本文内容较新 · 今天更新
最后更新: 2026年03月12日
预计阅读时间: 17.1 分钟
4270 字 250 字/分

这学期上了第一节Java课,没想到还在教Swing,这是何意味

不过不管了,正好想研究Transformer,所以我决定这学期就Java手搓Transformer了

首先从矩阵运算的部分来,毕竟数学的精髓就都在这里了

Matrix对象

我先创建了一个 Matrix 类,其中包含如下参数,为了提高运算时的速度,我选择使用一维的列表来代替二维

public class Matrix {

    private double[] data;
    private int row;
    private int col;


    public Matrix(int row, int col){
        if(row <= 0 || col <= 0){
            throw new RuntimeException("row and col must be positive");
        }
        this.row = row;
        this.col = col;
        this.data = new double[row * col];
    }
    public static Matrix random(int row, int col) { 
    //一个工厂方法,在开始时随机初始化内容
        Matrix result = new Matrix(row, col);
        result.randFill();
        return result;
    }

Getter & Setter

接下来实现getter和setter,使得矩阵中的数据可以被操作

    private void validateIndex(int row, int col){
        if(row < 0 || row >= this.row || col < 0 || col >= this.col){
            throw new RuntimeException("index out of range");
        }
    }
    private int calPos(int row, int col){ //用于计算实际位置
        return row * this.col + col;
    }
    public void setElement(int row, int col, double value){
        validateIndex(row, col);
        this.data[calPos(row, col)] = value;
    }
    public double getElement(int row,int col){
        validateIndex(row, col);
        return this.data[calPos(row, col)];
    }

数学计算

矩阵转置

    public Matrix transpose(){
        Matrix result = new Matrix(this.col, this.row);

        for(int i = 0; i < this.row; i++){
            for(int j = 0; j < this.col; j++){
                result.setElement(j, i, this.getElement(i, j));
            }
        }
        return result;
    }

矩阵乘法


    public Matrix multiply(Matrix another){
        if(this.col != another.row){
            throw new RuntimeException("matrix size not match");
        }
        Matrix result = new Matrix(this.row, another.col);

        for(int i = 0; i < this.row; i++){
            for(int j = 0; j < another.col; j++){
                double sum = 0;
                for(int k = 0; k < this.col; k++){
                    sum += this.getElement(i, k) * another.getElement(k, j);
                }
                result.setElement(i, j, sum);
            }
        }
        return result;
    }

矩阵数乘

    public Matrix constantMultiply(double constant){
        Matrix result = new Matrix(this.row, this.col);
        for(int i = 0; i < this.row; i++){
            for(int j = 0; j < this.col; j++){
                result.setElement(i, j, this.getElement(i, j) * constant);
            }
        }
        return result;
    }

矩阵减法

    public Matrix minus(Matrix another){
        if(this.row != another.row || this.col != another.col){
            throw new RuntimeException("matrix size not match");
        }else{
            Matrix result = new Matrix(this.row,this.col);
            for(int i = 0; i < this.row; i++){
                for(int j = 0; j < this.col; j++){
                    result.setElement(i, j, this.getElement(i, j) - another.getElement(i, j));
                }
            }
            return result;
        }
    }

矩阵逐个加

    public Matrix hadamard(Matrix another) {
        if (this.row != another.row || this.col != another.col) {
            throw new RuntimeException("matrix size not match");
        }

        Matrix result = new Matrix(this.row, this.col);

        for (int i = 0; i < this.row; i++) {
            for (int j = 0; j < this.col; j++) {
                result.setElement(i, j, this.getElement(i, j) * another.getElement(i, j));
            }
        }

        return result;
    }

矩阵加法

  public Matrix add(Matrix another){
        if(this.row != another.row || this.col != another.col){
            throw new RuntimeException("matrix size not match");
        }else{
            Matrix result = new Matrix(this.row,this.col);
            for(int i = 0; i < this.row; i++){
                for(int j = 0; j < this.col; j++){
                    result.setElement(i, j, this.getElement(i, j) + another.getElement(i, j));
                }
            }
            return result;
        }
    }

toString

接下来,为了调试时更加便捷,我选择Override重写toString方法,让他能够格式化Matrix对象

    @Override
    public String toString(){
        StringBuilder sb = new StringBuilder();

        for(int i = 0 ; i< this.row; i++){
            sb.append("[");
            for(int j=0; j<this.col; j++){
                sb.append(this.getElement(i,j));
                if(j != this.col -1){
                    sb.append(", ");
                }
            }
            sb.append("]");
            if(i != this.row -1) sb.append("\n");
        }
        return sb.toString();
    }

Copy

考虑实现一下Matrix的深拷贝

    public Matrix copy() {
        Matrix result = new Matrix(this.row, this.col);

        for (int i = 0; i < this.row; i++) {
            for (int j = 0; j < this.col; j++) {
                result.setElement(i, j, this.getElement(i, j));
            }
        }

        return result;
    }

Softmax

深度学习的精髓,Softmax,这里做每行独立的Softmax

   public Matrix softmaxByRow(){
        Matrix result = new Matrix(this.row, this.col);
        for(int i = 0; i < this.row; i++){
            double max = this.getElement(i, 0);
            for(int j = 1; j < this.col; j++){
                if(this.getElement(i, j) > max){
                    max = this.getElement(i, j);
                }
            }
            double sum = 0.0;
            for(int j =0 ; j < this.col; j++){
                double value = Math.exp(this.getElement(i, j) - max);
                result.setElement(i, j, value);
                sum += value;
            }

            for(int j=0; j < this.col; j++){
                result.setElement(i, j, result.getElement(i, j) / sum);
            }
        }
        return result;
    }

整个 Matrix 类的设计就到这里结束了