0
点赞
收藏
分享

微信扫一扫

【Pytorch】nn.Linear,nn.Conv


nn.Linear

【Pytorch】nn.Linear,nn.Conv_pytorch

nn.Conv1d

【Pytorch】nn.Linear,nn.Conv_3d_02

当​​nn.Conv1d​​​的​​kernel_size=1​​​时,效果与​​nn.Linear​​​相同,不过输入数据格式不同:​

import torch


def count_parameters(model):
"""Count the number of parameters in a model."""
return sum([p.numel() for p in model.parameters()])


conv = torch.nn.Conv1d(3, 32, kernel_size=1)
print(count_parameters(conv))
# 128

linear = torch.nn.Linear(3, 32)
print(count_parameters(linear))
# 128

print(conv.weight.shape)
# torch.Size([32, 3, 1])
print(linear.weight.shape)
# torch.Size([32, 3])

# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(2))
linear.bias = torch.nn.Parameter(conv.bias)

tensor = torch.randn(128, 256, 3) # [batch, feature_num,feature_size]
permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous() # [batch, feature_size,feature_num]

out_linear = linear(tensor)
print(out_linear.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_linear.shape)
# torch.Size([128, 256, 32])


out_conv = conv(permuted_tensor)
print(out_conv.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_conv.shape)
# torch.Size([128, 32, 256])

nn.Conv2d

【Pytorch】nn.Linear,nn.Conv_深度学习_03

nn.Conv3d

【Pytorch】nn.Linear,nn.Conv_深度学习_04


举报

相关推荐

0 条评论