package game.models.single;

import configuration.models.ModelConfig;
import configuration.models.single.PolynomialModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.models.Element;
import game.trainers.GradientTrainable;
import game.utils.GlobalRandom;
import java.util.ArrayList;
import java.util.Iterator;
import weka.core.matrix.Matrix;

/* loaded from: input_file:game/models/single/PolynomialModel.class */
public class PolynomialModel extends SingleModel {
    protected ArrayList<Element> elements;
    protected double bias;
    protected transient double[][] elemVal;
    protected int maxdeg;

    @Override // game.models.single.SingleModel, game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        this.maxdeg = ((PolynomialModelConfig) modelConfig).getMaxDegree();
        super.init(modelConfig);
    }

    public int getCoefsNumber() {
        return this.elements.size() + 1;
    }

    @Override // game.models.single.SingleModel
    protected double getOutputWith(double[] dArr, double[] dArr2) {
        int size = this.elements.size();
        double d = dArr2[size];
        for (int i = 0; i < size; i++) {
            double d2 = 1.0d;
            boolean z = false;
            for (int i2 = 0; i2 < this.inputsNumber; i2++) {
                int index = this.elements.get(i).getIndex(i2);
                if (index > 0) {
                    for (int i3 = 0; i3 < index; i3++) {
                        z = true;
                        d2 *= dArr[i2];
                    }
                }
            }
            d += z ? dArr2[i] * d2 : 0.0d;
        }
        return d;
    }

    @Override // game.models.single.SingleModel, game.models.Model
    public double getOutput(double[] dArr) {
        int size = this.elements.size();
        double d = this.bias;
        for (int i = 0; i < size; i++) {
            double coefficent = this.elements.get(i).getCoefficent();
            double d2 = 1.0d;
            boolean z = false;
            for (int i2 = 0; i2 < this.inputsNumber; i2++) {
                int index = this.elements.get(i).getIndex(i2);
                if (index > 0) {
                    for (int i3 = 0; i3 < index; i3++) {
                        z = true;
                        d2 *= dArr[i2];
                    }
                }
            }
            d += z ? coefficent * d2 : 0.0d;
        }
        return d;
    }

    @Override // game.models.single.SingleModel, game.models.ModelLearnable
    public void learn() {
        precomputeElements();
        if (!estimateCoefficientsUsingLMSMethod()) {
            estimateCoefficientsUsingDefaultTrainer(this);
        }
        postLearnActions();
    }

    @Override // game.models.single.SingleModel, game.models.ModelLearnableBase, game.models.ModelLearnable
    public void setInputsNumber(int i) {
        GlobalRandom globalRandom = GlobalRandom.getInstance();
        int[] iArr = new int[i];
        int[] iArr2 = new int[i];
        this.elements = new ArrayList<>();
        for (int i2 = 1; i2 <= this.maxdeg; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < i; i4++) {
                    if (i4 == i3) {
                        iArr2[i4] = 1;
                        iArr[i4] = i2;
                    } else {
                        iArr2[i4] = 0;
                        iArr[i4] = 0;
                    }
                }
                this.elements.add(new Element(globalRandom.nextDouble() - 0.5d, iArr2, iArr, i));
            }
        }
        this.coef = this.elements.size() + 1;
        this.elemVal = new double[this.maxLearningVectors][this.coef];
        super.setInputsNumber(i);
    }

    void precomputeElements() {
        for (int i = 0; i < this.learning_vectors; i++) {
            for (int i2 = 0; i2 < this.coef - 1; i2++) {
                this.elemVal[i][i2] = getElementValueForVector(i, i2);
            }
        }
    }

    double getElementValueForVector(int i, int i2) {
        boolean z = false;
        double d = 1.0d;
        for (int i3 = 0; i3 < this.inputsNumber; i3++) {
            int index = this.elements.get(i2).getIndex(i3);
            if (index > 0) {
                z = true;
                for (int i4 = 0; i4 < index; i4++) {
                    d *= this.inputVect[i][i3];
                }
            }
        }
        if (z) {
            return d;
        }
        return 0.0d;
    }

    @Override // game.models.single.SingleModel, game.trainers.GradientTrainable
    public boolean gradient(double[] dArr, double[] dArr2) {
        if (this.elemVal == null) {
            return false;
        }
        for (int i = 0; i < this.coef; i++) {
            dArr2[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.learning_vectors; i2++) {
            if (!this.inValidationSet[i2]) {
                double d = dArr[this.coef - 1];
                for (int i3 = 0; i3 < this.coef - 1; i3++) {
                    d += dArr[i3] * this.elemVal[i2][i3];
                }
                double d2 = (d - this.target[i2]) * 2.0d;
                for (int i4 = 0; i4 < this.coef - 1; i4++) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + (d2 * this.elemVal[i2][i4]);
                }
                int i6 = this.coef - 1;
                dArr2[i6] = dArr2[i6] + d2;
            }
        }
        return true;
    }

    @Override // game.models.single.SingleModel, game.trainers.GradientTrainable
    public double getError(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.learning_vectors; i++) {
            if (!this.inValidationSet[i]) {
                double d2 = dArr[this.coef - 1];
                for (int i2 = 0; i2 < this.coef - 1; i2++) {
                    d2 += dArr[i2] * this.elemVal[i][i2];
                }
                double d3 = d2 - this.target[i];
                d += d3 * d3;
            }
        }
        return d;
    }

    @Override // game.models.single.SingleModel, game.trainers.GradientTrainable
    public double getValidationError(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.learning_vectors; i++) {
            if (this.inValidationSet[i]) {
                double d2 = dArr[this.coef - 1];
                for (int i2 = 0; i2 < this.coef - 1; i2++) {
                    d2 += dArr[i2] * this.elemVal[i][i2];
                }
                double d3 = d2 - this.target[i];
                d += d3 * d3;
            }
        }
        return d;
    }

    private boolean estimateCoefficientsUsingLMSMethod() {
        int i = 0;
        Iterator<Element> it = this.elements.iterator();
        while (it.hasNext()) {
            if (it.next().isEnabled()) {
                i++;
            }
        }
        double[][] dArr = new double[this.learning_vectors][i + 1];
        double[][] dArr2 = new double[this.learning_vectors][1];
        for (int i2 = 0; i2 < this.learning_vectors; i2++) {
            dArr[i2][0] = 1.0d;
        }
        int i3 = 1;
        for (int i4 = 0; i4 < this.elements.size(); i4++) {
            if (this.elements.get(i4).isEnabled()) {
                for (int i5 = 0; i5 < this.learning_vectors; i5++) {
                    dArr[i5][i3] = this.elemVal[i5][i4];
                }
                i3++;
            }
        }
        for (int i6 = 0; i6 < this.learning_vectors; i6++) {
            dArr2[i6][0] = this.target[i6];
        }
        Matrix matrix = new Matrix(dArr);
        Matrix transpose = matrix.transpose();
        Matrix matrix2 = new Matrix(dArr2);
        Matrix times = transpose.times(matrix);
        if (times.det() == 0.0d) {
            return false;
        }
        try {
            double[][] array = times.inverse().times(transpose.times(matrix2)).getArray();
            int i7 = 1;
            Iterator<Element> it2 = this.elements.iterator();
            while (it2.hasNext()) {
                Element next = it2.next();
                if (next.isEnabled()) {
                    int i8 = i7;
                    i7++;
                    next.setCoefficient(array[i8][0]);
                }
            }
            this.bias = array[0][0];
            this.trainedBy = "LMS method";
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    void estimateCoefficientsUsingDefaultTrainer(GradientTrainable gradientTrainable) {
        this.a = new double[getCoefsNumber()];
        if (this.validationEnabled) {
            initializeValidationSet();
        }
        this.trainer.setCoef(getCoefsNumber());
        this.trainer.setStartingPoint(computeStartingPoint());
        this.trainer.teach();
        for (int i = 0; i < this.elements.size(); i++) {
            if (!Double.isNaN(this.trainer.getBest(i)) && !Double.isInfinite(this.trainer.getBest(i))) {
                this.elements.get(i).setCoefficient(this.trainer.getBest(i));
            }
        }
        this.bias = this.trainer.getBest(this.elements.size());
    }

    @Override // game.models.single.SingleModel, game.models.ModelLearnableBase, game.models.ModelLearnable
    public void deleteLearningVectors() {
        this.elemVal = (double[][]) null;
        this.a = null;
        super.deleteLearningVectors();
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return PolynomialModelConfig.class;
    }

    @Override // game.models.single.SingleModel, game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("PolynomialModel.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        double[] dArr = new double[this.elements.size() + 1];
        for (int i = 0; i < dArr.length - 1; i++) {
            dArr[i] = this.elements.get(i).getCoefficent();
        }
        dArr[dArr.length - 1] = this.bias;
        CCodeUtils.convertArray(dArr, "parameters", sb);
        int[][] iArr = new int[this.elements.size()][this.inputsNumber];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            Element element = this.elements.get(i2);
            for (int i3 = 0; i3 < iArr[i2].length; i3++) {
                iArr[i2][i3] = element.getIndex(i3);
            }
        }
        CCodeUtils.convertArray(iArr, "powers", sb);
        sb.append("return polynomialModelOutput<").append(this.inputsNumber).append(",").append(this.elements.size()).append(">(input,parameters,powers);\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }
}
