package smile.regression;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;

/* loaded from: input_file:smile/regression/NeuralNetwork.class */
public class NeuralNetwork implements OnlineRegression<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) NeuralNetwork.class);
    private ActivationFunction activationFunction;
    private int p;
    private Layer[] net;
    private Layer inputLayer;
    private Layer outputLayer;
    private double eta;
    private double alpha;
    private double lambda;

    /* loaded from: input_file:smile/regression/NeuralNetwork$ActivationFunction.class */
    public enum ActivationFunction {
        LOGISTIC_SIGMOID,
        TANH
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/regression/NeuralNetwork$Layer.class */
    public class Layer implements Serializable {
        private static final long serialVersionUID = 1;
        int units;
        double[] output;
        double[] error;
        double[][] weight;
        double[][] delta;

        private Layer() {
        }
    }

    /* loaded from: input_file:smile/regression/NeuralNetwork$Trainer.class */
    public static class Trainer extends RegressionTrainer<double[]> {
        private ActivationFunction activationFunction;
        private int[] numUnits;
        private double eta;
        private double alpha;
        private double lambda;
        private int epochs;

        public Trainer(int... iArr) {
            this(ActivationFunction.LOGISTIC_SIGMOID, iArr);
        }

        public Trainer(ActivationFunction activationFunction, int... iArr) {
            this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
            this.eta = 0.1d;
            this.alpha = 0.0d;
            this.lambda = 0.0d;
            this.epochs = 25;
            int length = iArr.length;
            if (length < 2) {
                throw new IllegalArgumentException("Invalid number of layers: " + length);
            }
            for (int i = 0; i < length; i++) {
                if (iArr[i] < 1) {
                    throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", Integer.valueOf(i + 1), Integer.valueOf(iArr[i])));
                }
            }
            if (iArr[length - 1] != 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", Integer.valueOf(iArr[length - 1])));
            }
            this.activationFunction = activationFunction;
            this.numUnits = iArr;
        }

        public Trainer setLearningRate(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Invalid learning rate: " + d);
            }
            this.eta = d;
            return this;
        }

        public Trainer setMomentum(double d) {
            if (d < 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("Invalid momentum factor: " + d);
            }
            this.alpha = d;
            return this;
        }

        public Trainer setWeightDecay(double d) {
            if (d < 0.0d || d > 0.1d) {
                throw new IllegalArgumentException("Invalid weight decay factor: " + d);
            }
            this.lambda = d;
            return this;
        }

        public Trainer setNumEpochs(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + i);
            }
            this.epochs = i;
            return this;
        }

        @Override // smile.regression.RegressionTrainer
        public NeuralNetwork train(double[][] dArr, double[] dArr2) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this.activationFunction, this.numUnits);
            neuralNetwork.setLearningRate(this.eta);
            neuralNetwork.setMomentum(this.alpha);
            neuralNetwork.setWeightDecay(this.lambda);
            for (int i = 1; i <= this.epochs; i++) {
                neuralNetwork.learn(dArr, dArr2);
                NeuralNetwork.logger.info("Neural network learns epoch {}", Integer.valueOf(i));
            }
            return neuralNetwork;
        }
    }

    public NeuralNetwork(int... iArr) {
        this(ActivationFunction.LOGISTIC_SIGMOID, iArr);
    }

    public NeuralNetwork(ActivationFunction activationFunction, int... iArr) {
        this(activationFunction, 1.0E-4d, 0.9d, iArr);
    }

    public NeuralNetwork(ActivationFunction activationFunction, double d, double d2, int... iArr) {
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = 0.0d;
        this.lambda = 0.0d;
        int length = iArr.length;
        if (length < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + length);
        }
        for (int i = 0; i < length; i++) {
            if (iArr[i] < 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", Integer.valueOf(i + 1), Integer.valueOf(iArr[i])));
            }
        }
        if (iArr[length - 1] != 1) {
            throw new IllegalArgumentException(String.format("Invalid number of units in output layer %d", Integer.valueOf(iArr[length - 1])));
        }
        this.activationFunction = activationFunction;
        this.alpha = d;
        this.lambda = d2;
        this.p = iArr[0];
        this.net = new Layer[length];
        for (int i2 = 0; i2 < length; i2++) {
            this.net[i2] = new Layer();
            this.net[i2].units = iArr[i2];
            this.net[i2].output = new double[iArr[i2] + 1];
            this.net[i2].error = new double[iArr[i2] + 1];
            this.net[i2].output[iArr[i2]] = 1.0d;
        }
        this.inputLayer = this.net[0];
        this.outputLayer = this.net[length - 1];
        for (int i3 = 1; i3 < length; i3++) {
            this.net[i3].weight = new double[iArr[i3]][iArr[i3 - 1] + 1];
            this.net[i3].delta = new double[iArr[i3]][iArr[i3 - 1] + 1];
            double sqrt = 1.0d / Math.sqrt(this.net[i3 - 1].units);
            for (int i4 = 0; i4 < this.net[i3].units; i4++) {
                for (int i5 = 0; i5 <= this.net[i3 - 1].units; i5++) {
                    this.net[i3].weight[i4][i5] = Math.random(-sqrt, sqrt);
                }
            }
        }
    }

    private NeuralNetwork() {
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = 0.0d;
        this.lambda = 0.0d;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork m109clone() {
        NeuralNetwork neuralNetwork = new NeuralNetwork();
        neuralNetwork.activationFunction = this.activationFunction;
        neuralNetwork.p = this.p;
        neuralNetwork.eta = this.eta;
        neuralNetwork.alpha = this.alpha;
        neuralNetwork.lambda = this.lambda;
        int length = this.net.length;
        neuralNetwork.net = new Layer[length];
        for (int i = 0; i < length; i++) {
            neuralNetwork.net[i] = new Layer();
            neuralNetwork.net[i].units = this.net[i].units;
            neuralNetwork.net[i].output = (double[]) this.net[i].output.clone();
            neuralNetwork.net[i].error = (double[]) this.net[i].error.clone();
            if (i > 0) {
                neuralNetwork.net[i].weight = Math.clone(this.net[i].weight);
                neuralNetwork.net[i].delta = Math.clone(this.net[i].delta);
            }
        }
        neuralNetwork.inputLayer = neuralNetwork.net[0];
        neuralNetwork.outputLayer = neuralNetwork.net[length - 1];
        return neuralNetwork;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public void setMomentum(double d) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + d);
        }
        this.alpha = d;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double[][] getWeight(int i) {
        return this.net[i].weight;
    }

    private void setInput(double[] dArr) {
        if (dArr.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.inputLayer.units)));
        }
        System.arraycopy(dArr, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void propagate(Layer layer, Layer layer2) {
        for (int i = 0; i < layer2.units; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 <= layer.units; i2++) {
                d += layer2.weight[i][i2] * layer.output[i2];
            }
            if (layer2 == this.outputLayer) {
                layer2.output[i] = d;
            } else if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.output[i] = Math.logistic(d);
            } else if (this.activationFunction == ActivationFunction.TANH) {
                layer2.output[i] = (2.0d * Math.logistic(2.0d * d)) - 1.0d;
            }
        }
    }

    private void propagate() {
        for (int i = 0; i < this.net.length - 1; i++) {
            propagate(this.net[i], this.net[i + 1]);
        }
    }

    private double computeOutputError(double d) {
        return computeOutputError(d, this.outputLayer.error);
    }

    private double computeOutputError(double d, double[] dArr) {
        double d2 = d - this.outputLayer.output[0];
        double d3 = 0.0d + (0.5d * d2 * d2);
        dArr[0] = d2;
        return d3;
    }

    private void backpropagate(Layer layer, Layer layer2) {
        for (int i = 0; i <= layer2.units; i++) {
            double d = layer2.output[i];
            double d2 = 0.0d;
            for (int i2 = 0; i2 < layer.units; i2++) {
                d2 += layer.weight[i2][i] * layer.error[i2];
            }
            if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.error[i] = d * (1.0d - d) * d2;
            } else if (this.activationFunction == ActivationFunction.TANH) {
                layer2.error[i] = (1.0d - (d * d)) * d2;
            }
        }
    }

    private void backpropagate() {
        int length = this.net.length;
        while (true) {
            length--;
            if (length <= 0) {
                return;
            } else {
                backpropagate(this.net[length], this.net[length - 1]);
            }
        }
    }

    private void adjustWeights() {
        for (int i = 1; i < this.net.length; i++) {
            for (int i2 = 0; i2 < this.net[i].units; i2++) {
                for (int i3 = 0; i3 <= this.net[i - 1].units; i3++) {
                    double d = ((1.0d - this.alpha) * this.eta * this.net[i].error[i2] * this.net[i - 1].output[i3]) + (this.alpha * this.net[i].delta[i2][i3]);
                    this.net[i].delta[i2][i3] = d;
                    double[] dArr = this.net[i].weight[i2];
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + d;
                    if (this.lambda != 0.0d && i3 < this.net[i - 1].units) {
                        double[] dArr2 = this.net[i].weight[i2];
                        int i5 = i3;
                        dArr2[i5] = dArr2[i5] * (1.0d - (this.eta * this.lambda));
                    }
                }
            }
        }
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        setInput(dArr);
        propagate();
        return this.outputLayer.output[0];
    }

    public double learn(double[] dArr, double d, double d2) {
        setInput(dArr);
        propagate();
        double computeOutputError = d2 * computeOutputError(d);
        if (d2 != 1.0d) {
            double[] dArr2 = this.outputLayer.error;
            dArr2[0] = dArr2[0] * d2;
        }
        backpropagate();
        adjustWeights();
        return computeOutputError;
    }

    @Override // smile.regression.OnlineRegression
    public void learn(double[] dArr, double d) {
        learn(dArr, d, 1.0d);
    }

    public void learn(double[][] dArr, double[] dArr2) {
        int length = dArr.length;
        int[] permutate = Math.permutate(length);
        for (int i = 0; i < length; i++) {
            learn(dArr[permutate[i]], dArr2[permutate[i]]);
        }
    }
}
