Java手搓简易Transfomer - 矩阵运算部分
本文内容较新
·
今天更新
最后更新: 2026年03月12日
预计阅读时间: 17.1 分钟
4270 字
250 字/分
这学期上了第一节Java课,没想到还在教Swing,这是何意味
不过不管了,正好想研究Transformer,所以我决定这学期就Java手搓Transformer了
首先从矩阵运算的部分来,毕竟数学的精髓就都在这里了
Matrix对象
我先创建了一个 Matrix 类,其中包含如下参数,为了提高运算时的速度,我选择使用一维的列表来代替二维
public class Matrix {
private double[] data;
private int row;
private int col;
public Matrix(int row, int col){
if(row <= 0 || col <= 0){
throw new RuntimeException("row and col must be positive");
}
this.row = row;
this.col = col;
this.data = new double[row * col];
}
public static Matrix random(int row, int col) {
//一个工厂方法,在开始时随机初始化内容
Matrix result = new Matrix(row, col);
result.randFill();
return result;
}Getter & Setter
接下来实现getter和setter,使得矩阵中的数据可以被操作
private void validateIndex(int row, int col){
if(row < 0 || row >= this.row || col < 0 || col >= this.col){
throw new RuntimeException("index out of range");
}
}
private int calPos(int row, int col){ //用于计算实际位置
return row * this.col + col;
}
public void setElement(int row, int col, double value){
validateIndex(row, col);
this.data[calPos(row, col)] = value;
}
public double getElement(int row,int col){
validateIndex(row, col);
return this.data[calPos(row, col)];
}数学计算
矩阵转置
public Matrix transpose(){
Matrix result = new Matrix(this.col, this.row);
for(int i = 0; i < this.row; i++){
for(int j = 0; j < this.col; j++){
result.setElement(j, i, this.getElement(i, j));
}
}
return result;
}矩阵乘法
public Matrix multiply(Matrix another){
if(this.col != another.row){
throw new RuntimeException("matrix size not match");
}
Matrix result = new Matrix(this.row, another.col);
for(int i = 0; i < this.row; i++){
for(int j = 0; j < another.col; j++){
double sum = 0;
for(int k = 0; k < this.col; k++){
sum += this.getElement(i, k) * another.getElement(k, j);
}
result.setElement(i, j, sum);
}
}
return result;
}矩阵数乘
public Matrix constantMultiply(double constant){
Matrix result = new Matrix(this.row, this.col);
for(int i = 0; i < this.row; i++){
for(int j = 0; j < this.col; j++){
result.setElement(i, j, this.getElement(i, j) * constant);
}
}
return result;
}矩阵减法
public Matrix minus(Matrix another){
if(this.row != another.row || this.col != another.col){
throw new RuntimeException("matrix size not match");
}else{
Matrix result = new Matrix(this.row,this.col);
for(int i = 0; i < this.row; i++){
for(int j = 0; j < this.col; j++){
result.setElement(i, j, this.getElement(i, j) - another.getElement(i, j));
}
}
return result;
}
}矩阵逐个加
public Matrix hadamard(Matrix another) {
if (this.row != another.row || this.col != another.col) {
throw new RuntimeException("matrix size not match");
}
Matrix result = new Matrix(this.row, this.col);
for (int i = 0; i < this.row; i++) {
for (int j = 0; j < this.col; j++) {
result.setElement(i, j, this.getElement(i, j) * another.getElement(i, j));
}
}
return result;
}矩阵加法
public Matrix add(Matrix another){
if(this.row != another.row || this.col != another.col){
throw new RuntimeException("matrix size not match");
}else{
Matrix result = new Matrix(this.row,this.col);
for(int i = 0; i < this.row; i++){
for(int j = 0; j < this.col; j++){
result.setElement(i, j, this.getElement(i, j) + another.getElement(i, j));
}
}
return result;
}
}toString
接下来,为了调试时更加便捷,我选择Override重写toString方法,让他能够格式化Matrix对象
@Override
public String toString(){
StringBuilder sb = new StringBuilder();
for(int i = 0 ; i< this.row; i++){
sb.append("[");
for(int j=0; j<this.col; j++){
sb.append(this.getElement(i,j));
if(j != this.col -1){
sb.append(", ");
}
}
sb.append("]");
if(i != this.row -1) sb.append("\n");
}
return sb.toString();
}Copy
考虑实现一下Matrix的深拷贝
public Matrix copy() {
Matrix result = new Matrix(this.row, this.col);
for (int i = 0; i < this.row; i++) {
for (int j = 0; j < this.col; j++) {
result.setElement(i, j, this.getElement(i, j));
}
}
return result;
}Softmax
深度学习的精髓,Softmax,这里做每行独立的Softmax
public Matrix softmaxByRow(){
Matrix result = new Matrix(this.row, this.col);
for(int i = 0; i < this.row; i++){
double max = this.getElement(i, 0);
for(int j = 1; j < this.col; j++){
if(this.getElement(i, j) > max){
max = this.getElement(i, j);
}
}
double sum = 0.0;
for(int j =0 ; j < this.col; j++){
double value = Math.exp(this.getElement(i, j) - max);
result.setElement(i, j, value);
sum += value;
}
for(int j=0; j < this.col; j++){
result.setElement(i, j, result.getElement(i, j) / sum);
}
}
return result;
}整个 Matrix 类的设计就到这里结束了
评论 暂无