0
点赞
收藏
分享

微信扫一扫

自然语言处理(十九):Transformer前馈全连接层

自然语言处理笔记总目录


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x):
        x = self.w1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.w2(x)
        return x
d_model = 512
d_ff = 2048
dropout = 0.2

x = out_mha
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
out_ff = ff(x)

print(out_ff)
print(out_ff.shape)
tensor([[[ 1.5385, -0.8189, -1.6781,  ...,  1.5737,  0.4905,  0.8483],
         [ 2.8966, -2.4892, -1.9388,  ...,  1.7022, -0.2211, -0.7838],
         [ 1.4625, -1.2973, -0.4546,  ...,  2.4504, -1.5376, -0.8824],
         [ 2.4148, -1.8958, -1.6720,  ...,  1.6979,  0.3737, -0.1442]],

        [[ 0.9309,  1.1935,  1.1984,  ...,  2.3999,  0.3744,  0.2678],
         [ 1.2424,  0.0684,  1.7166,  ...,  2.2012, -0.7395,  0.5636],
         [ 1.3801, -0.1511,  1.3062,  ...,  1.5764, -0.5672,  0.4452],
         [ 1.5959,  0.1437,  1.5425,  ...,  2.1625, -1.0858,  0.1428]]],
       grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])
举报

相关推荐

0 条评论