(1) Hook 是什么
在 wiki 百科中定义如下:
钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)
下面我们来看一下,在mmdetection训练数据时,hook的整个使用流程,
首先在train_detector函数中,出现了runner.register_training_hooks这个函数的调用:
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
runner.register_training_hooks这个函数体内分别注册多个hook,具体实现如下:
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
我们再以第一个函数self.register_lr_hook为例,看一下具体注册过程,
def register_lr_hook(self, lr_config):
if lr_config is None:
return
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
#首字母转为大写
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
#我的例程中hook_type= 'StepLrUpdaterHook'
lr_config['type'] = hook_type
#HOOKS = Registry('hook')是之前注册好的,在 mmcv.build_from_cfg函数中,通过类名获得类,
#然后再用lr_config里面的参数对类进行实例化。class StepLrUpdaterHook(LrUpdaterHook):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
#把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
self.register_hook(hook, priority='VERY_HIGH')
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
obj_type = args.pop('type')
obj_cls = registry.get(obj_type)
return obj_cls(**args)
总结一下,通过config文件实例化一个hook类,再把它注册进 Runner 中。