0
点赞
收藏
分享

微信扫一扫

tf2实现vit的transformer encoder构成的主要层

import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1" # 使用cpu
import tensorflow as tf
from tensorflow.keras.layers import (Layer,Conv2D,LayerNormalization,Dense,
                                    Input,Dropout,Softmax,Add)

from tensorflow.keras.models import Model
from tensorflow.keras.activations import gelu




class PatchEmbedding(Layer):
    # PatchEmbedding层, 包括1.图片的embedding, 2.pos_embedding
    def __init__(self,image_size,patch_size,embed_dims,**kwargs):
        super(PatchEmbedding,self).__init__(**kwargs)

        self.embed_dims = embed_dims
        self.patch_size = patch_size
        self.image_size = image_size

        self.n_patches = (image_size//patch_size) * (image_size//patch_size)

        self.patch_embed_func = Conv2D(self.embed_dims,self.patch_size,self.patch_size)
    
    def build(self,input_shape):
        # 位置编码
        self.pos_embedding = self.add_weight('pos_embedding',
                                    shape=[input_shape[0],self.n_patches+1,self.embed_dims],
                                    dtype='float32',
                                    initializer='random_normal',
                                    trainable=True)
        # 分类头
        self.cls_token = self.add_weight('cls_token',
                                    shape=[input_shape[0],1,self.embed_dims],
                                    dtype='float32',
                                    initializer='random_normal',
                                    trainable=True)

        super(PatchEmbedding,self).build(input_shape)


    def get_config(self):
        config = super(PatchEmbedding,self).get_config()
        config.update({"embed_dims":self.embed_dims,
                       "patch_size":self.patch_size,
                       "image_size":self.image_size})
        return config

    def call(self,inputs):
        # patch_size=16, embed_dims=768
        # b,224,224,3 -> b,14,14,768
        x = self.patch_embed_func(inputs)

        # b,14,14,768 -> b,196,768
        b,h,w,c = x.shape
        x = tf.reshape(x,shape=[-1,h*w,c])

        # 加上cls_token
        x = tf.concat([x,self.cls_token],axis=1)

        # 加上pos_embedding
        x = x + self.pos_embedding

        return x

    def compute_output_shape(self, input_shape):
        b,h,w,c = input_shape
        return [b,self.n_patches,self.embed_dims]


class MultiHead_Self_Attention(Layer):
    def __init__(self,embed_dims,num_heads,atten_dropIndice=0.0,**kwargs):
        super(MultiHead_Self_Attention,self).__init__(**kwargs)

        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.atten_dropIndice = atten_dropIndice
        
        self.head_dims = self.embed_dims // self.num_heads
        self.all_head_dims = self.num_heads * self.head_dims

        self.scale = self.head_dims ** (-0.5) # q×k之后的变换系数
        

        self.qkv = Dense(self.all_head_dims*3)
        self.proj = Dense(self.all_head_dims)
        self.softmax = Softmax()
        self.attention_dropout = Dropout(self.atten_dropIndice)
    
    def get_config(self):
        config = super(MultiHead_Self_Attention,self).get_config()
        config.update({"embed_dims":self.embed_dims,
                       "num_heads":self.num_heads,
                       "atten_dropIndice":self.atten_dropIndice})
        return config

    def call(self,inputs):
        # num_heads = 12

        # -> b,197,768*3
        qkv = self.qkv(inputs)
        # q,k,v: b,197,768
        q,k,v = tf.split(qkv,3,axis=-1)

        b,n_patches,all_head_dims = q.shape

        # b,197,768 -> b,197,12,64
        q = tf.reshape(q,shape=[-1,n_patches,self.num_heads,self.head_dims])
        k = tf.reshape(k,shape=[-1,n_patches,self.num_heads,self.head_dims])
        v = tf.reshape(v,shape=[-1,n_patches,self.num_heads,self.head_dims])

        # b,197,12,64 -> b,12,197,64
        q = tf.transpose(q,[0,2,1,3])
        k = tf.transpose(k,[0,2,1,3])
        v = tf.transpose(v,[0,2,1,3])

        # -> b,12,12,64
        attention = tf.matmul(q,k,transpose_b=True)
        attention = self.scale * attention
        attention = self.softmax(attention)
        attention = self.attention_dropout(attention)

        # -> b,12,197,64
        attention = tf.matmul(attention,v)

        # b,12,197,64 -> b,197,12,64
        out = tf.transpose(attention,[0,2,1,3])
        # b,197,12,64 -> b,197,768
        out = tf.reshape(out,shape=[-1,n_patches,all_head_dims])

        out = self.proj(out)

        return out
    
    def compute_output_shape(self, input_shape):
        return input_shape
    
class MLP(Layer):
    def __init__(self,embed_dims,mlp_ratio=4.0,dropoutIndice=0.0,**kwargs):
        super(MLP,self).__init__(**kwargs)

        self.embed_dims = embed_dims
        self.mlp_ratio = mlp_ratio
        self.dropoutIndice = dropoutIndice

        self.fc1 = Dense(int(self.embed_dims*self.mlp_ratio))
        self.fc2 = Dense(self.embed_dims)
        self.mlp_droput = Dropout(self.dropoutIndice)
        

    def get_config(self):
        config = super(MLP,self).get_config()
        config.update({"embed_dims":self.embed_dims,
                       "mlp_ratio":self.mlp_ratio,
                       "dropoutIndice":self.dropoutIndice})

        return config
    
    def call(self,inputs):
        # mlp_ratio=4
        # b,197,768 -> b,197,768*4
        x = self.fc1(inputs)
        print('yes')
        x = gelu(x)
        x = self.mlp_droput(x)

        #  b,197,768*4 ->  b,197,768
        x = self.fc2(x)
        x = self.mlp_droput(x)

        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape


if __name__ == '__main__':
    inputs = Input(shape=(224,224,3),batch_size=4)
    # 1,224,224,3 -> 1,197,768
    patch_embedding = PatchEmbedding(224,16,768,name='patch_embed')(inputs)
    # 保存残差边
    h = patch_embedding

    patch_embedding = LayerNormalization(name='laynorm1')(patch_embedding)

    # 1,197,768 -> 1,197,768
    attention = MultiHead_Self_Attention(768,12,0,name='msa')(patch_embedding)
    attention = Add(name='Add1')([attention,h])
    # 保存残差边
    h = attention

    outputs = LayerNormalization(name='laynorm2')(attention)
    # 1,197,768 -> 1,197,768
    outputs = MLP(768,4,name='mlp')(outputs)
    outputs = Add(name='Add2')([outputs,h])

    model = Model(inputs=inputs,outputs=outputs)
    model.summary()

 

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(4, 224, 224, 3)]   0           []

 patch_embed (PatchEmbedding)   (4, 197, 768)        1198848     ['input_1[0][0]']

 laynorm1 (LayerNormalization)  (4, 197, 768)        1536        ['patch_embed[0][0]']

 msa (MultiHead_Self_Attention)  (4, 197, 768)       2362368     ['laynorm1[0][0]']

 Add1 (Add)                     (4, 197, 768)        0           ['msa[0][0]',
                                                                  'patch_embed[0][0]']

 laynorm2 (LayerNormalization)  (4, 197, 768)        1536        ['Add1[0][0]']

 mlp (MLP)                      (4, 197, 768)        4722432     ['laynorm2[0][0]']

 Add2 (Add)                     (4, 197, 768)        0           ['mlp[0][0]',
                                                                  'Add1[0][0]']

==================================================================================================
Total params: 8,286,720
Trainable params: 8,286,720
Non-trainable params: 0
举报

相关推荐

0 条评论