0
点赞
收藏
分享

微信扫一扫

带参数的注意力汇聚

柠檬果然酸 2022-02-19 阅读 99

文章目录

1. 公式

我们在查询x和键 x i x_i xi之间的距离中新增一个可学习的参数w,具体表示如下:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ⁡ ( − 1 2 ( x − x i ) w ) 2 ∑ j = 1 n exp ⁡ ( − 1 2 ( ( x − x j ) w ) 2 ) y i f(x)=\sum_{i=1}^n\alpha(x,x_i)y_i=\sum_{i=1}^n\frac{\exp(-\frac{1}{2}(x-x_i)w)^2}{\sum_{j=1}^n\exp(-\frac{1}{2}((x-x_j)w)^2)}y_i f(x)=i=1nα(x,xi)yi=i=1nj=1nexp(21((xxj)w)2)exp(21(xxi)w)2yi
f ( x ) = ∑ i = 1 n s o f t m a x ( − 1 2 ( ( x − x i ) w ) 2 ) y i f(x)=\sum_{i=1}^n softmax(-\frac{1}{2}((x-x_i)w)^2)y_i f(x)=i=1nsoftmax(21((xxi)w)2)yi

  • x:表示查询
  • x i x_i xi:表示键
  • y i y_i yi:表示值

2. 代码


# 导入相关数据库
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 设置数据大小
n_train = 50


# 定义相关函数
def f(x):
	return 2 * torch.sin(x) + x ** 0.8


# 随机生成 n_train 个数据[升序]->x_train
x_train, _ = torch.sort(torch.rand(n_train) * 5)
y_train = f(x_train) + torch.normal(0, 0.5, (n_train,))

# 生成x轴,等间距
x_test = torch.arange(0, 5, 0.1)

# 获得y_truth
y_truth = f(x_test)


# 画图
def plot_kernel_reg(y_hat):
	# 曲线图
	d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
			 xlim=[0, 5], ylim=[-1, 5])
	# 圆点图
	d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);


# 定义一个 "带参数的注意力汇聚" 类
class NWKernelRegression(nn.Module):
	def __init__(self, **kwargs):
		super().__init__(**kwargs)
		# 定义一个可学习的参数
		self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

	def forward(self, queries, keys, values):
		queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
		# queries=(查询个数,"键-值"对个数)
		# self.attention_weight=(查询个数,1,"键-值"对个数)
		# value=查询个数,"键-值"对个数,1)
		# return =(查询个数,1).reshape(-1)=(查询个数)
		# 定义一个可学习的权重;
		self.attention_weight = nn.functional.softmax(
			-((queries - keys) * self.w) ** 2 / 2, dim=1)
		# 返回 加权重后的y值
		return torch.bmm(self.attention_weight.unsqueeze(1),
						 values.unsqueeze(-1)).reshape(-1)


# x_title.shape =[50,50];repeat函数起到复制作用
# y_title.shape=[50,50]
x_title = x_train.repeat((n_train, 1))
y_title = y_train.repeat((n_train, 1))

# 将x_title里面的对角线值去掉,得到 key
# 将y_title里面的对角线值去掉,得到 values
keys = x_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
values = y_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

# 实例化带参数的注意力汇聚网络;
net = NWKernelRegression()
# reduction='none'表示损失单独计算
loss = nn.MSELoss(reduction='none')
# 选用随机梯度下降进行训练,学习率为0.5
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

# 开始迭代训练
for epoch in range(5):
	# 设置优化器梯度为0
	trainer.zero_grad()
	# 计算损失
	l = loss(net(x_train, keys, values), y_train)
	# 损失变成标量,损失回传
	l.sum().backward()
	# 开始迭代优化
	trainer.step()
	print(f'epoch{epoch + 1},loss{float(l.sum()):.6f}')
	animator.add(epoch + 1, float(l.sum()))

# 键key,值values
keys = x_train.repeat((n_train, 1))
values = y_train.repeat((n_train, 1))

# 得到y_hat ,为了画图用detach()
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)
d2l.show_heatmaps(net.attention_weight.unsqueeze(0).unsqueeze(0), xlabel='x', ylabel='y')

plt.show()
举报

相关推荐

0 条评论