package game.models.single;

import Jama.Matrix;
import com.rapidminer.tools.math.LinearRegression;
import com.rapidminer.tools.math.smoothing.SmoothingKernel;
import com.rapidminer.tools.math.smoothing.TriweightSmoothingKernel;
import configuration.models.ModelConfig;
import configuration.models.single.LocalPolynomialModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.evolution.treeEvolution.exception.LearnException;
import game.models.ModelLearnableBase;
import game.tools.distance.DistanceMeasure;
import game.utils.Utils;
import java.util.Arrays;
import org.ytoh.configurations.ui.SelectionSetModel;

/* loaded from: input_file:game/models/single/LocalPolynomialModel.class */
public class LocalPolynomialModel extends ModelLearnableBase {
    private int degree;
    private double ridge;
    private int nearestNeighbours;
    private SelectionSetModel<String> measureType;
    private double[][] learnDataInput;
    private double[] learnDataTarget;
    private DistanceMeasure distance;
    private SmoothingKernel kernelSmoother;
    private boolean closestToTestSet;
    private int expansions;
    private int bestSolutions;
    private int combinationComplexity = 30;
    private int[][][] polynomialMasks;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:game/models/single/LocalPolynomialModel$InstanceSet.class */
    public class InstanceSet {
        public double[][] inputSet;
        public double[] targetSet;
        public double[] distanceSet;

        public InstanceSet(double[][] dArr, double[] dArr2, double[] dArr3) {
            this.inputSet = dArr;
            this.targetSet = dArr2;
            this.distanceSet = dArr3;
        }

        public InstanceSet(int i, int i2) {
            this.inputSet = new double[i][i2];
            this.targetSet = new double[i];
            this.distanceSet = new double[i];
        }

        public void copyInstance(InstanceSet instanceSet, int i, int i2) {
            this.inputSet[i2] = instanceSet.inputSet[i];
            this.targetSet[i2] = instanceSet.targetSet[i];
            this.distanceSet[i2] = instanceSet.distanceSet[i];
        }

        public int size() {
            return this.inputSet.length;
        }

