前言
通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。
一、MultiHeadAttention的本质内涵
1.Self_Atention机制
2.MultiHead_Atention机制
二、使MultiHeadAttention在TensorFlow中的代码实现
1.参数说明
2.整体结构
''' 多头映射层 '''
query = self._query_dense(query)
key = self._key_dense(key)
value = self._value_dense(value)
''' 注意力层 '''
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training
)
''' 输出层 '''
attention_output = self._output_dense(attention_output)
3.多头映射层
4.注意力层
5.输出映射层
验证
import tensorflow as tf
layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,
return_attention_scores=True)
''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')
手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)