0
点赞
收藏
分享

微信扫一扫

pytorch笔记:ModuleDict

腾讯优测 2024-05-04 阅读 13

1 介绍 

  • 在 PyTorch 中,nn.ModuleDict 是一个方便的容器,用于存储一组子模块(即 nn.Module 对象)的字典
  • 这个容器主要用于动态地管理多个模块,并通过键来访问它们,类似于 Python 的字典

2 特点

  • 组织性
    • nn.ModuleDict 提供了一种将多个模块有序组织在一起的方法。
    • 这有助于让代码更加结构化,易于理解和维护
  • 动态操作
    • ​​​​​​​​​​​​​​可以像操作普通字典那样添加或删除模块
    • 例如使用 module_dict['key'] = module 添加模块,使用 del module_dict['key'] 删除模块
  • 自动参数注册
    • ​​​​​​​当将模块添加到 ModuleDict 中时,它们的参数会自动注册到整个网络中,确保在模型训练时这些参数可以被识别和更新

3 例子

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleDict({
            'linear': nn.Linear(10, 20),
            'activation': nn.ReLU()
        })

    def forward(self, x):
        x = self.layers['linear'](x)
        x = self.layers['activation'](x)
        return x
  • 在构造函数中,我们使用 ModuleDict 来存储一个线性层和一个激活层,并在前向传播 forward 方法中通过键名访问这些层

4 主要方法

clear清空 ModuleDict 中的所有条目
items

返回一个可迭代对象,包含 ModuleDict 中的键/值对

keys

返回一个可迭代对象,包含 ModuleDict 的所有键

popModuleDict 中移除指定的键,并返回其对应的模块
update(modules)使用另一个映射或键值对迭代器更新 ModuleDict,如果存在相同的键,则会覆盖原有的条目
values返回一个可迭代对象,包含 ModuleDict 的所有模块值
举报

相关推荐

0 条评论