import torch
import torch.nn.functional as F
from torch.autograd import *
a = Variable(torch.FloatTensor([[0,0,0,0,0,0,90,100]]))
b=F.softmax(a,-1)
print(b.multinomial()) # 7 或 6
print(b.multinomial(2)) # 6,7 或 7,6
print(b.multinomial(2,True)) # 7,7 或 7,6 或 6,7 或 6,6
也可以试试
WeightedRandomSampler
主要是replace也就是True False那个参数决定采样数据是否重复