package game.models.single.neuralNet;

import configuration.models.ModelConfig;
import configuration.models.single.NeuralNetModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.classifiers.single.neuralNet.ActivationFunction;
import game.classifiers.single.neuralNet.InnerNode;
import game.classifiers.single.neuralNet.InputNode;
import game.classifiers.single.neuralNet.Node;
import game.classifiers.single.neuralNet.OutputNode;
import game.classifiers.single.neuralNet.SigmoidFunction;
import game.models.ModelLearnableBase;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;

/* loaded from: input_file:game/models/single/neuralNet/NeuralNetModel.class */
public class NeuralNetModel extends ModelLearnableBase {
    private static final ActivationFunction SIGMOID_FUNCTION = new SigmoidFunction();
    private static final ActivationFunction LINEAR_FUNCTION = new LinearFunction();
    private int firstLayerNeurons;
    private int secondLayerNeurons;
    private int trainingCycles;
    private double learningRate;
    private boolean decay;
    private double momentum;
    private boolean normalize;
    private double errorEpsilon;
    private double tempLearningRate;
    private double[] range;
    private double[] offset;
    private InputNode[] inputNodes = new InputNode[0];
    private InnerNode[] innerNodes = new InnerNode[0];
    private OutputNode[] outputNodes = new OutputNode[0];
    private int outputs = 1;

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        NeuralNetModelConfig neuralNetModelConfig = (NeuralNetModelConfig) modelConfig;
        this.learningRate = neuralNetModelConfig.getLearningRate();
        this.momentum = neuralNetModelConfig.getMomentum();
        this.firstLayerNeurons = neuralNetModelConfig.getFirstLayerNeurons();
        this.secondLayerNeurons = neuralNetModelConfig.getSecondLayerNeurons();
        this.trainingCycles = neuralNetModelConfig.getTrainingCycles();
        this.decay = neuralNetModelConfig.getDecay();
        this.normalize = neuralNetModelConfig.getNormalize();
        this.errorEpsilon = neuralNetModelConfig.getErrorEpsilon();
        this.tempLearningRate = this.learningRate;
    }

    @Override // game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        NeuralNetModelConfig neuralNetModelConfig = (NeuralNetModelConfig) super.getConfig();
        neuralNetModelConfig.setLearningRate(this.learningRate);
        neuralNetModelConfig.setMomentum(this.momentum);
        neuralNetModelConfig.setFirstLayerNeurons(this.firstLayerNeurons);
        neuralNetModelConfig.setSecondLayerNeurons(this.secondLayerNeurons);
        neuralNetModelConfig.setTrainingCycles(this.trainingCycles);
        neuralNetModelConfig.setDecay(this.decay);
        neuralNetModelConfig.setNormalize(this.normalize);
        neuralNetModelConfig.setErrorEpsilon(this.errorEpsilon);
        return neuralNetModelConfig;
    }

    private void normalizeLearnData() {
        this.range = new double[this.inputsNumber];
        this.offset = new double[this.inputsNumber];
        for (int i = 0; i < this.inputsNumber; i++) {
            double d = this.inputVect[0][i];
            double d2 = this.inputVect[0][i];
            for (int i2 = 1; i2 < this.learning_vectors; i2++) {
                if (this.inputVect[i2][i] < d) {
                    d = this.inputVect[i2][i];
                }
                if (this.inputVect[i2][i] > d2) {
                    d2 = this.inputVect[i2][i];
                }
            }
            this.range[i] = (d2 - d) / 2.0d;
            if (this.range[i] == 0.0d) {
                this.range[i] = 1.0d;
            }
            this.offset[i] = (d2 + d) / 2.0d;
            for (int i3 = 0; i3 < this.learning_vectors; i3++) {
                this.inputVect[i3][i] = (this.inputVect[i3][i] - this.offset[i]) / this.range[i];
            }
        }
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        if (this.normalize) {
            normalizeLearnData();
        }
        this.learned = true;
        Random random = new Random();
        initInputLayer();
        initOutputLayer(random);
        initHiddenLayers(random);
        shuffleLearnData(random);
        for (int i = 0; i < this.trainingCycles; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.learning_vectors; i2++) {
                calculateValue(this.inputVect[i2]);
                double d2 = this.tempLearningRate;
                if (this.decay) {
                    d2 /= i + 1;
                }
                d += calculateError(this.inputVect[i2], this.target[i2]);
                update(d2, this.momentum);
            }
            if (d / this.learning_vectors < this.errorEpsilon) {
                return;
            }
        }
    }

    private void shuffleLearnData(Random random) {
        ArrayList arrayList = new ArrayList(this.learning_vectors);
        for (int i = 0; i < this.learning_vectors; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.shuffle(arrayList, random);
        double[][] dArr = new double[this.inputVect.length][this.inputVect[0].length];
        double[] dArr2 = new double[this.target.length];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            dArr[i2] = this.inputVect[((Integer) arrayList.get(i2)).intValue()];
            dArr2[i2] = this.target[((Integer) arrayList.get(i2)).intValue()];
        }
        this.target = dArr2;
        this.inputVect = dArr;
    }

    private void update(double d, double d2) {
        for (int i = 0; i < this.outputs; i++) {
            this.innerNodes[i].update(d, d2);
        }
        for (int length = this.innerNodes.length - 1; length >= this.outputs; length--) {
            this.innerNodes[length].update(d, d2);
        }
    }

    private double calculateError(double[] dArr, double d) {
        double d2 = 0.0d;
        for (int i = 0; i < this.outputNodes.length; i++) {
            double calculateError = this.outputNodes[i].calculateError(dArr, d);
            d2 += calculateError * calculateError;
        }
        for (int i2 = 0; i2 < this.outputs; i2++) {
            this.innerNodes[i2].calculateError(dArr, d);
        }
        for (int length = this.innerNodes.length - 1; length >= this.outputs; length--) {
            this.innerNodes[length].calculateError(dArr, d);
        }
        for (int i3 = 0; i3 < this.inputNodes.length; i3++) {
            this.inputNodes[i3].calculateError(dArr, d);
        }
        return d2;
    }

    private void calculateValue(double[] dArr) {
        for (int i = 0; i < this.inputNodes.length; i++) {
            this.inputNodes[i].calculateValue(dArr);
        }
        for (int i2 = this.outputs; i2 < this.innerNodes.length; i2++) {
            this.innerNodes[i2].calculateValue(dArr);
        }
        for (int i3 = 0; i3 < this.outputs; i3++) {
            this.innerNodes[i3].calculateValue(dArr);
        }
        for (int i4 = 0; i4 < this.outputNodes.length; i4++) {
            this.outputNodes[i4].calculateValue(dArr);
        }
    }

    private void initInputLayer() {
        this.inputNodes = new InputNode[this.inputsNumber];
        for (int i = 0; i < this.inputsNumber; i++) {
            this.inputNodes[i] = new InputNode();
            this.inputNodes[i].setAttribute(i);
        }
    }

    private void initOutputLayer(Random random) {
        this.outputNodes = new OutputNode[this.outputs];
        for (int i = 0; i < this.outputNodes.length; i++) {
            this.outputNodes[i] = new OutputNode(i);
            InnerNode innerNode = new InnerNode(-2, random, SIGMOID_FUNCTION);
            addNode(innerNode);
            Node.connect(innerNode, this.outputNodes[i]);
        }
    }

    private void addNode(InnerNode innerNode) {
        InnerNode[] innerNodeArr = new InnerNode[this.innerNodes.length + 1];
        System.arraycopy(this.innerNodes, 0, innerNodeArr, 0, this.innerNodes.length);
        innerNodeArr[innerNodeArr.length - 1] = innerNode;
        this.innerNodes = innerNodeArr;
    }

    private void initHiddenLayers(Random random) {
        int i = this.firstLayerNeurons != 0 ? 0 + 1 : 0;
        if (this.secondLayerNeurons != 0) {
            i++;
        }
        if (i == 0) {
            i = 1;
        }
        int[] iArr = new int[i];
        int i2 = 0;
        if (this.firstLayerNeurons != 0) {
            i2 = 0 + 1;
            iArr[0] = this.firstLayerNeurons;
        }
        if (this.secondLayerNeurons != 0) {
            iArr[i2] = this.secondLayerNeurons;
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3] == -1) {
                iArr[i3] = ((int) Math.round((this.inputsNumber + this.outputs) / 2.0d)) + 1;
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < iArr.length; i5++) {
            int i6 = iArr[i5];
            for (int i7 = 0; i7 < i6; i7++) {
                InnerNode innerNode = new InnerNode(i5, random, SIGMOID_FUNCTION);
                addNode(innerNode);
                if (i5 > 0) {
                    for (int length = ((this.innerNodes.length - i7) - 1) - i4; length < (this.innerNodes.length - i7) - 1; length++) {
                        Node.connect(this.innerNodes[length], innerNode);
                    }
                }
            }
            i4 = i6;
        }
        int i8 = iArr[0];
        if (i8 == 0) {
            for (int i9 = 0; i9 < this.inputsNumber; i9++) {
                for (int i10 = 0; i10 < this.outputs; i10++) {
                    Node.connect(this.inputNodes[i9], this.innerNodes[i10]);
                }
            }
            return;
        }
        for (int i11 = 0; i11 < this.inputsNumber; i11++) {
            for (int i12 = this.outputs; i12 < this.outputs + i8; i12++) {
                Node.connect(this.inputNodes[i11], this.innerNodes[i12]);
            }
        }
        for (int length2 = this.innerNodes.length - i4; length2 < this.innerNodes.length; length2++) {
            for (int i13 = 0; i13 < this.outputs; i13++) {
                Node.connect(this.innerNodes[length2], this.innerNodes[i13]);
            }
        }
    }

    private double outputValue(double[] dArr) {
        for (int i = 0; i < this.inputNodes.length; i++) {
            this.inputNodes[i].calculateValue(dArr);
        }
        for (int i2 = this.outputs; i2 < this.innerNodes.length; i2++) {
            this.innerNodes[i2].calculateValue(dArr);
        }
        for (int i3 = 0; i3 < this.outputs; i3++) {
            this.innerNodes[i3].calculateValue(dArr);
        }
        return this.outputNodes[0].calculateValue(dArr);
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        if (!this.normalize) {
            return outputValue(dArr);
        }
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < this.inputsNumber; i++) {
            dArr2[i] = (dArr[i] - this.offset[i]) / this.range[i];
        }
        return outputValue(dArr2);
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return NeuralNetModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append("NeuralNet.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        assignVariableNames();
        if (this.normalize) {
            normalizeInput(sb);
        } else {
            sb.append("double * inp = input;\n");
        }
        sb.append("double res;\n");
        for (int i = this.outputs; i < this.innerNodes.length; i++) {
            this.innerNodes[i].toCCode(sb);
        }
        for (int i2 = 0; i2 < this.outputs; i2++) {
            this.innerNodes[i2].toCCode(sb);
        }
        sb.append("return res;\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }

    private void normalizeInput(StringBuilder sb) {
        CCodeUtils.convertArray(this.range, "range", sb);
        CCodeUtils.convertArray(this.offset, "offset", sb);
        sb.append("double inp[").append(this.inputsNumber).append("];\n");
        sb.append("for(int i = 0; i < ").append(this.inputsNumber).append("; i++){\n");
        sb.append("inp[i] = (input[i]-offset[i])/range[i];\n");
        sb.append("}\n");
    }

    private void assignVariableNames() {
        for (int i = 0; i < this.inputNodes.length; i++) {
            this.inputNodes[i].setVariableName("inp[" + i + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        }
        for (int i2 = 1; i2 < this.innerNodes.length; i2++) {
            this.innerNodes[i2].setVariableName("n" + i2);
        }
        this.innerNodes[0].setVariableName("res");
    }
}
