0
点赞
收藏
分享

微信扫一扫

pytorch中的nn.Module抽象类的参数


我们在搭建网络时,通常要继承​​nn.Module​​​这个模块,并且实现其​​forward​​方法,那么这个基类中到底有何属性呢?

def __init__(self):
self._parameters = OrderedDict()
self._modules = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hoods = OrderedDict()
self.training = True

这个基类有以下属性:

  • ​_parameters​​:有序字典,保存用户直接设置的Parameter。例如,对于self.param1 = nn.Parameter(torch.randn(3, 3)),构造函数会在字典中加入一个key为param1、value为对应Parameter的item。self.submodule = nn.Linear(3, 4)中的Parameter不会被保存在该字典中。
  • ​_modules​​:子module。例如,通过self.submodel = nn.Linear(3, 4)指定的子module会被保存于此。
  • ​_buffers​​:缓存。例如,BatchNorm使用动量机制,每次前向传播时都需要用到上一次前向传播的结果。
  • ​_backward_hooks​​:钩子技术,用来提取中间变量。
  • ​_forward_hoods​​​:钩子技术,用来提取中间变量。
    +​​​training​​ :BatchNorm层与Dropout层在训练阶段和测试阶段采取的策略不同,通过training属性决定前向传播策略。


举报

相关推荐

0 条评论