商汤开源⽬标检测⼯具箱mmdetection代码详解(⼆)------
mmdetectio。。。
⽬录
mmdetection版本:2.0
mmcv版本:0.5.5
mmdetection和mmcv的关系是,mmdetection⼀些功能代码是直接通过调⽤mmcv的api实现的。
============================================================================
mmdetection的main()函数就在 tool/train.py⾥,在 中说过,看train.py的代码,看不懂的地⽅最先遇到的是
build_xxx(),build_xxx()已经在 中讲过了,然后现在就说第⼆个⽐较难懂的地⽅,就是训练的过程。
mmdetection的训练过程,只⽤调⽤⼀个接⼝,就是 train_detector(),这个接⼝被定义在mmdetection项⽬代码的
mmdet/apis/train.py⾥,注意这⾥的train.py和 tool/train.py是不同的,前者主要是 提供接⼝,后者是训练的顺序代码。
train_detector():
train_detector()主要接受三个参数,分别是model,cfg,dataset:
model:通过build_detector()实例化出来的某个⽬标检测⽹络类的对象。
cfg:cfg是来⾃配置⽂件的配置信息,这些配置⽂件⼀般都在mmdetection项⽬⾥的 config/_base_/
cfg是由4个配置⽂件组成的,以maskrcnn⽹络来训练COCO数据集为例,如下图,下图中只有第⼀个配置⽂件会随着选择的⽹络改变⽽改变,第⼆个随着你选择的数据集⽽改变,其余两个是不会变的。
上图中第⼀个配置⽂件:包含了Maskrcnn的配置信息以及训练、测试这个⽹络的训练、测试信息。
上图中第⼆个配置⽂件:包含了训练阶段和测试阶段如何处理COCO数据集的信息,如归⼀化参数,Resize的尺⼨。当然还有指定COCO 数据集的路径,也是在这个配置⽂件中指定的。还有batchsize也是在这⾥指定,只不过名字变成了samples_per_gpu。
上图中第三个配置⽂件:包含了训练模式的优化器、学习率、epoch的信息,当然是可以在这个配置⽂件⾥修改这些参数的。
上图中第四个配置⽂件:包含了训练过程中保存模型的间隔,⽇记记录的配置信息。
dataset:dataset通过build_dataset()实例化出来的数据集类的对象。
过程:
解析完train_detector()的参数之后,就可以看看train_detector()的过程了。
主要流程如下,后⾯会逐⼀讲解:
以上是train_detector()的⼤致代码流程,其中最重要的是 最后的 runner.run(),因为它控制的就是训练的流程。
为了更好地了解,下⾯会着重讲⼀下最重要的 Runner类 和 HOOK的使⽤。
Runner类:
Runner类位于mmcv.runner⾥,同样是不属于mmdetection的项⽬代码,但是要运⾏mmdetection就需要⽤到mmcv的包。Runner类主要包含 保存模型的过程、train训练的过程、val验证的过程、各种HOOK的管理过程(HOOK下⾯会详细介绍)。有⼈可能会疑惑了,为什么那么多train?有 tool/train.py、mmdet/api/train.py ,现在⼜有⼀个train(),关系是这样的:
从上图可以看到,训练的接⼝调来调去,其实最终是在 mmcv/runner.py 的train.py⽅法⾥实现较为底层的训练代码,在这个train()中是已经到了从data_loader⾥取出数据进⾏训练的地步了。但如果你觉得到这⾥就没什么tricks(骚操作)你就错了,尽管是简单的训练代码,都分了⼏块,如下图:
从上图可以看到:
模型的输出结果其实是经过 batch_processor()得到的。在模型进⾏输出之前,会经过⼏个HOOK,call_hook()就是调⽤HOOK的函
数,call_hook中 有字符,表⽰,具体的操作,例如 call_hook('before_train_epoch') 就表⽰在训练⼀个 epoch前需要进⾏的操作。然后下⾯就讲讲HOOK。
HOOK类:
我们先来看看HOOK这个类是怎么定义的:
位置:mmcv/runner/hooks/hook.py
HOOKS = Registry('hook')
class Hook(object):
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):tool工具箱
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)可视化结果如下图:
很多⽅法的名字都挺有意思,例如 “before_epoch”就表⽰这个⽅法会在训练每个epoch之前执⾏。
同样也可以看到这个类⾥⾯定义了很多空的⽅法(都是pass),这个是给我们重载 ⽤的,就是说继承HOOK类的类,可以拥有这些⽅法,这就衍⽣了xxxHOOK的类了。其次我们看到了HOOKS,这是什么?就是在的注册表全局变量,这是HOOK的注册表全局变量,这就暗⽰了注册表⾥肯定有很多不同的 HOOKS类。
那我们能不能看⼀下,⼀共有多少种HOOKS被定义呢?在mmcv/runner/hooks/__init__.py下,有定义:
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook',
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook',
'TensorboardLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook'
]
可以从HOOK的名字看出来,每个HOOK都对应着⼀些特定的功能。
我们先看看调⽤HOOK的函数 call_hook()是怎么定义的:
(这位于mmcv/runner/runner.py,不属于mmdetection项⽬代码,属于mmcv)
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
可以看到,call_hook()先遍历hooks的列表中的每个不同类型的hook,由于每个hook都是从HOOK 类实例化出来的,所以都有
before_run(),after_run(),before_epoch(),after_epoch()....等⽅法。我们主要看看 call_hook调⽤的是什么HOOK,通过调试可以看到 self._hooks⾥的变量值(即hook),⾥⾯有:
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论