0
点赞
收藏
分享

微信扫一扫

2022-03-14thorough-pytorch-模型定义

王老师说 2022-03-17 阅读 65

目录

thorough pytorch

thorough pytorch
通过Sequential,ModuleList和ModuleDict三种方式定义PyTorch模型。

Sequential

对已定义好的模型顺序执行,不需要同时写__init__和forward

import torch.nn as nn
net = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net)

OrderedDict格式的输入,有层名称的

import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)),
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
          ]))
print(net2)

ModuleList和ModuleDict

这俩只定义了模型,并没有规定执行顺序,需要init和foward,里面的层可以重复使用,ModuleDict就是加了个层名称

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)
net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
class model(nn.Module):
  def __init__(self, ...):
    self.modulelist = ...
    ...
    
  def forward(self, x):
    for layer in self.modulelist:
      x = layer(x)
    return x

补充

举报

相关推荐

0 条评论