0
点赞
收藏
分享

微信扫一扫

2、torch.nn.Flatten()的参数实例

罗蓁蓁 2022-03-18 阅读 168

2、torch.nn.Flatten()的参数实例

# 默认参数
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten()
a1 = F(a)

a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 12288])
默认将第0维保留下来,其余拍成一维

# 一个参数
import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(2)
a1 = F(a)


a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 3, 4096])
从第二维开始,拍成一维

# 两个参数

import torch
a = torch.randn(8,3,64,64)
F = torch.nn.Flatten(1,2)
a1 = F(a)

a的大小:
torch.Size([8, 3, 64, 64])

a1的大小:
torch.Size([8, 192, 64])
将第一维到第二维拍成一维,其余不变

举报

相关推荐

0 条评论