0
点赞
收藏
分享

微信扫一扫

mmSegmentation 自定义验证钩子(eval_hooks)

前言

mmSegmentation 中的验证流程不满足需求, 因此自定义验证流程。

验证流程在 config 文件中配置如下, 属于 schedule 配置。

evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)  # 每 16000 个迭代周期进行一次评估,评估指标为 mIoU

每当 interval 个 iter 训练完成后, 启用一次 evaluation。

版本信息

mmcv 1.4.8
mmsegmentation 0.23.0

自定义验证钩子(EvalHooks)

1 复制mmseg/core/evaluation/eval_hooks.pymy_eval_hooks.py
在这里插入图片描述
2 在mmseg/core/evaluation/__init__.py中将eval_hooks.py中的EvalHook替换为my_eval_hooks.py中的EvalHook
在这里插入图片描述

3 可以啦, 在my_eval_hooks.pyEvalHook._do_evaluate()中修改验证流程吧。

注: EvalHook并不支持使用用配置文件注入, 因此只能用这种方式进行修改。

验证钩子生命周期

注册位置: mmseg/apis/train.py:train_segmentor()

    # register eval hooks
    if validate:
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        val_dataloader = build_dataloader(
            val_dataset,
            samples_per_gpu=1,
            workers_per_gpu=cfg.data.workers_per_gpu,
            dist=distributed,
            shuffle=False)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

调用位置:
EvalHooks属于after_train_hook, 在每个训练循环后被调用。
IterBasedRunner.train()(mmcv.runner.iter_based_runner) 控制一个Iter的主要流程, 当其执行self.call_hook('after_train_iter')时会调用EvalHooks的after_train_iter()函数。
after_train_iter()由mmcv.runner.EvalHook实现, 我们的EvalHook继承自该类, after_train_iter()会调用self._do_evaluate(), 因此我们将在self._do_evaluate()中修改流程。

    def _do_evaluate(self, runner):
        """perform evaluation and save ckpt."""
        if not self._should_evaluate(runner):
            return

        from mmseg.apis import single_gpu_test
        results = single_gpu_test(
            runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
        runner.log_buffer.clear()
        runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
        key_score = self.evaluate(runner, results)
        if self.save_best:
            self._save_ckpt(runner, key_score)

钩子上下文

钩子统一传入runner运行时, 主要包括:

runner.model 模型
runner.data_loader 是训练集的dataloader
self.dataloader 才是验证集的dataloader
self.dataloader.dataset 可以直接访问数据集类, 功能可以写在这里
runner.meta 包括很多运行配置如'env_info', 'seed', 'exp_name', 'mmseg_version', 'config', 'CLASSES', 'PALETTE'
举报

相关推荐

0 条评论