郑重承诺,本文章提供代码保证能够运行可用,若不能用可留言,看到了一定帮忙解决!
很多边缘计算芯片不支持sort算子,通过分析,发现简单的矩阵计算及torch.max可以替换sort算子,代码如下:
def sort_torch_max_sheng_xu(x, dim=2):
# x = torch.rand(1, 2, 32, 4, 5)
# sort_list, ind = x.sort(dim, False)
# print(sort_list[0,0,:,0,0])
# print(ind[0,0,:,0,0])
sorted_x = torch.zeros_like(x).to(x.device)
indices = torch.zeros_like(x).to(x.device)
index_array = torch.arange(x.size(dim)).unsqueeze(0).unsqueeze(0).unsqueeze(3).unsqueeze(4).to(x.device)
index_array = (sorted_x*0 + index_array)
# min_x_abs = abs(x.max())
# x = x + min_x_abs
for i in range(x.size(dim)):
max_val, max_idx = torch.max(x, dim=dim)
sorted_x.select(dim, x.size(dim)-1 - i).copy_(max_val)
indices.select(dim, x.size(dim)-1 - i).copy_(max_idx)
mask = max_idx.unsqueeze(dim) - index_array
mask = mask*mask
mask[mask>0] = 1
x = x*mask
# print(sorted_x[0,0,:,0,0])
# print(indices[0,0,:,0,0])
return sorted_x,indices