import torch
mask = torch.triu(
torch.ones(5, 5), diagonal=1).byte()
print(mask)
mask = torch.triu(
torch.ones(5, 5), diagonal=2).byte()
print(mask)
tensor(
[[0, 1, 1, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]], dtype=torch.uint8)
tensor(
[[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], dtype=torch.uint8)