0
点赞
收藏
分享

微信扫一扫

mmaction中的Registry

洛茄 2022-03-12 阅读 67

1. 定义一个自己的类

在csnmemde.py中,导入mmaction.models.builder 中的HEADS,使用HEADS注册器写在class上面

还定义了一个mmdet_imported用来最后的一步register_module() #最后一步的作用不晓得。

# csnmemde.py

from mmaction.models.builder import HEADS

try:
    from mmdet.models import BACKBONES as MMDET_BACKBONES  # 定义backbone时用到这句
    from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
    mmdet_imported = True
except (ImportError, ModuleNotFoundError):
    mmdet_imported = False

@HEADS.register_module()
class ResNetCSNMem(nn.Module):
    def __init__(self, chnum_in, mem_dim, feature_num, 
                 feature_num_2, feature_num_x2, feature_num_x4, 
                 feature_num_x6, feature_num_x8, shrink_thres=0.0025):
        super(ResNetCSNMem, self).__init__()
        print('ResNetCSNCov3DMem')
        self.chnum_in = chnum_in  # 通道数
        self.feature_num = feature_num
        self.feature_num_2 = feature_num_2
        self.feature_num_x2 = feature_num_x2
        self.feature_num_x4 = feature_num_x4
        self.feature_num_x6 = feature_num_x6
        self.feature_num_x8 = feature_num_x8


if mmdet_imported:
    MMDET_SHARED_HEADS.register_module()(ResNetCSNMem)

2.把ResNetCSNMem类放到mmaction的包中

在 mmacyion2/mmaction/models/head 的目录下添加csnmemde.py,或者可以直接在改目录下编辑代码。

3. 在head的__init__.py中添加ResNetCSNMem类

from .x3d_head import X3DHead
from .csnmemde import ResNetCSNMem
__all__ = [
    'TSNHead', 'I3DHead', 'BaseHead', 'TSMHead', 'SlowFastHead', 'SSNHead',
    'TPNHead', 'AudioTSNHead', 'X3DHead', 'BBoxHeadAVA', 'AVARoIHead',
    'FBOHead', 'LFBInferHead', 'TRNHead', 'TimeSformerHead', 'ACRNHead',
    'STGCNHead', 'ResNetCSNMem'
]

3.加载model

定义完一个新的类,第一次用到该类的时候,要确保重新activate了对应的虚拟环境, 重新activate了对应的虚拟环境, 重新activate了对应的虚拟环境, 这样这个新类才会注册到mmaction全局。

from mmaction.models import build_head
from mmcv import Config

cfg = Config.fromfile('config/csncfg.py')

memde = build_head(cfg.model.cls_head)

判断是否将新的类注册到全局,打印其HEADS看一下,注册成功!

from mmaction.models import HEADS
HEADS

 

 对于mmdet的注册方法一样,还可以自己定义BACKBONE, NECK等。这里完成了Registry的部分,如果用这个类,还需要定义config,使用时builder一下。

 

举报

相关推荐

0 条评论