0
点赞
收藏
分享

微信扫一扫

【PyTorch】利用torch.cat()实现Tensor的拼接


问题

方法

import torch
from torch import nn

conv1 = nn.Conv2d(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
padding=1
)

conv2 = nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1
)

x = torch.rand(128, 3, 224, 224)

x1 = conv1(x) # [128, 32, 224, 224]
x2 = conv2(x) # [128, 16, 224, 224]

# 表示对dim=1维进行cat操作,其他维度均不变
out = torch.cat([x1, x2], dim=1)
print(out.shape) #[128, 48, 224, 224]

结语


举报

相关推荐

0 条评论