0
点赞
收藏
分享

微信扫一扫

mmdetection - 初识hook的使用

归零者245号 2022-02-14 阅读 69

(1) Hook 是什么

在 wiki 百科中定义如下:

钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

下面我们来看一下,在mmdetection训练数据时,hook的整个使用流程,
首先在train_detector函数中,出现了runner.register_training_hooks这个函数的调用:

runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

runner.register_training_hooks这个函数体内分别注册多个hook,具体实现如下:

    def register_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None,
                                momentum_config=None,
                                timer_config=dict(type='IterTimerHook'),
                                custom_hooks_config=None):
        self.register_lr_hook(lr_config)
        self.register_momentum_hook(momentum_config)
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
        self.register_timer_hook(timer_config)
        self.register_logger_hooks(log_config)
        self.register_custom_hooks(custom_hooks_config)

我们再以第一个函数self.register_lr_hook为例,看一下具体注册过程,

    def register_lr_hook(self, lr_config):
        if lr_config is None:
            return
        elif isinstance(lr_config, dict):
            assert 'policy' in lr_config
            policy_type = lr_config.pop('policy')
  			#首字母转为大写
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
            #我的例程中hook_type= 'StepLrUpdaterHook'
            lr_config['type'] = hook_type
            #HOOKS = Registry('hook')是之前注册好的,在 mmcv.build_from_cfg函数中,通过类名获得类,
            #然后再用lr_config里面的参数对类进行实例化。class StepLrUpdaterHook(LrUpdaterHook):
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
        #把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
        self.register_hook(hook, priority='VERY_HIGH')
def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()
    obj_type = args.pop('type')
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)

总结一下,通过config文件实例化一个hook类,再把它注册进 Runner 中。

举报

相关推荐

0 条评论