本文内容较新 · 今天更新
最后更新: 2026年03月31日
预计阅读时间: 26.3 分钟
6575 字 250 字/分

Attension也是整个 Transformer 里最精髓的部分了, 也卡了我相当之久

公式解析部分

注意力公式如下, 很晦涩, 但我尽可能以简单的方式来解释这些问题

$$ Attention(Q,K,V)=softmax\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V $$

$QKV$计算

先解释 Q K V 分别是什么:

  • Q: Query 向量, 代表着我要查什么东西
  • K: Key 向量, 代表着当前被查询的这个东西
  • V: Value 向量, 实际上当前被查询的东西的具体内容
  • $d_{k}$: 向量的维度

简单来说, Q 负责确认"我要找什么", K 负责"我现在找到的是什么", V 是"找大的具体内容". 首先, 这三个矩阵的型是相同的

先看 $QK^{T}$, 两个高纬度的向量做点积, 根据高中的知识, 通常来说, 点积更大的, 这两个向量在方向上约接近

举个例子, 我们按照单词划分Token, 输入一个 3 Tokens的一句话

Token1 Token2 Token3

那对于这几个Token对应的矩阵, 就应当是三行(对应三个 Token) 二列(Token 维度为 2), 即输入矩阵 X

$$ X=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$

我们可以简单的认为, 其中 Token1 对应第一行的向量(1, 2), Token2 对应第二行的(0, 1)

接下来, 根据公式 $Q=XW_{Q},\quad K=XW_{K},\quad V=XW_{V}$ 来生成我们需要的 QKV 三个矩阵, 其中 $X$ 是输入, $W_{Q,K,V}$ 是一个可训练的权重矩阵, 为了方便理解, 我们直接令:

$W_{Q}=\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}$ $W_{K}=\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}$ $W_{V}=\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}$

由上述 QKV 公式, 我们可以计算

$$ Q=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix} $$

$$ K=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}=\begin{bmatrix}1 & 3\\ 0 & 1\\ 3 & 4\end{bmatrix} $$

$$ V=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}=\begin{bmatrix}3 & 2\\ 1 & 1\\ 4 & 1\end{bmatrix} $$

$QK^{T}$计算

接下来, 计算 $QK^{T}$

$QK^{T}=\begin{bmatrix}1 & 2\\ 0 & 1\\ 3 & 1\end{bmatrix}\cdot\begin{bmatrix}1 & 0 & 3\\ 3 & 1 & 4\end{bmatrix}=\begin{bmatrix}7 & 2 & 11\\ 3 & 1 & 4\\ 6 & 1 & 13\end{bmatrix}$

其中, 我们计算的结果矩阵中的每一行, 就是这个Token和其他另外的几个Token的相似程度, 并且$QK^{T}$一定是方阵, 每行和每列都对应一个Token

以第一行 Token1 为例, 通过计算得到, 在总共的三个Token中, Token1最关注Token3, 最不关注Token2, 但这只是一个初始的结果, 并不能直接用, 我们还需要进一步处理

除以 $\sqrt{d_{k}}$

为什么它偏偏要除以这个开方后的$d_k$呢, 在我们这个例子中可能看不出来, 因为我们选择的向量维度只有 2, 根据如下的公式

$$ Q_i \cdot K_j = \sum_{t=1}^{d_k} q_t k_t $$

随着我们向量维度的增大, 这个矩阵的数值也会开始爆炸, 大的数字会变得更大, 可能是这样的

$$ K^{T}=\begin{bmatrix}1 & 0 & 114514\\ 3 & 1 & 1919810114514\\ 888 & 888 & 888\end{bmatrix} $$

所以我们需要除以一个数, 来把矩阵里的每个参数的值拉回到一个正常的水平

$softmax$

接下来, 我们要把这个矩阵中的值转换为一个概率, 而最常用的方法就是Softmax, 把每个值映射到一个 0-1 的概率, 来决定注意力的分配

$$ softmax(x_{i})=\frac{e^{x_{i}}}{\sum e^{x_{j}}} $$

根据这个公式, 对每行单独进行softmax, 就计算出了每个Token对于其他Token的注意力分配, 最终得到的结果可能类似于 下面的结果

$$ softmax\left(\frac{QK^{T}}{\sqrt2}\right)\approx\begin{bmatrix}0.055 & 0.002 & 0.943\\ 0.304 & 0.074 & 0.622\\ 0.007 & 0.0002 & 0.993\end{bmatrix} $$

乘以$V$

根据上面的结果, 我们能发现, 对于Token1来说, 他对三个Token的注意力分布分别是0.055, 0.002, 0.943, 带入所有信息, 计算最终的Attention值

