import torch
a = torch.linspace(0, 10, 6).view(2, 3)
b = a.sum(dim=0)
c = torch.cumsum(a, dim=0)
print(a)
print(b)
print(c)
# tensor([[ 0., 2., 4.],
# [ 6., 8., 10.]])
#
# tensor([ 6., 10., 14.])
#
# tensor([[ 0., 2., 4.],
# [ 6., 10., 14.]])
d = a.sum(dim=1)
e = torch.cumsum(a, dim=1)
print(d)
print(e)
# tensor([ 6., 24.])
#
# tensor([[ 0., 2., 6.],
# [ 6., 14., 24.]])
keepdim 参数,说明输出结果是否保留维度
import torch
a = torch.linspace(0, 10, 6).view(2, 3)
b = a.sum(dim=0)
c = a.sum(dim=0, keepdim=True)
print(a)
# tensor([[ 0., 2., 4.],
# [ 6., 8., 10.]])
print(b)
# tensor([ 6., 10., 14.])
print(b.shape)
# torch.Size([3])
print(c)
# tensor([[ 6., 10., 14.]])
print(c.shape)
# torch.Size([1, 3])