2 注意力汇聚:Nadaraya-Watson核回归
- 查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚;注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。以Nadaraya‐Watson核回归模型为例介绍注意力汇聚
1 | import torch |
2.1 生成数据集
简单起见,考虑下面这个回归问题:给定的成对的“输入-输出”数据集(x1,y1),…,(xn,yn),如何学习f来预测任意新输入x所对应的输出ˆy=f(x)?
生成一个人工数据集:
yi=2sin(xi)+x0.8i+ϵ,ϵ∼N(0,0.5)
- 生成50个训练样本和50个测试样本。为了更好地可视化之后的注意力模式,需要将训练样本进行排序。
1 | n_train = 50 # 训练样本数 |
- 定义函数,绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数f(标记为“Truth”),以及学习得到的预测函数(标记为“Pred”)。
1 | def plot_kernel_reg(y_hat): |
2.2 平均汇聚
- 使用平均汇聚,真实函数与预测函数相差很大
1 | import os |
2.3 非参数注意力汇聚
- 平均汇聚忽略了输入xi。于是Nadaraya和Watson提出了一个更好的想法,根据输入的位置对输出yi进行加权:
f(x)=n∑i=1K(x−xi)∑nj=1K(x−xj)yi
- 其中K是核(kernel)。上述预测器被称为Nadaraya-Watson核回归。受此启发,我们可以从注意力机制框架的角度重写,成为一个更加通用的注意力汇聚(attention pooling)公式:
f(x)=n∑i=1α(x,xi)yi
其中x是查询,xi,yi是键值对。注意力汇聚是yi的加权平均。将查询x和键xi之间的关系建模为注意力权重α(x,xi),这个权重将被分配给每一个对应值yi。对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布:它们是非负的,并且总和为1。
为了更好地理解注意力汇聚,下面考虑一个高斯核:
K(u)=1√2πexp(−u22)
- 因此
f(x)=n∑i=1α(x,xi)yi =n∑i=1exp(−12(x−xi)2)∑nj=1exp(−12(x−xj)2)yi =n∑i=1softmax(−12(x−xi)2)yi
如果一个键xi与查询x越相似,那么分配给这个键对应值yi的注意力权重就越大,也就“获得了更多的注意力”
Nadaraya‐Watson核回归是一个非参数模型。
基于这个非参数的注意力汇聚模型来绘制预测结果。从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。
1 | # X_repeat的形状:(n_test,n_train), |
- 观察注意力的权重。这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知“查询‐键”对越接近,注意力汇聚的注意力权重就越高。
1 | d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs') |
2.4 带参数注意力汇聚
非参数的Nadaraya‐Watson核回归具有一致性(consistency)的优点:如果有足够的数据,此模型会收敛到最优结果。尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。
例如,在查询x和键xi之间的距离乘以可学习参数w:
f(x)=n∑i=1α(x,xi)yi =n∑i=1exp(−12((x−xi)w)2)∑nj=1exp(−12((x−xj)w)2)yi =n∑i=1softmax(−12((x−xi)w)2)yi
2.4.1 批量矩阵乘法
batch1:n个a*b的矩阵:X1,…,Xn
batch2:n个b*c的矩阵:Y1,…,Yn
它们的批量矩阵乘法得到n个a*c矩阵X1Y1,…,XnYn
(n, a, b) * (n, b, c) -> (n, a, c)
1 | X = torch.ones((2, 1, 4)) |
torch.Size([2, 1, 6])
- 在注意力机制背景中,我们可以使用批量矩阵乘法来计算小批量数据中的加权平均值。
1 | weights = torch.ones((2, 10)) * 0.1 |
tensor([[[ 4.5000]],
[[14.5000]]])
2.4.2 定义模型
- Nadaraya‐Watson核回归的带参数版本:
1 | class NWKernelRegression(nn.Module): #与无参的类似 |
2.4.3 训练
- 将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出。
1 | # X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入 |
- 在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑。
1 | # keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键) |
- 看一下输出结果的绘制图:与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑。
1 | d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs') |