$$ Output=AV\approx\begin{bmatrix}3.94 & 1.06\\ 3.50 & 1.30\\ 3.99 & 1.01\end{bmatrix} $$

最终的结果的行数, 代表Token的总数, 而列数就是我们设定的向量维度的大小

代码实现

public class SingleSelfAttention {
    //Three Matrix which is QKV in the original paper
    private Matrix wq;
    private Matrix wk;
    private Matrix wv;
    //The dimension of every vector
    private int dModel;

    //Cache
    private Matrix lastInput;
    private Matrix lastQ,lastK,lastV,lastAttention;
    public SingleSelfAttention(int dModel){
        if(dModel <= 0){
            throw new IllegalArgumentException("dModel must be greater than 0\n");
        }
        this.dModel = dModel;

        //init the weights
        wq = Matrix.random(dModel, dModel);
        wk = Matrix.random(dModel, dModel);
        wv = Matrix.random(dModel, dModel);
    }

    public Matrix forward(Matrix matrix){


        lastInput = matrix;
        lastQ = matrix.multiply(wq);
        lastK = matrix.multiply(wk);
        lastV = matrix.multiply(wv);

        //Attention Attention = QK^T/sqrt(dModel)
        Matrix scores = lastQ.multiply(lastK.transpose());
        scores = scores.constantMultiply(1.0 / Math.sqrt(dModel));
        scores = applyCausalAttention(scores);

        //use softmax
        lastAttention = scores.softmaxByRow();

        return lastAttention.multiply(lastV);

    }

    public Matrix backward(Matrix gradOutput, double learningRate){

        //The gradient of attention
        Matrix gradAttention = gradOutput.multiply(lastV.transpose());
        // output = attention * V, so dV = attention^T * dOutput
        Matrix gradV = lastAttention.transpose().multiply(gradOutput);

        //The gradient of scores
        Matrix gradScores = new Matrix (lastAttention.getRow(), lastAttention.getCol());
        for(int i = 0 ; i < lastAttention.getRow() ; i++){
            double dot = 0.0;
            for(int j = 0 ; j < lastAttention.getCol(); j++){
                dot += gradAttention.getElement(i, j) * lastAttention.getElement(i, j);
            }
            for(int j = 0; j < lastAttention.getCol(); j++){
                double value = lastAttention.getElement(i, j)* (gradAttention.getElement(i, j) - dot);
                if(j > i){
                    value = 0.0;
                }
                gradScores.setElement(i, j, value/ Math.sqrt(dModel));
            }
        }

        Matrix gradQ = gradScores.multiply(lastK);
        Matrix gradK = gradScores.transpose().multiply(lastQ);

        Matrix gradWq = lastInput.transpose().multiply(gradQ);
        Matrix gradWk = lastInput.transpose().multiply(gradK);
        Matrix gradWv = lastInput.transpose().multiply(gradV);

        Matrix gradInput = gradQ.multiply(wq.transpose())
                .add(gradK.multiply(wk.transpose()))
                .add(gradV.multiply(wv.transpose()));
        updateWeights(wq, gradWq, learningRate);
        updateWeights(wk, gradWk, learningRate);
        updateWeights(wv, gradWv, learningRate);

        return gradInput;
    }
    public void updateWeights(Matrix weights, Matrix gradWeights, double learningRate){
        for(int i = 0; i < weights.getRow(); i++){
            for(int j = 0 ; j < weights.getCol(); j++){
                double value = weights.getElement(i, j);
                weights.setElement(i, j, value - learningRate * gradWeights.getElement(i, j));
            }
        }
    }

    private Matrix applyCausalAttention(Matrix scores){
        Matrix result = scores.copy();

        for(int i = 0; i < result.getRow(); i++){
            for(int j = i + 1; j < result.getCol(); j++){
                result.setElement(i, j, -1e9);
            }
        }
        return result;
    }

    public Matrix getWq() {
        return wq;
    }
    public Matrix getWk() {
        return wk;
    }
    public Matrix getWv() {
        return wv;
    }

    public void setWq(Matrix newWq){
        validateWeightShape(newWq, "wq");
        this.wq = newWq.copy();
    }
    public void setWk(Matrix newWk){
        validateWeightShape(newWk, "wk");
        this.wk = newWk.copy();
    }
    public void setWv(Matrix newWv){
        validateWeightShape(newWv, "wv");
        this.wv = newWv.copy();
    }

    private void validateWeightShape(Matrix matrix, String weightName){
        if(matrix == null){
            throw new IllegalArgumentException("The " + weightName + " cannot be null");
        }
        if(matrix.getRow() != dModel || matrix.getCol() != dModel){
            throw new IllegalArgumentException("The shape of " + weightName + " must be (" + dModel + ", " + dModel + ")");
        }
    }
}