自然语言处理(NLP: Natural Language Processing)与语音识别(ASR: Automatic Speech Recognition)都是典型的序列识别任务,现阶段皆可以按Transformer模型架构进行建模处理,如SAN-M、BERT、BART、GPT2、T5、Switch-Transformer等模型。
Transformer模型的输入,表示为Sequence embeddings(序列特征),其中NLP的输入Embedding、表示Token的高维矢量编码,ASR的输入特征、则是经过语音信号采样、分帧、预加重、加窗、FFT等处理的声学特征(Acoustic features,如LogfBank、MFCC、Spectrum)。
数据增强的目的,是通过对样本进行变换,以扩增数据量、并丰富样本分布,但要求变换后样本能保持原有的标签语义。因此,数据增强需同时兼顾相似性与多样性。对序列特征进行随机掩码处理,可实现NLP或ASR模型训练的数据增强,从而提升模型的稳健性与泛化性,典型如ASR的SpecAugment、与NLP的Cutoff。
SpecAugment
GitHub地址:https://github.com/DemisEom/SpecAugment
SpecAugment直接对ASR的输入声学特征进行操作:
- 声学特征是时域语音信号、经处理变换获得的谱图(如Mel谱图),包括水平轴(时间轴)与纵轴(频率轴);
- SpecAugment的可选操作包括沿时间轴的变形、沿时间轴的通道掩码、或沿频率轴的通道掩码,以对抗时域上的变形,频域上的部分片段损失,从而增强ASR模型训练;
Cutoff
GitHub地址:https://github.com/dinghanshen/Cutoff
Cutoff直接对Embedding layer的输出特征进行操作:
- Cutoff类型分三种:
- Token cutoff:沿Sequence维度,随机掩码某些Token embeddings;
- Feature cutoff:沿Embedding维度,随机掩码某些Vectors;
- Span cutoff:沿Sequence维度,随机地从某个Token开始,掩码一段连续的Token embeddings;
- 为了对齐样本变换前后的语义,引入了Jensen-Shannon (JS) divergence consistency loss;因此Total loss如下:
Cutoff总体示意、实验结果与代码分析(以Span cutoff为例)如下:
def js_div(p, q):
""" Jensen-Shannon (JS) divergence consistency."""
m = (p + q) / 2
a = F.kl_div(p.log(), m, reduction='batchmean')
b = F.kl_div(q.log(), m, reduction='batchmean')
jsd = ((a + b) / 2)
return jsd
# Cutoff: cut embedding_output and attention mask
input_ids = inputs['input_ids']
token_type_ids = inputs.get('token_type_ids', None)
labels = inputs.get('labels', None)
# Output features of embedding layer
embeds = model.get_embedding_output(input_ids=input_ids, token_type_ids=token_type_ids)
masks = inputs['attention_mask']
input_lens = torch.sum(masks, dim=1)
input_embeds = []
input_masks = []
# Span cutoff
for i in range(embeds.shape[0]):
cutoff_length = int(input_lens[i] * self.args.aug_cutoff_ratio)
start = int(torch.rand(1) * (input_lens[i] - cutoff_length))
cutoff_embed = torch.cat((embeds[i][:start], torch.zeros([cutoff_length, embeds.shape[-1]], dtype=torch.float).to(self.args.device), embeds[i][start + cutoff_length:]), dim=0)
cutoff_mask = torch.cat((masks[i][:start], torch.zeros([cutoff_length], dtype=torch.long).to(self.args.device), masks[i][start + cutoff_length:]), dim=0)
input_embeds.append(cutoff_embed)
input_masks.append(cutoff_mask)
input_embeds = torch.stack(input_embeds, dim=0)
input_masks = torch.stack(input_masks, dim=0)
# Preditions of augmented samples
cutoff_outputs = model.get_logits_from_embedding_output(embedding_output=input_embeds, attention_mask=input_masks, labels=labels)
# CE loss of augmented samples
if self.args.aug_ce_loss > 0:
loss += self.args.aug_ce_loss * cutoff_outputs[0]
# JS divergence loss between original samples and augmented samples
if self.args.aug_js_loss > 0:
assert self.args.n_gpu == 1
ori_logits = ori_outputs[1]
aug_logits = cutoff_outputs[1]
p = torch.softmax(ori_logits + 1e-10, dim=1)
q = torch.softmax(aug_logits + 1e-10, dim=1)
aug_js_loss = js_div(p, q)
loss += self.args.aug_js_loss * aug_js_loss