0
点赞
收藏
分享

微信扫一扫

string模拟实现(直接上源码)

小a草 2024-04-28 阅读 18


前言

通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。


一、MultiHeadAttention的本质内涵

1.Self_Atention机制

2.MultiHead_Atention机制

二、使MultiHeadAttention在TensorFlow中的代码实现

1.参数说明

2.整体结构

        ''' 多头映射层 '''
        query = self._query_dense(query)
        key = self._key_dense(key)
        value = self._value_dense(value)
        
        ''' 注意力层 '''
        attention_output, attention_scores = self._compute_attention(
            query, key, value, attention_mask, training
        )
        
        ''' 输出层 '''
        attention_output = self._output_dense(attention_output)

3.多头映射层

4.注意力层

5.输出映射层


验证

import tensorflow as tf


layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,
                               return_attention_scores=True)

''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')



手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)

举报

相关推荐

string模拟实现

string模拟实现:

string【2】模拟实现string类

string的模拟实现

c++ string模拟实现

9.string模拟实现

0 条评论