0
点赞
收藏
分享

微信扫一扫

通过mask选择预测标签以及展平二维数组

sullay 2022-01-23 阅读 44

前言

在进行自然语言处理任务的时候,我们为了让一个句子的长度保持一致,经常会使用padding的操作,但是在算一些指标的时候,这些填充的token不应该被算进去。因此本文总结一些常用的小技巧快速实现目的。(本文会不断更新…



通过mask选择标签

import torch
# 原数据,一般是预测出来的标签
src = torch.arange(1, 13).view(3, 4)
print(src)


掩码:

mask = torch.tril(torch.ones(3, 4)).bool()
print(mask)


最后选择的输出:

print(src[mask])


展平

from itertools import chain

a = [[1, 2], [3, 4]]
print(list(chain.from_iterable(a)))

输出: [1, 2, 3, 4]

举报

相关推荐

0 条评论