文章目录
1. torch.chunk
尝试将一个张量分割成指定数量的块。每个块都是输入张量的一个视图。
# 创建一个5行6列的张量
x_input_chunk = torch.arange(30).reshape(5,6)
# 将x_input_chunk按行分割成3分,如果除不尽,多余的在最后一个
x_chunk_dim_0 = x_input_chunk.chunk(3,dim=0)
# 将x_input_chunk按列分割成3分,如果除不尽,多余的在最后一个
x_chunk_dim_1 = x_input_chunk.chunk(3,dim=1)
print(f"x_input_chunk={x_input_chunk}")
print(f"x_chunk_dim_0={x_chunk_dim_0}")
print(f"x_chunk_dim_1={x_chunk_dim_1}")
x_input_chunk=tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]])
# 将x_input_chunk按行分割成3分,如果除不尽,多余的在最后一个
x_chunk_dim_0=(tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]]), tensor([[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]]), tensor([[24, 25, 26, 27, 28, 29]]))
# 将x_input_chunk按列分割成3分,如果除不尽,多余的在最后一个
x_chunk_dim_1=(tensor([[ 0, 1],
[ 6, 7],
[12, 13],
[18, 19],
[24, 25]]), tensor([[ 2, 3],
[ 8, 9],
[14, 15],
[20, 21],
[26, 27]]), tensor([[ 4, 5],
[10, 11],
[16, 17],
[22, 23],
[28, 29]]))