0
点赞
收藏
分享

微信扫一扫

vue elementui <el-date-picker>日期选择框限制只能选择90天内的日期(包括今天)

深度学习基础知识 nn.Sequential | nn.ModuleList | nn.ModuleDict

1、nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict 类都继承自 Module 类。

2、nn.Sequential、nn.ModuleList 和 nn.ModuleDict语法

代码

import torch
import torch.nn as nn

net1 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())
net2 = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net3 = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

print(net1)
print(net2)
print(net3)

在这里插入图片描述

3、Sequential 、ModuleDict、 ModuleList 的区别

代码:

import torch
import torch.nn as nn

net1 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())
net2 = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net3 = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})


x = torch.randn(8, 3, 32)
print(net1(x).shape)    # 输出内容: torch.Size([8, 3, 64])
# print(net2(x).shape)  # 会报错,提示缺少forward
# print(net3(x).shape)   # 会报错,提示缺少forward

为 nn.ModuleList 写 forward 函数
代码:

import torch
import torch.nn as nn


class My_Model(nn.Module):
    def __init__(self):
        super(My_Model, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(32, 64),nn.ReLU()])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

net = My_Model()

x = torch.randn(8, 3, 32)
out = net(x)
print(out.shape)

输出结果:
在这里插入图片描述
为 nn.ModuleDict 写 forward 函数

import torch
import torch.nn as nn


class My_Model(nn.Module):
    def __init__(self):
        super(My_Model, self).__init__()
        self.layers = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

    def forward(self, x):
        for layer in self.layers.values():
            x = layer(x)
        return x

net = My_Model()
x = torch.randn(8, 3, 32)
out = net(x)
print(out.shape)

将 nn.ModuleList 转换成 nn.Sequential

import torch
import torch.nn as nn

module_list = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net = nn.Sequential(*module_list)
x = torch.randn(8, 3, 32)
print(net(x).shape)

输出如下:
在这里插入图片描述

将 nn.ModuleDict 转换成 nn.Sequential

import torch
import torch.nn as nn

module_dict = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})
net = nn.Sequential(*module_dict.values())
x = torch.randn(8, 3, 32)
print(net(x).shape)

输出如下:
在这里插入图片描述

4、ModuleDict、 ModuleList 的区别

import torch.nn as nn

net = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])
net.append(nn.Linear(64, 10))
print(net)
import torch.nn as nn

net = nn.ModuleDict({'linear1': nn.Linear(32, 64), 'act': nn.ReLU()})
net['linear2'] = nn.Linear(64, 128)
print(net)

5、nn.ModuleList 、 nn.ModuleDict 与 Python list、Dict 的区别

import torch.nn as nn

net = nn.ModuleList([nn.Linear(32, 64), nn.ReLU()])

for name, param in net.named_parameters():
    print(name, param)


print("-----------------------------")
for name, param in net.named_parameters():
    print(name, param.size())

显示结果如下:
在这里插入图片描述

import torch.nn as nn

net = nn.ModuleDict({'linear': nn.Linear(32, 64), 'act': nn.ReLU()})

for name, param in net.named_parameters():
    print(name, param.size())
print("--------------------------")

for name, param in net.named_parameters():
    print(name, param.size())

显示结果:
在这里插入图片描述

举报

相关推荐

0 条评论