package com.hk.neuralnetwork;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

/* loaded from: input_file:com/hk/neuralnetwork/Mat.class */
public class Mat implements Serializable, Cloneable {
    public final int rows;
    public final int cols;
    public final double[][] data;
    public static MatFunc SIGMOID = new MatFunc() { // from class: com.hk.neuralnetwork.Mat.1
        @Override // com.hk.neuralnetwork.Mat.MatFunc
        public double perform(double d, int i, int i2) {
            return 1.0d / (1.0d + Math.exp(-d));
        }
    };
    public static MatFunc SIGMOID_DERIVATIVE = new MatFunc() { // from class: com.hk.neuralnetwork.Mat.2
        @Override // com.hk.neuralnetwork.Mat.MatFunc
        public double perform(double d, int i, int i2) {
            return d * (1.0d - d);
        }
    };
    public static MatFunc TANH = new MatFunc() { // from class: com.hk.neuralnetwork.Mat.3
        @Override // com.hk.neuralnetwork.Mat.MatFunc
        public double perform(double d, int i, int i2) {
            return Math.tanh(d);
        }
    };
    public static MatFunc TANH_DERIVATIVE = new MatFunc() { // from class: com.hk.neuralnetwork.Mat.4
        @Override // com.hk.neuralnetwork.Mat.MatFunc
        public double perform(double d, int i, int i2) {
            return 1.0d - (d * d);
        }
    };
    private static final long serialVersionUID = 3107367440033528127L;

    /* loaded from: input_file:com/hk/neuralnetwork/Mat$MatFunc.class */
    public interface MatFunc {
        double perform(double d, int i, int i2);
    }

    public Mat(int i, int i2) {
        this.rows = i;
        this.cols = i2;
        this.data = new double[i][i2];
    }

    public Mat(double[][] dArr) {
        this.rows = dArr.length;
        this.cols = dArr[0].length;
        this.data = new double[this.rows][this.cols];
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.cols; i2++) {
                this.data[i][i2] = dArr[i][i2];
            }
        }
    }

    public Mat randomize() {
        return randomize(ThreadLocalRandom.current());
    }

    public Mat randomize(final Random random) {
        return map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.5
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                return (random.nextDouble() * 2.0d) - 1.0d;
            }
        });
    }

    public Mat add(final Mat mat) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.6
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                return d + mat.data[i][i2];
            }
        });
    }

    public Mat add(final double d) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.7
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d2, int i, int i2) {
                return d2 + d;
            }
        });
    }

    public Mat subtract(final Mat mat) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.8
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                return d - mat.data[i][i2];
            }
        });
    }

    public Mat subtract(final double d) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.9
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d2, int i, int i2) {
                return d2 - d;
            }
        });
    }

    public Mat transpose() {
        return new Mat(this.cols, this.rows).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.10
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                return Mat.this.data[i2][i];
            }
        });
    }

    public Mat mult(final double d) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.11
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d2, int i, int i2) {
                return d2 * d;
            }
        });
    }

    public Mat elementMult(final Mat mat) {
        return new Mat(this.data).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.12
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                return d * mat.data[i][i2];
            }
        });
    }

    public Mat mult(final Mat mat) {
        if (this.cols != mat.rows) {
            throw new RuntimeException("Rows don't match columns");
        }
        return new Mat(this.rows, mat.cols).map(new MatFunc() { // from class: com.hk.neuralnetwork.Mat.13
            @Override // com.hk.neuralnetwork.Mat.MatFunc
            public double perform(double d, int i, int i2) {
                double d2 = 0.0d;
                for (int i3 = 0; i3 < Mat.this.cols; i3++) {
                    d2 += Mat.this.data[i][i3] * mat.data[i3][i2];
                }
                return d2;
            }
        });
    }

    public Mat map(MatFunc matFunc) {
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.cols; i2++) {
                this.data[i][i2] = matFunc.perform(this.data[i][i2], i, i2);
            }
        }
        return this;
    }

    public double[] toArray() {
        double[] dArr = new double[this.rows * this.cols];
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.cols; i2++) {
                dArr[i2 + (i * this.cols)] = this.data[i][i2];
            }
        }
        return dArr;
    }

    public double[] getColumn(int i) {
        double[] dArr = new double[this.rows];
        for (int i2 = 0; i2 < this.rows; i2++) {
            dArr[i2] = this.data[i2][i];
        }
        return dArr;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Mat m39clone() {
        return new Mat(this.data);
    }

    public String toArrayString() {
        return Arrays.deepToString(this.data);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.rows; i++) {
            sb.append("[");
            for (int i2 = 0; i2 < this.cols; i2++) {
                sb.append(this.data[i][i2]);
                if (i2 < this.cols - 1) {
                    sb.append(", ");
                }
            }
            sb.append(']');
            if (i < this.rows - 1) {
                sb.append('\n');
            }
        }
        return sb.toString();
    }

    public static Mat fromArray(double[] dArr) {
        Mat mat = new Mat(dArr.length, 1);
        for (int i = 0; i < dArr.length; i++) {
            mat.data[i][0] = dArr[i];
        }
        return mat;
    }
}