        public int dimension() {
            return this.inputSet[0].length;
        }
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        LocalPolynomialModelConfig localPolynomialModelConfig = (LocalPolynomialModelConfig) modelConfig;
        this.ridge = 1.0E-9d;
        this.degree = localPolynomialModelConfig.getMaxDegree();
        this.nearestNeighbours = localPolynomialModelConfig.getNearestNeighbours();
        this.closestToTestSet = localPolynomialModelConfig.getClosestUsedAsTest();
        this.measureType = localPolynomialModelConfig.getMeasureType();
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        this.learnDataInput = this.inputVect;
        this.learnDataTarget = this.target;
        try {
            this.distance = (DistanceMeasure) Class.forName("game.tools.distance." + this.measureType.getEnabledElements(String.class)[0]).getConstructor(double[][].class).newInstance(this.learnDataInput);
            this.polynomialMasks = multipleMaskPolynomialExpansion(this.inputsNumber, this.degree);
            setPolynomialCombinationsParameters(this.inputsNumber);
            this.kernelSmoother = new TriweightSmoothingKernel();
            postLearnActions();
        } catch (Exception e) {
            throw new LearnException(e.getMessage());
        }
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        double[] distanceToAll = this.distance.getDistanceToAll(dArr);
        InstanceSet instanceSet = new InstanceSet(this.nearestNeighbours, this.inputsNumber);
        int[] insertSort = Utils.insertSort(distanceToAll, this.nearestNeighbours);
        for (int i = 0; i < insertSort.length; i++) {
            instanceSet.targetSet[i] = this.learnDataTarget[insertSort[i]];
            instanceSet.distanceSet[i] = distanceToAll[insertSort[i]];
            instanceSet.inputSet[i] = this.learnDataInput[insertSort[i]];
        }
        int i2 = (int) ((0.3d * this.nearestNeighbours) + 0.5d);
        int i3 = this.nearestNeighbours - i2;
        InstanceSet instanceSet2 = new InstanceSet(i2, this.inputsNumber);
        InstanceSet instanceSet3 = new InstanceSet(i3, this.inputsNumber);
        if (this.closestToTestSet) {
            for (int i4 = 0; i4 < i2; i4++) {
                instanceSet2.copyInstance(instanceSet, i4, i4);
            }
            int i5 = 0;
            for (int i6 = i2; i6 < this.nearestNeighbours; i6++) {
                instanceSet3.copyInstance(instanceSet, i6, i5);
                i5++;
            }
        } else {
            for (int i7 = 0; i7 < i3; i7++) {
                instanceSet3.copyInstance(instanceSet, i7, i7);
            }
            int i8 = 0;
            for (int i9 = i3; i9 < this.nearestNeighbours; i9++) {
                instanceSet3.copyInstance(instanceSet, i9, i8);
                i8++;
            }
        }
        return performAreaPrediction(instanceSet, instanceSet3, instanceSet2, dArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [int[][], java.lang.Object] */
    /* JADX WARN: Type inference failed for: r0v27 */
    /* JADX WARN: Type inference failed for: r0v56 */
    /* JADX WARN: Type inference failed for: r1v31, types: [int[][]] */
    /* JADX WARN: Type inference failed for: r5v1, types: [int[][]] */
    /* JADX WARN: Type inference failed for: r8v0, types: [game.models.single.LocalPolynomialModel] */
    protected double performAreaPrediction(InstanceSet instanceSet, InstanceSet instanceSet2, InstanceSet instanceSet3, double[] dArr) {
        int numberOfTerms;
        double[] applyKernelSmoother = applyKernelSmoother(instanceSet2.distanceSet);
        double[] dArr2 = new double[this.polynomialMasks.length];
        for (int i = 0; i < this.polynomialMasks.length; i++) {
            dArr2[i] = getMaskFitness(instanceSet2, instanceSet3, this.polynomialMasks[i], applyKernelSmoother);
        }
        int[][][] combineBestPolynomials = combineBestPolynomials(this.polynomialMasks, dArr2);
        ?? r0 = new int[this.polynomialMasks.length + combineBestPolynomials.length];
        System.arraycopy(this.polynomialMasks, 0, r0, 0, this.polynomialMasks.length);
        System.arraycopy(combineBestPolynomials, 0, r0, this.polynomialMasks.length, combineBestPolynomials.length);
        double[] dArr3 = new double[r0.length];
        System.arraycopy(dArr2, 0, dArr3, 0, dArr2.length);
        for (int length = this.polynomialMasks.length; length < r0.length; length++) {
            dArr3[length] = getMaskFitness(instanceSet2, instanceSet3, r0[length], applyKernelSmoother);
        }
        int[][] iArr = r0[0];
        double d = dArr3[0];
        int numberOfTerms2 = getNumberOfTerms(iArr);
        for (int i2 = 1; i2 < dArr3.length; i2++) {
            if (dArr3[i2] > d && (((numberOfTerms = getNumberOfTerms(r0[i2])) > numberOfTerms2 && Math.abs(dArr3[i2] - d) > Math.abs(0.1d * d)) || numberOfTerms <= numberOfTerms2)) {
                d = dArr3[i2];
                iArr = r0[i2];
                numberOfTerms2 = getNumberOfTerms(iArr);
            }
        }
        double[] regressionCoefficients = getRegressionCoefficients(instanceSet, iArr, applyKernelSmoother(instanceSet.distanceSet));
        double[] polynomialExpansion = polynomialExpansion(dArr, iArr);
        double d2 = 0.0d;
        for (int i3 = 0; i3 < regressionCoefficients.length; i3++) {
            d2 += polynomialExpansion[i3] * regressionCoefficients[i3];
        }
        return d2;
    }

    protected void setPolynomialCombinationsParameters(int i) {
        if (i >= Math.sqrt(this.combinationComplexity) || i <= 3) {
            this.expansions = 3;
        } else {
            this.expansions = i;
        }
        this.bestSolutions = (int) Math.round(this.combinationComplexity / this.expansions);
        if (this.polynomialMasks.length < this.bestSolutions) {
            this.bestSolutions = this.polynomialMasks.length;
            if (this.polynomialMasks.length < this.expansions) {
                this.expansions = this.polynomialMasks.length;
            } else {
                this.expansions = Math.min(this.bestSolutions, (int) Math.round(this.combinationComplexity / this.bestSolutions));
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v18, types: [int[][]] */
    /* JADX WARN: Type inference failed for: r0v32, types: [int[]] */
    /* JADX WARN: Type inference failed for: r0v4, types: [int[][], int[][][]] */
    /* JADX WARN: Type inference failed for: r12v2 */
    /* JADX WARN: Type inference failed for: r2v1 */
    /* JADX WARN: Type inference failed for: r2v7 */
    /* JADX WARN: Type inference failed for: r2v8 */
    /* JADX WARN: Type inference failed for: r6v0, types: [game.models.single.LocalPolynomialModel] */
    protected int[][][] combineBestPolynomials(int[][][] iArr, double[] dArr) {
        int[] insertSortDesc = Utils.insertSortDesc(dArr, this.bestSolutions);
        ?? r0 = new int[insertSortDesc.length];
        for (int i = 0; i < insertSortDesc.length; i++) {
            r0[i] = iArr[insertSortDesc[i]];
        }
        int i2 = 0;
        for (int i3 = 2; i3 <= this.expansions; i3++) {
            i2 += (int) nOverK(r0.length, r0.length - i3);
        }
        ?? r12 = new int[i2];
        int i4 = 0;
        for (int i5 = 2; i5 <= this.expansions; i5++) {
            int[][] combinations = getCombinations(i5, r0.length);
            for (int i6 = 0; i6 < combinations.length; i6++) {
                if (isAllowed(combinations[i6], r0, this.expansions)) {
                    ?? r02 = new int[combinations[i6].length];
                    for (int i7 = 0; i7 < combinations[i6].length; i7++) {
                        r02[i7] = r0[combinations[i6][i7]][0];
                    }
                    r12[i4] = r02;
                    i4++;
                }
            }
        }
        int i8 = i4;
        int length = r12.length;
        int[][][] iArr2 = r12;
        if (i8 != length) {
            ?? r03 = new int[i4];
            for (int i9 = 0; i9 < r03.length; i9++) {
                r03[i9] = r12[i9];
            }
            iArr2 = r03;
        }
        return iArr2;
    }

    protected boolean isAllowed(int[] iArr, int[][][] iArr2, int i) {
        int i2 = 0;
        for (int i3 : iArr) {
            i2 += maxDegree(iArr2[i3]);
        }
        return i2 <= i;
    }

    protected int maxDegree(int[][] iArr) {
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2].length > i) {
                i = iArr[i2].length;
            }
        }
        return i;
    }

    protected int getNumberOfTerms(int[][] iArr) {
        int i = 0;
        for (int[] iArr2 : iArr) {
            i += iArr2.length;
        }
        return i;
    }

    protected double getMaskFitness(InstanceSet instanceSet, InstanceSet instanceSet2, int[][] iArr, double[] dArr) {
        return (-1.0d) * regressionError(instanceSet2, getRegressionCoefficients(instanceSet, iArr, dArr), iArr);
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    protected double[] getRegressionCoefficients(InstanceSet instanceSet, int[][] iArr, double[] dArr) {
        ?? r0 = new double[instanceSet.size()];
        double[][] dArr2 = new double[instanceSet.size()][1];
        for (int i = 0; i < r0.length; i++) {
            r0[i] = polynomialExpansion(instanceSet.inputSet[i], iArr);
            dArr2[i][0] = instanceSet.targetSet[i];
        }
        return LinearRegression.performRegression(new Matrix((double[][]) r0), new Matrix(dArr2), dArr, this.ridge);
    }

    protected double[] polynomialExpansion(double[] dArr, int[][] iArr) {
        double[] dArr2 = new double[iArr.length + 1];
        dArr2[0] = 1.0d;
        for (int i = 0; i < iArr.length; i++) {
            dArr2[i + 1] = 1.0d;
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                int i3 = i + 1;
                dArr2[i3] = dArr2[i3] * dArr[iArr[i][i2]];
            }
        }
        return dArr2;
    }

    protected double regressionError(InstanceSet instanceSet, double[] dArr, int[][] iArr) {
        double d = 0.0d;
        for (int i = 0; i < instanceSet.size(); i++) {
            double d2 = 0.0d;
            double[] polynomialExpansion = polynomialExpansion(instanceSet.inputSet[i], iArr);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d2 += polynomialExpansion[i2] * dArr[i2];
            }
            double d3 = d2 - instanceSet.targetSet[i];
            d += d3 * d3;
        }
        return Math.sqrt(d / instanceSet.size());
    }

