文章目录
1. 公式
我们在查询x和键
x
i
x_i
xi之间的距离中新增一个可学习的参数w,具体表示如下:
f
(
x
)
=
∑
i
=
1
n
α
(
x
,
x
i
)
y
i
=
∑
i
=
1
n
exp
(
−
1
2
(
x
−
x
i
)
w
)
2
∑
j
=
1
n
exp
(
−
1
2
(
(
x
−
x
j
)
w
)
2
)
y
i
f(x)=\sum_{i=1}^n\alpha(x,x_i)y_i=\sum_{i=1}^n\frac{\exp(-\frac{1}{2}(x-x_i)w)^2}{\sum_{j=1}^n\exp(-\frac{1}{2}((x-x_j)w)^2)}y_i
f(x)=i=1∑nα(x,xi)yi=i=1∑n∑j=1nexp(−21((x−xj)w)2)exp(−21(x−xi)w)2yi
f
(
x
)
=
∑
i
=
1
n
s
o
f
t
m
a
x
(
−
1
2
(
(
x
−
x
i
)
w
)
2
)
y
i
f(x)=\sum_{i=1}^n softmax(-\frac{1}{2}((x-x_i)w)^2)y_i
f(x)=i=1∑nsoftmax(−21((x−xi)w)2)yi
- x:表示查询
- x i x_i xi:表示键
- y i y_i yi:表示值
2. 代码
# 导入相关数据库
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
# 设置数据大小
n_train = 50
# 定义相关函数
def f(x):
return 2 * torch.sin(x) + x ** 0.8
# 随机生成 n_train 个数据[升序]->x_train
x_train, _ = torch.sort(torch.rand(n_train) * 5)
y_train = f(x_train) + torch.normal(0, 0.5, (n_train,))
# 生成x轴,等间距
x_test = torch.arange(0, 5, 0.1)
# 获得y_truth
y_truth = f(x_test)
# 画图
def plot_kernel_reg(y_hat):
# 曲线图
d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
xlim=[0, 5], ylim=[-1, 5])
# 圆点图
d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
# 定义一个 "带参数的注意力汇聚" 类
class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 定义一个可学习的参数
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
def forward(self, queries, keys, values):
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
# queries=(查询个数,"键-值"对个数)
# self.attention_weight=(查询个数,1,"键-值"对个数)
# value=查询个数,"键-值"对个数,1)
# return =(查询个数,1).reshape(-1)=(查询个数)
# 定义一个可学习的权重;
self.attention_weight = nn.functional.softmax(
-((queries - keys) * self.w) ** 2 / 2, dim=1)
# 返回 加权重后的y值
return torch.bmm(self.attention_weight.unsqueeze(1),
values.unsqueeze(-1)).reshape(-1)
# x_title.shape =[50,50];repeat函数起到复制作用
# y_title.shape=[50,50]
x_title = x_train.repeat((n_train, 1))
y_title = y_train.repeat((n_train, 1))
# 将x_title里面的对角线值去掉,得到 key
# 将y_title里面的对角线值去掉,得到 values
keys = x_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
values = y_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# 实例化带参数的注意力汇聚网络;
net = NWKernelRegression()
# reduction='none'表示损失单独计算
loss = nn.MSELoss(reduction='none')
# 选用随机梯度下降进行训练,学习率为0.5
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
# 开始迭代训练
for epoch in range(5):
# 设置优化器梯度为0
trainer.zero_grad()
# 计算损失
l = loss(net(x_train, keys, values), y_train)
# 损失变成标量,损失回传
l.sum().backward()
# 开始迭代优化
trainer.step()
print(f'epoch{epoch + 1},loss{float(l.sum()):.6f}')
animator.add(epoch + 1, float(l.sum()))
# 键key,值values
keys = x_train.repeat((n_train, 1))
values = y_train.repeat((n_train, 1))
# 得到y_hat ,为了画图用detach()
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
d2l.show_heatmaps(net.attention_weight.unsqueeze(0).unsqueeze(0), xlabel='x', ylabel='y')
plt.show()