mmaction2训练相关源码概览⽂章⽬录
0. 前⾔
⽬标:整理 mmaction2 训练的实现过程。
1. ⼊⼝函数详解
⼊⼝代码:tools/train.py
实现的功能:
第⼀步:构建参数,包括命令⾏参数以及配置⽂件参数。
第⼆步:初始化⼀堆东西,⽐如创建输出路径、logger、random seed等。
第三步:构建模型。
第四步:构建数据集。
第五步:执⾏训练。
1.1. 构建模型
⼊⼝函数:mmaction/models/builder.py 中的 build_model 函数。
实现的功能:
根据配置⽂件中的 del['type'] 判断模型类型。
通过注册机制,根据模型类型字符串选择对应的类。
通过 del 中除 type 外的其他参数作为模型初始化参数,构建模型对象。
更多内容请参考
1.2. 构建数据集
⼊⼝函数:mmaction/datasets/builder.py 中的 build_dataset 函数。
实现功能:
根据配置⽂件中的 ain['type'] 判断数据集类型。
通过注册机制,根据数据集类型字符串选择对应的类。
通过 ain 中除 type 外的其他参数作为数据集初始化参数,构建最终数据集。
更多内容请参考
1.3. 执⾏训练。
⼊⼝函数:mmaction/apis/train.py 中的 train_model 函数。
从流程看:
第⼀步:构建logger
第⼆步:根据参数构建dataloader
第三步:根据需求,构建分布式模型
第四步:构建optimizer
第五步:初始化 EpochBaseRunner,简称为 runner。
第六步:根据需求设置 fp16 量化。
第七步:构建各类hooks,包括学习率、log、优化器、保存模型、分布式sampler等。
第⼋步:根据需求设置validate参数,包括构建val数据集以及对应dataloader,以及对应的 eval hook。
第九步:根据需求初始化模型参数。
第⼗步:实际执⾏训练。
从实现机制看:
训练细节都是通过 EpochBaseRunner 实现的。
⼀些具体细节都是通过 runner 中的hook实现。
2. Runner 介绍
2.1. BaseRunner
代码位于 mmcv.runner.base_runner.py 中
作⽤:pytorch训练相关代码。
构造函数:
输⼊参数
model
batch_processor(callable⽅法,调⽤⽅法是 batch_processor(model, data, train_mode),输出⼀个字典)
optimizer
work_dir(保存模型、⽇志⽂件)
logger
404页面网站源码meta(字典,包括环境信息和seed等)
执⾏的操作:
判断输⼊参数合法性。
将输⼊数据保存为成员变量。
初始化其他成员变量。
⽀持的成员变量:model_name、rank、world_size、hooks、epoch、iter、inner_iter、max_epochs、max_iters
抽象⽅法:train、val、run、save_checkpoint
实现的功能
获取optimizer中每个param_groups中lr、momentum、betas的数值。
注册hook、定义⼀些默认hook。
resume 模型权重。
hooks 相关功能:
在 register_hook 时会根据输⼊的 priority 获得具体的优先级数值,内部保存hooks时会根据优先级数值进⾏排序。
定义 hook 的 helper function,⽤来运⾏所有hook的某个⽅法。
定义training中默认⽤到的六种hook
LrUpdaterHook:详见 mmcv.runner.hooks.lr_updater.py
MomentumUpdaterHook:optimizer中momentum的更新,详见 mmcv.um_updater.py
OptimizerHook:更新参数hook,详见 mmcv.runner.hooks.optimizer.py
CheckpointHook:保存模型hook,详见 mmcv.runner.hooks.checkpoint.py
IterTimerHook:为logger增加计时功能,在logger中增加了两个参数data_time和time,前者表⽰数据获取时间,后者表⽰iter总时间,详见 mmcv.runner.hooks.iter_timer.py
LoggerHook(s):详见 mmcv.runner.hooks.logger 中
2.2. EpochBaseRunner
定义了BaseRunner中 train/val/run/save_checkpoint 四个抽象⽅法。
训练代码主要功能:
在响应位置调⽤hooks对应的⽅法,这⾥就包含了更新参数、logger、更新学习率、模型保存等功能。
遍历⼀遍dataloader,分别执⾏前向过程,得到损失函数。
如果设置了 batch_processor,则通过该函数计算损失函数。
如果没有设置 batch_processor,则通过 ain_step 获得损失函数。
def train(self, data_loader,**kwargs):
self.data_loader = data_loader
self._max_iters = self._max_epochs *len(data_loader)
self.call_hook('before_train_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = ain_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
if not isinstance(outputs,dict):
raise TypeError('"batch_processor()" or "ain_step()"'
' must return a dict')
if'log_vars'in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter +=1
self.call_hook('after_train_epoch')
self._epoch +=1
验证相关代码,主要功能包括:
执⾏各类hooks。
遍历 dataloader,通过 self.batch_processor 或 model.val_step 执⾏前向操作,得到模型输出结果
def val(self, data_loader,**kwargs):
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2)# Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
_grad():
if self.batch_processor is None:
outputs = del.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
if not isinstance(outputs,dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if'log_vars'in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
run,包括训练/验证⼯作
输⼊参数包括 data_loaders/workflow,两者的长度相同,分别对应。
workflow 加⼊是 [('train', 2), ('val', 1)],则表⽰train 2 epoch then val 1 epoch,按照这个顺序依次进⾏训练,作为⼀个
epoch。
后续会根据 workflow 根据 mode 选择对应的 train/val ⽅法。
def run(self, data_loaders, workflow, max_epochs,**kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
max_epochs (int): Total training epochs.
"""
assert isinstance(data_loaders,list)
assert mmcv.is_list_of(workflow,tuple)
assert len(data_loaders)==len(workflow)
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode =='train':
self._max_iters = self._max_epochs *len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode,str):# ain()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner =getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode =='train'and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i],**kwargs)
time.sleep(1)# wait for some hooks like loggers to finish
self.call_hook('after_run')

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。