import torch
x = torch.randn(1,3,224,224)
print(x[0].shape)
print(x[0::,...].shape)
torch.Size([3, 224, 224])
torch.Size([1, 3, 224, 224])
- 索引::加上,会保持维度不变。
微信扫一扫
import torch
x = torch.randn(1,3,224,224)
print(x[0].shape)
print(x[0::,...].shape)
torch.Size([3, 224, 224])
torch.Size([1, 3, 224, 224])
相关推荐