package org.generallib.deeplearning.neuralnetwork;

import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/generallib/deeplearning/neuralnetwork/NeuralNetwork.class */
public class NeuralNetwork {
    private final int[] layerCounts;
    private final DoubleMatrix[] theta;
    private final int outputRange;
    private ActivationFunction act;

    /* JADX WARN: Type inference failed for: r2v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v38, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v40, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v42, types: [double[], double[][]] */
    public static void main(String[] strArr) throws Exception {
        NeuralNetwork neuralNetwork = new NeuralNetwork(new int[]{5, 6, 3}, new ActivationFunction() { // from class: org.generallib.deeplearning.neuralnetwork.NeuralNetwork.1
            @Override // org.generallib.deeplearning.neuralnetwork.ActivationFunction
            public DoubleMatrix activate(DoubleMatrix doubleMatrix) {
                return MatrixFunctions.exp(doubleMatrix.mul(-1.0d)).add(1.0d).rdiv(1.0d);
            }
        });
        System.out.println(neuralNetwork);
        DoubleMatrix doubleMatrix = new DoubleMatrix((double[][]) new double[]{new double[]{72.0d, 69.0d, 76.0d, 76.0d, 79.0d}, new double[]{72.0d, 69.0d, 76.0d, 76.0d, 111.0d}, new double[]{72.0d, 69.0d, 76.0d, 108.0d, 79.0d}, new double[]{72.0d, 69.0d, 76.0d, 108.0d, 111.0d}, new double[]{72.0d, 69.0d, 108.0d, 76.0d, 79.0d}, new double[]{72.0d, 69.0d, 108.0d, 108.0d, 79.0d}, new double[]{72.0d, 69.0d, 108.0d, 76.0d, 111.0d}, new double[]{104.0d, 101.0d, 108.0d, 108.0d, 111.0d}, new double[]{104.0d, 105.0d, 32.0d, 32.0d, 32.0d}, new double[]{104.0d, 73.0d, 32.0d, 32.0d, 32.0d}, new double[]{72.0d, 105.0d, 32.0d, 32.0d, 32.0d}, new double[]{72.0d, 73.0d, 32.0d, 32.0d, 32.0d}, new double[]{32.0d, 104.0d, 105.0d, 32.0d, 32.0d}, new double[]{32.0d, 72.0d, 105.0d, 32.0d, 32.0d}, new double[]{32.0d, 104.0d, 73.0d, 32.0d, 32.0d}, new double[]{32.0d, 72.0d, 73.0d, 32.0d, 32.0d}, new double[]{104.0d, 32.0d, 73.0d, 32.0d, 32.0d}});
        DoubleMatrix div = doubleMatrix.div(doubleMatrix.max() - doubleMatrix.min());
        double d = 0.0d;
        for (int i = 0; i < 10; i++) {
            neuralNetwork.resetLayers();
            double d2 = 0.001d * i * i;
            for (int i2 = 0; i2 < 300; i2++) {
                d = neuralNetwork.trainNetwork(div, new DoubleMatrix(new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d, 1.0d}), d2);
            }
            System.out.println("\nlambda [" + d2 + "] >>> " + d + "\n");
            System.out.println("hello: " + neuralNetwork.predict(new DoubleMatrix((double[][]) new double[]{new double[]{104.0d, 101.0d, 108.0d, 108.0d, 111.0d}})));
            System.out.println("hi: " + neuralNetwork.predict(new DoubleMatrix((double[][]) new double[]{new double[]{104.0d, 105.0d, 32.0d, 32.0d, 32.0d}})));
            System.out.println("happy: " + neuralNetwork.predict(new DoubleMatrix((double[][]) new double[]{new double[]{104.0d, 97.0d, 112.0d, 112.0d, 121.0d}})));
        }
    }

    public NeuralNetwork(int[] iArr, ActivationFunction activationFunction) throws NeuralNetworkInitializeException {
        this.act = activationFunction;
        if (iArr.length < 3) {
            throw new InvalidLayerCountException();
        }
        this.layerCounts = iArr;
        this.theta = new DoubleMatrix[iArr.length - 1];
        this.outputRange = iArr[iArr.length - 1];
        resetLayers();
    }

    public void resetLayers() {
        for (int i = 0; i < this.layerCounts.length - 1; i++) {
            this.theta[i] = DoubleMatrix.rand(this.layerCounts[i + 1], this.layerCounts[i]);
            this.theta[i] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(this.layerCounts[i + 1]), this.theta[i]);
        }
    }

    public double trainNetwork(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return trainNetwork(doubleMatrix, doubleMatrix2, 0.0d);
    }

    public double trainNetwork(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        int i = doubleMatrix.rows;
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.layerCounts.length];
        doubleMatrixArr[0] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(doubleMatrix.rows), doubleMatrix);
        for (int i2 = 1; i2 < this.layerCounts.length - 1; i2++) {
            DoubleMatrix mmul = doubleMatrixArr[i2 - 1].mmul(this.theta[i2 - 1].transpose());
            doubleMatrixArr[i2] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(mmul.rows), this.act.activate(mmul));
        }
        doubleMatrixArr[this.layerCounts.length - 1] = this.act.activate(doubleMatrixArr[this.layerCounts.length - 2].mmul(this.theta[this.layerCounts.length - 2].transpose()));
        double cost = cost(doubleMatrix, doubleMatrix2, this.theta, i, d);
        DoubleMatrix rows = DoubleMatrix.eye(this.outputRange).getRows(doubleMatrix2.toIntArray());
        DoubleMatrix[] doubleMatrixArr2 = new DoubleMatrix[this.layerCounts.length];
        doubleMatrixArr2[this.layerCounts.length - 1] = doubleMatrixArr[this.layerCounts.length - 1].sub(rows);
        for (int length = this.layerCounts.length - 2; length > 0; length--) {
            DoubleMatrix mul = doubleMatrixArr[length].mul(doubleMatrixArr[length].rsub(1.0d));
            doubleMatrixArr2[length] = doubleMatrixArr2[length + 1].mmul(this.theta[length].getRange(0, this.theta[length].rows, 1, this.theta[1].columns)).mul(mul.getRange(0, mul.rows, 1, mul.columns));
        }
        for (int i3 = 0; i3 < this.theta.length; i3++) {
            this.theta[i3] = this.theta[i3].sub(doubleMatrixArr2[i3 + 1].transpose().mmul(doubleMatrixArr[i3]).mul(1.0d / i).add(this.theta[i3].mulColumn(0, 0.0d).mul(d / i)).mulColumn(0, 0.0d));
        }
        return cost;
    }

    private double cost(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix[] doubleMatrixArr, int i, double d) {
        DoubleMatrix predict = predict(doubleMatrix, doubleMatrixArr);
        DoubleMatrix mul = DoubleMatrix.eye(this.outputRange).getRows(doubleMatrix2.toIntArray()).mul(-1.0d);
        DoubleMatrix doubleMatrix3 = new DoubleMatrix();
        doubleMatrix3.copy(mul);
        DoubleMatrix mul2 = doubleMatrix3.mul(MatrixFunctions.log(predict));
        DoubleMatrix doubleMatrix4 = new DoubleMatrix();
        doubleMatrix4.copy(mul);
        double sum = (1.0d / i) * mul2.sub(doubleMatrix4.add(1.0d).mul(MatrixFunctions.log(predict.rsub(1.0d)))).sum();
        double d2 = 0.0d;
        for (DoubleMatrix doubleMatrix5 : doubleMatrixArr) {
            DoubleMatrix doubleMatrix6 = new DoubleMatrix();
            doubleMatrix6.copy(doubleMatrix5);
            doubleMatrix6.mulColumn(0, 0.0d);
            d2 += MatrixFunctions.pow(doubleMatrix6, 2.0d).sum();
        }
        return sum + (d2 * (d / (2.0d * i)));
    }

    public DoubleMatrix predict(DoubleMatrix doubleMatrix) {
        return predict(doubleMatrix, this.theta);
    }

    private DoubleMatrix predict(DoubleMatrix doubleMatrix, DoubleMatrix[] doubleMatrixArr) {
        DoubleMatrix[] doubleMatrixArr2 = new DoubleMatrix[this.layerCounts.length];
        doubleMatrixArr2[0] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(doubleMatrix.rows), doubleMatrix);
        for (int i = 1; i < this.layerCounts.length - 1; i++) {
            DoubleMatrix mmul = doubleMatrixArr2[i - 1].mmul(doubleMatrixArr[i - 1].transpose());
            doubleMatrixArr2[i] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(mmul.rows), this.act.activate(mmul));
        }
        return this.act.activate(doubleMatrixArr2[this.layerCounts.length - 2].mmul(doubleMatrixArr[this.layerCounts.length - 2].transpose()));
    }

    public String toString() {
        String str = "[ ";
        for (int i : this.layerCounts) {
            str = String.valueOf(str) + i + " ";
        }
        String str2 = String.valueOf(str) + "]";
        String str3 = "";
        for (int i2 = 0; i2 < this.theta.length; i2++) {
            str3 = String.valueOf(str3) + "Layer" + (i2 + 1) + " -> " + (i2 + 2) + " \n" + this.theta[i2].toString("%.5f", "", "", " ", "\n") + "\n";
        }
        return "NeuralNetwork -- " + str2 + "\n" + str3;
    }
}
