0
点赞
收藏
分享

微信扫一扫

pytorch,构造矩阵mask

Fifi的天马行空 2022-07-27 阅读 46
编程语言


import torch
mask = torch.triu(
torch.ones(5, 5), diagonal=1).byte()
print(mask)
mask = torch.triu(
torch.ones(5, 5), diagonal=2).byte()
print(mask)

tensor(
[[0, 1, 1, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]], dtype=torch.uint8)

tensor(
[[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], dtype=torch.uint8)


举报

相关推荐

0 条评论