Java手搓简易Transfomer - Attention部分
Attension也是整个 Transformer 里最精髓的部分了, 也卡了我相当之久
公式解析部分
注意力公式如下, 很晦涩, 但我尽可能以简单的方式来解释这些问题
$$ Attention(Q,K,V)=softmax\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V $$
$QKV$计算
先解释 Q K V 分别是什么:
- Q: Query 向量, 代表着我要查什么东西
- K: Key 向量, 代表着当前被查询的这个东西
- V: Value 向量, 实际上当前被查询的东西的具体内容
- $d_{k}$: 向量的维度
简单来说, Q 负责确认"我要找什么", K 负责"我现在找到的是什么", V 是"找大的具体内容". 首先, 这三个矩阵的型是相同的
先看 $QK^{T}$, 两个高纬度的向量做点积, 根据高中的知识, 通常来说, 点积更大的, 这两个向量在方向上约接近
举个例子, 我们按照单词划分Token, 输入一个 3 Tokens的一句话
Token1 Token2 Token3
那对于这几个Token对应的矩阵, 就应当是三行(对应三个 Token) 二列(Token 维度为 2), 即输入矩阵 X
$$ X=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$
我们可以简单的认为, 其中 Token1 对应第一行的向量(1, 2), Token2 对应第二行的(0, 1)
接下来, 根据公式 $Q=XW_{Q},\quad K=XW_{K},\quad V=XW_{V}$ 来生成我们需要的 QKV 三个矩阵, 其中 $X$ 是输入, $W_{Q,K,V}$ 是一个可训练的权重矩阵, 为了方便理解, 我们直接令:
$W_{Q}=\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}$ $W_{K}=\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}$ $W_{V}=\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}$
由上述 QKV 公式, 我们可以计算
$$ Q=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$
$$ K=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 3\\ 0 & 1\\ 3 & 4\end{bmatrix} $$
$$ V=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}=\begin{bmatrix}3 & 2\\ 1 & 1\\ 4 & 1\end{bmatrix} $$
$QK^{T}$计算
接下来, 计算 $QK^{T}$
$QK^{T}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0 & 3\\ 3 & 1 & 4\end{bmatrix}=\begin{bmatrix}7 & 2 & 11\\ 3 & 1 & 4\\ 6 & 1 & 13\end{bmatrix}$
其中, 我们计算的结果矩阵中的每一行, 就是这个Token和其他另外的几个Token的相似程度, 并且$QK^{T}$一定是方阵, 每行和每列都对应一个Token
以第一行 Token1 为例, 通过计算得到, 在总共的三个Token中, Token1最关注Token3, 最不关注Token2, 但这只是一个初始的结果, 并不能直接用, 我们还需要进一步处理
除以 $\sqrt{d_{k}}$
为什么它偏偏要除以这个开方后的$d_k$呢, 在我们这个例子中可能看不出来, 因为我们选择的向量维度只有 2, 根据如下的公式
$$ Q_i \cdot K_j = \sum_{t=1}^{d_k} q_t k_t $$
随着我们向量维度的增大, 这个矩阵的数值也会开始爆炸, 大的数字会变得更大, 可能是这样的
$$ K^{T}=\begin{bmatrix}1 & 0 & 114514\\ 3 & 1 & 1919810114514\\ 888 & 888 & 888\end{bmatrix} $$
所以我们需要除以一个数, 来把矩阵里的每个参数的值拉回到一个正常的水平
$softmax$
接下来, 我们要把这个矩阵中的值转换为一个概率, 而最常用的方法就是Softmax, 把每个值映射到一个 0-1 的概率, 来决定注意力的分配
$$ softmax(x_{i})=\frac{e^{x_{i}}}{\sum e^{x_{j}}} $$
根据这个公式, 对每行单独进行softmax, 就计算出了每个Token对于其他Token的注意力分配, 最终得到的结果可能类似于 下面的结果
$$ softmax\left(\frac{QK^{T}}{\sqrt2}\right)\approx\begin{bmatrix}0.055 & 0.002 & 0.943\\ 0.304 & 0.074 & 0.622\\ 0.007 & 0.0002 & 0.993\end{bmatrix} $$
乘以$V$
根据上面的结果, 我们能发现, 对于Token1来说, 他对三个Token的注意力分布分别是0.055, 0.002, 0.943, 带入所有信息, 计算最终的Attention值
$$ Output=AV\approx\begin{bmatrix}3.94 & 1.06\\ 3.50 & 1.30\\ 3.99 & 1.01\end{bmatrix} $$
最终的结果的行数, 代表Token的总数, 而列数就是我们设定的向量维度的大小
代码实现
public class SingleSelfAttention {
//Three Matrix which is QKV in the original paper
private Matrix wq;
private Matrix wk;
private Matrix wv;
//The dimension of every vector
private int dModel;
//Cache
private Matrix lastInput;
private Matrix lastQ,lastK,lastV,lastAttention;
public SingleSelfAttention(int dModel){
if(dModel <= 0){
throw new IllegalArgumentException("dModel must be greater than 0\n");
}
this.dModel = dModel;
//init the weights
wq = Matrix.random(dModel, dModel);
wk = Matrix.random(dModel, dModel);
wv = Matrix.random(dModel, dModel);
}
public Matrix forward(Matrix matrix){
lastInput = matrix;
lastQ = matrix.multiply(wq);
lastK = matrix.multiply(wk);
lastV = matrix.multiply(wv);
//Attention Attention = QK^T/sqrt(dModel)
Matrix scores = lastQ.multiply(lastK.transpose());
scores = scores.constantMultiply(1.0 / Math.sqrt(dModel));
scores = applyCausalAttention(scores);
//use softmax
lastAttention = scores.softmaxByRow();
return lastAttention.multiply(lastV);
}
public Matrix backward(Matrix gradOutput, double learningRate){
//The gradient of attention
Matrix gradAttention = gradOutput.multiply(lastV.transpose());
// output = attention * V, so dV = attention^T * dOutput
Matrix gradV = lastAttention.transpose().multiply(gradOutput);
//The gradient of scores
Matrix gradScores = new Matrix (lastAttention.getRow(), lastAttention.getCol());
for(int i = 0 ; i < lastAttention.getRow() ; i++){
double dot = 0.0;
for(int j = 0 ; j < lastAttention.getCol(); j++){
dot += gradAttention.getElement(i, j) * lastAttention.getElement(i, j);
}
for(int j = 0; j < lastAttention.getCol(); j++){
double value = lastAttention.getElement(i, j)* (gradAttention.getElement(i, j) - dot);
if(j > i){
value = 0.0;
}
gradScores.setElement(i, j, value/ Math.sqrt(dModel));
}
}
Matrix gradQ = gradScores.multiply(lastK);
Matrix gradK = gradScores.transpose().multiply(lastQ);
Matrix gradWq = lastInput.transpose().multiply(gradQ);
Matrix gradWk = lastInput.transpose().multiply(gradK);
Matrix gradWv = lastInput.transpose().multiply(gradV);
Matrix gradInput = gradQ.multiply(wq.transpose())
.add(gradK.multiply(wk.transpose()))
.add(gradV.multiply(wv.transpose()));
updateWeights(wq, gradWq, learningRate);
updateWeights(wk, gradWk, learningRate);
updateWeights(wv, gradWv, learningRate);
return gradInput;
}
public void updateWeights(Matrix weights, Matrix gradWeights, double learningRate){
for(int i = 0; i < weights.getRow(); i++){
for(int j = 0 ; j < weights.getCol(); j++){
double value = weights.getElement(i, j);
weights.setElement(i, j, value - learningRate * gradWeights.getElement(i, j));
}
}
}
private Matrix applyCausalAttention(Matrix scores){
Matrix result = scores.copy();
for(int i = 0; i < result.getRow(); i++){
for(int j = i + 1; j < result.getCol(); j++){
result.setElement(i, j, -1e9);
}
}
return result;
}
public Matrix getWq() {
return wq;
}
public Matrix getWk() {
return wk;
}
public Matrix getWv() {
return wv;
}
public void setWq(Matrix newWq){
validateWeightShape(newWq, "wq");
this.wq = newWq.copy();
}
public void setWk(Matrix newWk){
validateWeightShape(newWk, "wk");
this.wk = newWk.copy();
}
public void setWv(Matrix newWv){
validateWeightShape(newWv, "wv");
this.wv = newWv.copy();
}
private void validateWeightShape(Matrix matrix, String weightName){
if(matrix == null){
throw new IllegalArgumentException("The " + weightName + " cannot be null");
}
if(matrix.getRow() != dModel || matrix.getCol() != dModel){
throw new IllegalArgumentException("The shape of " + weightName + " must be (" + dModel + ", " + dModel + ")");
}
}
}
评论 暂无