    protected double[] applyKernelSmoother(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = this.kernelSmoother.getWeight(dArr[i2], d);
        }
        return dArr2;
    }

    public int getDegree() {
        return this.degree;
    }

    public double getRidge() {
        return this.ridge;
    }

    private static int[][] getCombinations(int i, int i2) {
        int[][] iArr = new int[(int) nOverK(i2, i2 - i)][i];
        int i3 = 0;
        int[] iArr2 = new int[i];
        for (int i4 = 0; i4 < i; i4++) {
            iArr2[i4] = i4;
        }
        while (true) {
            for (int i5 = 1; i5 < i; i5++) {
                if (iArr2[i5] >= i2 - ((i - 1) - i5)) {
                    iArr2[i5] = iArr2[i5 - 1] + 1;
                }
            }
            for (int i6 = 0; i6 < i; i6++) {
                if (iArr2[i6] >= i2) {
                    return iArr;
                }
            }
            iArr[i3] = (int[]) iArr2.clone();
            i3++;
            int i7 = i - 1;
            iArr2[i7] = iArr2[i7] + 1;
            for (int i8 = i - 1; i8 >= 1; i8--) {
                if (iArr2[i8] >= i2 - ((i - 1) - i8)) {
                    int i9 = i8 - 1;
                    iArr2[i9] = iArr2[i9] + 1;
                }
            }
        }
    }

