package smile.classification;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;

/* loaded from: input_file:smile/classification/NeuralNetwork.class */
public class NeuralNetwork implements OnlineClassifier<double[]>, SoftClassifier<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) NeuralNetwork.class);
    private ErrorFunction errorFunction;
    private ActivationFunction activationFunction;
    private int p;
    private int k;
    private Layer[] net;
    private Layer inputLayer;
    private Layer outputLayer;
    private double eta;
    private double alpha;
    private double lambda;
    private double[] target;

    /* loaded from: input_file:smile/classification/NeuralNetwork$ActivationFunction.class */
    public enum ActivationFunction {
        LINEAR,
        LOGISTIC_SIGMOID,
        SOFTMAX
    }

    /* loaded from: input_file:smile/classification/NeuralNetwork$ErrorFunction.class */
    public enum ErrorFunction {
        LEAST_MEAN_SQUARES,
        CROSS_ENTROPY
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/classification/NeuralNetwork$Layer.class */
    public class Layer implements Serializable {
        private static final long serialVersionUID = 1;
        int units;
        double[] output;
        double[] error;
        double[][] weight;
        double[][] delta;

        private Layer() {
        }
    }

    /* loaded from: input_file:smile/classification/NeuralNetwork$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private ErrorFunction errorFunction;
        private ActivationFunction activationFunction;
        private int[] numUnits;
        private double eta;
        private double alpha;
        private double lambda;
        private int epochs;

        public Trainer(ErrorFunction errorFunction, int... iArr) {
            this(errorFunction, NeuralNetwork.natural(errorFunction, iArr[iArr.length - 1]), iArr);
        }

        public Trainer(ErrorFunction errorFunction, ActivationFunction activationFunction, int... iArr) {
            this.errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
            this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
            this.eta = 0.1d;
            this.alpha = 0.0d;
            this.lambda = 0.0d;
            this.epochs = 25;
            int length = iArr.length;
            if (length < 2) {
                throw new IllegalArgumentException("Invalid number of layers: " + length);
            }
            for (int i = 0; i < length; i++) {
                if (iArr[i] < 1) {
                    throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", Integer.valueOf(i + 1), Integer.valueOf(iArr[i])));
                }
            }
            if (errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && activationFunction == ActivationFunction.SOFTMAX) {
                throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error.");
            }
            if (errorFunction == ErrorFunction.CROSS_ENTROPY) {
                if (activationFunction == ActivationFunction.LINEAR) {
                    throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error.");
                }
                if (activationFunction == ActivationFunction.SOFTMAX && iArr[length - 1] == 1) {
                    throw new IllegalArgumentException("Softmax activation function is for multi-class.");
                }
                if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID && iArr[length - 1] != 1) {
                    throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification.");
                }
            }
            this.errorFunction = errorFunction;
            this.activationFunction = activationFunction;
            this.numUnits = iArr;
        }

        public Trainer setLearningRate(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Invalid learning rate: " + d);
            }
            this.eta = d;
            return this;
        }

        public Trainer setMomentum(double d) {
            if (d < 0.0d || d >= 1.0d) {
                throw new IllegalArgumentException("Invalid momentum factor: " + d);
            }
            this.alpha = d;
            return this;
        }

        public Trainer setWeightDecay(double d) {
            if (d < 0.0d || d > 0.1d) {
                throw new IllegalArgumentException("Invalid weight decay factor: " + d);
            }
            this.lambda = d;
            return this;
        }

        public Trainer setNumEpochs(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid numer of epochs of stochastic learning:" + i);
            }
            this.epochs = i;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public NeuralNetwork train(double[][] dArr, int[] iArr) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this.errorFunction, this.activationFunction, this.numUnits);
            neuralNetwork.setLearningRate(this.eta);
            neuralNetwork.setMomentum(this.alpha);
            neuralNetwork.setWeightDecay(this.lambda);
            for (int i = 1; i <= this.epochs; i++) {
                neuralNetwork.learn(dArr, iArr);
                NeuralNetwork.logger.info("Neural network learns epoch {}", Integer.valueOf(i));
            }
            return neuralNetwork;
        }
    }

    public NeuralNetwork(ErrorFunction errorFunction, int... iArr) {
        this(errorFunction, natural(errorFunction, iArr[iArr.length - 1]), iArr);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ActivationFunction natural(ErrorFunction errorFunction, int i) {
        if (errorFunction == ErrorFunction.CROSS_ENTROPY && i != 1) {
            return ActivationFunction.SOFTMAX;
        }
        return ActivationFunction.LOGISTIC_SIGMOID;
    }

    public NeuralNetwork(ErrorFunction errorFunction, ActivationFunction activationFunction, int... iArr) {
        this.errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = 0.0d;
        this.lambda = 0.0d;
        int length = iArr.length;
        if (length < 2) {
            throw new IllegalArgumentException("Invalid number of layers: " + length);
        }
        for (int i = 0; i < length; i++) {
            if (iArr[i] < 1) {
                throw new IllegalArgumentException(String.format("Invalid number of units of layer %d: %d", Integer.valueOf(i + 1), Integer.valueOf(iArr[i])));
            }
        }
        if (errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && activationFunction == ActivationFunction.SOFTMAX) {
            throw new IllegalArgumentException("Sofmax activation function is invalid for least mean squares error.");
        }
        if (errorFunction == ErrorFunction.CROSS_ENTROPY) {
            if (activationFunction == ActivationFunction.LINEAR) {
                throw new IllegalArgumentException("Linear activation function is invalid with cross entropy error.");
            }
            if (activationFunction == ActivationFunction.SOFTMAX && iArr[length - 1] == 1) {
                throw new IllegalArgumentException("Softmax activation function is for multi-class.");
            }
            if (activationFunction == ActivationFunction.LOGISTIC_SIGMOID && iArr[length - 1] != 1) {
                throw new IllegalArgumentException("For cross entropy error, logistic sigmoid output is for binary classification.");
            }
        }
        this.errorFunction = errorFunction;
        this.activationFunction = activationFunction;
        if (errorFunction == ErrorFunction.CROSS_ENTROPY) {
            this.alpha = 0.0d;
            this.lambda = 0.0d;
        }
        this.p = iArr[0];
        this.k = iArr[length - 1] == 1 ? 2 : iArr[length - 1];
        this.target = new double[iArr[length - 1]];
        this.net = new Layer[length];
        for (int i2 = 0; i2 < length; i2++) {
            this.net[i2] = new Layer();
            this.net[i2].units = iArr[i2];
            this.net[i2].output = new double[iArr[i2] + 1];
            this.net[i2].error = new double[iArr[i2] + 1];
            this.net[i2].output[iArr[i2]] = 1.0d;
        }
        this.inputLayer = this.net[0];
        this.outputLayer = this.net[length - 1];
        for (int i3 = 1; i3 < length; i3++) {
            this.net[i3].weight = new double[iArr[i3]][iArr[i3 - 1] + 1];
            this.net[i3].delta = new double[iArr[i3]][iArr[i3 - 1] + 1];
            double sqrt = 1.0d / Math.sqrt(this.net[i3 - 1].units);
            for (int i4 = 0; i4 < this.net[i3].units; i4++) {
                for (int i5 = 0; i5 <= this.net[i3 - 1].units; i5++) {
                    this.net[i3].weight[i4][i5] = Math.random(-sqrt, sqrt);
                }
            }
        }
    }

    private NeuralNetwork() {
        this.errorFunction = ErrorFunction.LEAST_MEAN_SQUARES;
        this.activationFunction = ActivationFunction.LOGISTIC_SIGMOID;
        this.eta = 0.1d;
        this.alpha = 0.0d;
        this.lambda = 0.0d;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork m54clone() {
        NeuralNetwork neuralNetwork = new NeuralNetwork();
        neuralNetwork.errorFunction = this.errorFunction;
        neuralNetwork.activationFunction = this.activationFunction;
        neuralNetwork.p = this.p;
        neuralNetwork.k = this.k;
        neuralNetwork.eta = this.eta;
        neuralNetwork.alpha = this.alpha;
        neuralNetwork.lambda = this.lambda;
        neuralNetwork.target = (double[]) this.target.clone();
        int length = this.net.length;
        neuralNetwork.net = new Layer[length];
        for (int i = 0; i < length; i++) {
            neuralNetwork.net[i] = new Layer();
            neuralNetwork.net[i].units = this.net[i].units;
            neuralNetwork.net[i].output = (double[]) this.net[i].output.clone();
            neuralNetwork.net[i].error = (double[]) this.net[i].error.clone();
            if (i > 0) {
                neuralNetwork.net[i].weight = Math.clone(this.net[i].weight);
                neuralNetwork.net[i].delta = Math.clone(this.net[i].delta);
            }
        }
        neuralNetwork.inputLayer = neuralNetwork.net[0];
        neuralNetwork.outputLayer = neuralNetwork.net[length - 1];
        return neuralNetwork;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public void setMomentum(double d) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + d);
        }
        this.alpha = d;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double[][] getWeight(int i) {
        return this.net[i].weight;
    }

    private void setInput(double[] dArr) {
        if (dArr.length != this.inputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.inputLayer.units)));
        }
        System.arraycopy(dArr, 0, this.inputLayer.output, 0, this.inputLayer.units);
    }

    private void getOutput(double[] dArr) {
        if (dArr.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.outputLayer.units)));
        }
        System.arraycopy(this.outputLayer.output, 0, dArr, 0, this.outputLayer.units);
    }

    private void propagate(Layer layer, Layer layer2) {
        for (int i = 0; i < layer2.units; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 <= layer.units; i2++) {
                d += layer2.weight[i][i2] * layer.output[i2];
            }
            if (layer2 != this.outputLayer || this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                layer2.output[i] = Math.logistic(d);
            } else {
                if (this.activationFunction != ActivationFunction.LINEAR && this.activationFunction != ActivationFunction.SOFTMAX) {
                    throw new UnsupportedOperationException("Unsupported activation function.");
                }
                layer2.output[i] = d;
            }
        }
        if (layer2 == this.outputLayer && this.activationFunction == ActivationFunction.SOFTMAX) {
            softmax();
        }
    }

    private void softmax() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.outputLayer.units; i++) {
            if (this.outputLayer.output[i] > d) {
                d = this.outputLayer.output[i];
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            double exp = Math.exp(this.outputLayer.output[i2] - d);
            this.outputLayer.output[i2] = exp;
            d2 += exp;
        }
        for (int i3 = 0; i3 < this.outputLayer.units; i3++) {
            double[] dArr = this.outputLayer.output;
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
    }

    private void propagate() {
        for (int i = 0; i < this.net.length - 1; i++) {
            propagate(this.net[i], this.net[i + 1]);
        }
    }

    private static double log(double d) {
        return d < 1.0E-300d ? -690.7755d : Math.log(d);
    }

    private double computeOutputError(double[] dArr) {
        return computeOutputError(dArr, this.outputLayer.error);
    }

    private double computeOutputError(double[] dArr, double[] dArr2) {
        if (dArr.length != this.outputLayer.units) {
            throw new IllegalArgumentException(String.format("Invalid output vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.outputLayer.units)));
        }
        double d = 0.0d;
        for (int i = 0; i < this.outputLayer.units; i++) {
            double d2 = this.outputLayer.output[i];
            double d3 = dArr[i] - d2;
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES) {
                d += 0.5d * d3 * d3;
            } else if (this.errorFunction == ErrorFunction.CROSS_ENTROPY) {
                if (this.activationFunction == ActivationFunction.SOFTMAX) {
                    d -= dArr[i] * log(d2);
                } else if (this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                    d = ((-dArr[i]) * log(d2)) - ((1.0d - dArr[i]) * log(1.0d - d2));
                }
            }
            if (this.errorFunction == ErrorFunction.LEAST_MEAN_SQUARES && this.activationFunction == ActivationFunction.LOGISTIC_SIGMOID) {
                d3 *= d2 * (1.0d - d2);
            }
            dArr2[i] = d3;
        }
        return d;
    }

    private void backpropagate(Layer layer, Layer layer2) {
        for (int i = 0; i <= layer2.units; i++) {
            double d = layer2.output[i];
            double d2 = 0.0d;
            for (int i2 = 0; i2 < layer.units; i2++) {
                d2 += layer.weight[i2][i] * layer.error[i2];
            }
            layer2.error[i] = d * (1.0d - d) * d2;
        }
    }

    private void backpropagate() {
        int length = this.net.length;
        while (true) {
            length--;
            if (length <= 0) {
                return;
            } else {
                backpropagate(this.net[length], this.net[length - 1]);
            }
        }
    }

    private void adjustWeights() {
        for (int i = 1; i < this.net.length; i++) {
            for (int i2 = 0; i2 < this.net[i].units; i2++) {
                for (int i3 = 0; i3 <= this.net[i - 1].units; i3++) {
                    double d = ((1.0d - this.alpha) * this.eta * this.net[i].error[i2] * this.net[i - 1].output[i3]) + (this.alpha * this.net[i].delta[i2][i3]);
                    this.net[i].delta[i2][i3] = d;
                    double[] dArr = this.net[i].weight[i2];
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + d;
                    if (this.lambda != 0.0d && i3 < this.net[i - 1].units) {
                        double[] dArr2 = this.net[i].weight[i2];
                        int i5 = i3;
                        dArr2[i5] = dArr2[i5] * (1.0d - (this.eta * this.lambda));
                    }
                }
            }
        }
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        setInput(dArr);
        propagate();
        getOutput(dArr2);
        if (this.outputLayer.units == 1) {
            return this.outputLayer.output[0] > 0.5d ? 0 : 1;
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            if (this.outputLayer.output[i2] > d) {
                d = this.outputLayer.output[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        setInput(dArr);
        propagate();
        if (this.outputLayer.units == 1) {
            return this.outputLayer.output[0] > 0.5d ? 0 : 1;
        }
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.outputLayer.units; i2++) {
            if (this.outputLayer.output[i2] > d) {
                d = this.outputLayer.output[i2];
                i = i2;
            }
        }
        return i;
    }

    public double learn(double[] dArr, double[] dArr2, double d) {
        setInput(dArr);
        propagate();
        double computeOutputError = d * computeOutputError(dArr2);
        if (d != 1.0d) {
            for (int i = 0; i < this.outputLayer.units; i++) {
                double[] dArr3 = this.outputLayer.error;
                int i2 = i;
                dArr3[i2] = dArr3[i2] * d;
            }
        }
        backpropagate();
        adjustWeights();
        return computeOutputError;
    }

    @Override // smile.classification.OnlineClassifier
    public void learn(double[] dArr, int i) {
        learn(dArr, i, 1.0d);
    }

    public void learn(double[] dArr, int i, double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid weight: " + d);
        }
        if (d == 0.0d) {
            logger.info("Ignore the training instance with zero weight.");
            return;
        }
        if (i < 0) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        if (this.outputLayer.units == 1 && i > 1) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        if (this.outputLayer.units > 1 && i >= this.outputLayer.units) {
            throw new IllegalArgumentException("Invalid class label: " + i);
        }
        if (this.errorFunction != ErrorFunction.CROSS_ENTROPY) {
            for (int i2 = 0; i2 < this.target.length; i2++) {
                this.target[i2] = 0.1d;
            }
            this.target[i] = 0.9d;
        } else if (this.activationFunction != ActivationFunction.LOGISTIC_SIGMOID) {
            for (int i3 = 0; i3 < this.target.length; i3++) {
                this.target[i3] = 0.0d;
            }
            this.target[i] = 1.0d;
        } else if (i == 0) {
            this.target[0] = 1.0d;
        } else {
            this.target[0] = 0.0d;
        }
        learn(dArr, this.target, d);
    }

    public void learn(double[][] dArr, int[] iArr) {
        int length = dArr.length;
        int[] permutate = Math.permutate(length);
        for (int i = 0; i < length; i++) {
            learn(dArr[permutate[i]], iArr[permutate[i]]);
        }
    }
}