    public static long nOverK(int i, int i2) {
        int i3 = i2 < i - i2 ? i2 : i - i2;
        long j = 1;
        for (int i4 = 1; i4 <= i3; i4++) {
            j = (j * i) / i4;
            i--;
        }
        return j;
    }

    public static int[][][] multipleMaskPolynomialExpansion(int i, int i2) {
        int[][][] iArr = new int[getPolynomialExpansionSize(i, i2) - 1][1];
        int i3 = 0;
        for (int i4 = 1; i4 <= i2; i4++) {
            int[] iArr2 = new int[i4];
            while (true) {
                int[] iArr3 = new int[i4];
                iArr3[0] = iArr2[0];
                for (int i5 = 1; i5 < i4; i5++) {
                    iArr3[i5] = iArr2[i5];
                }
                iArr[i3][0] = iArr3;
                i3++;
                if (!isLastPosition(iArr2, i)) {
                    iArr2[0] = iArr2[0] + 1;
                    if (iArr2[0] == i) {
                        iArr2[0] = moveCounterAhead(iArr2, 1, i);
                    }
                }
            }
        }
        return iArr;
    }

    public static int getPolynomialExpansionSize(int i, int i2) {
        int[] iArr = new int[i];
        Arrays.fill(iArr, 1);
        int i3 = 1;
        for (int i4 = 1; i4 <= i2; i4++) {
            for (int i5 = 0; i5 < iArr.length; i5++) {
                i3 += iArr[i5];
                int i6 = 0;
                for (int i7 = i5; i7 < iArr.length; i7++) {
                    i6 += iArr[i7];
                }
                iArr[i5] = i6;
            }
        }
        return i3;
    }

    private static boolean isLastPosition(int[] iArr, int i) {
        for (int i2 : iArr) {
            if (i2 < i - 1) {
                return false;
            }
        }
        return true;
    }

    private static int moveCounterAhead(int[] iArr, int i, int i2) {
        iArr[i] = iArr[i] + 1;
        if (iArr[i] == i2) {
            iArr[i] = moveCounterAhead(iArr, i + 1, i2);
        }
        return iArr[i];
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return LocalPolynomialModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("LocalPolynomialModel.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        int[][][] convertMasks = convertMasks(this.polynomialMasks);
        CCodeUtils.convertArray(this.learnDataInput, "learningVectorsInput", sb, true);
        CCodeUtils.convertArray(this.learnDataTarget, "learningVectorsTarget", sb);
        CCodeUtils.convertArray(convertMasks, "masks", sb);
        int i = this.bestSolutions;
        if (i > convertMasks.length) {
            i = convertMasks.length;
        }
        String str = this.measureType.getEnabledElements(String.class)[0];
        String str2 = str.substring(0, 1).toLowerCase() + str.substring(1);
        String simpleName = this.kernelSmoother.getClass().getSimpleName();
        sb.append("return localPolynomialModelOutput<").append(this.inputsNumber).append(",").append(this.learning_vectors).append(",").append(this.nearestNeighbours).append(",").append(convertMasks.length).append(",").append(this.degree).append(",").append(this.expansions).append(",").append(i).append(">(input,learningVectorsInput,learningVectorsTarget,").append(str2).append(",").append(simpleName.substring(0, 1).toLowerCase() + simpleName.substring(1)).append(",masks,").append(this.closestToTestSet).append(");\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }

    private int[][][] convertMasks(int[][][] iArr) {
        int i = 1;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2][0].length > i) {
                i = iArr[i2][0].length;
            }
        }
        int[][][] iArr2 = new int[iArr.length][1][i];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            for (int i4 = 0; i4 < iArr[i3][0].length; i4++) {
                iArr2[i3][0][i4] = iArr[i3][0][i4];
            }
            for (int length = iArr[i3][0].length; length < i; length++) {
                iArr2[i3][0][length] = -1;
            }
        }
        return iArr2;
    }
}
