package game.classifiers.ensemble;

import configuration.classifiers.ClassifierConfig;
import configuration.classifiers.ensemble.ClassifierThresholdingConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.classifiers.Classifier;
import game.evolution.treeEvolution.evolutionControl.EvolutionUtils;
import game.evolution.treeEvolution.supportAlgorithms.StratifiedRandomSampling;
import game.utils.Utils;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:game/classifiers/ensemble/ClassifierThresholding.class */
public class ClassifierThresholding extends OutputRelatedEnsemble {
    protected double dataMultiplier;
    protected double[] thresholds;
    protected double[][][] accuracyMatrix;

    @Override // game.classifiers.ensemble.OutputRelatedEnsemble, game.classifiers.ensemble.ClassifierEnsembleBase, game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void init(ClassifierConfig classifierConfig) {
        super.init(classifierConfig);
        this.dataMultiplier = ((ClassifierThresholdingConfig) classifierConfig).getDataMultiplier();
    }

    @Override // game.classifiers.Classifier
    public void learn() {
        Random random = new Random();
        List<List<Integer>> classIndexes = StratifiedRandomSampling.getClassIndexes(EvolutionUtils.convertOutputData(this.target), this.outputs);
        for (int i = 0; i < this.outputs; i++) {
            Classifier classifier = this.ensClassifiers.get(i);
            prepareData(classifier, i, random, classIndexes);
            classifier.learn();
        }
        Classifier classifier2 = this.ensClassifiers.get(this.outputs);
        prepareData(classifier2);
        classifier2.learn();
        this.thresholds = new double[100];
        for (int i2 = 0; i2 < this.thresholds.length; i2++) {
            this.thresholds[(this.thresholds.length - i2) - 1] = i2 / (this.thresholds.length - 1);
        }
        this.accuracyMatrix = computeMatrices(this.thresholds);
        postLearnActions();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[][][] computeMatrices(double[] dArr) {
        double[][][] dArr2 = new double[this.ensClassifiers.size()][dArr.length][this.outputs];
        double[] dArr3 = new double[this.learning_vectors];
        double[] dArr4 = new double[this.learning_vectors];
        for (int i = 0; i < this.ensClassifiers.size(); i++) {
            Classifier classifier = this.ensClassifiers.get(i);
            for (int i2 = 0; i2 < this.learning_vectors; i2++) {
                dArr4[i2] = classifier.getOutputProbabilities(this.inputVect[i2]);
            }
            for (int i3 = 0; i3 < this.outputs; i3++) {
                for (int i4 = 0; i4 < this.learning_vectors; i4++) {
                    dArr3[i4] = -dArr4[i4][i3];
                }
                int[] quickSort = Utils.quickSort(dArr3);
                double d = 0.0d;
                int i5 = 0;
                for (int i6 = 0; i6 < this.learning_vectors; i6++) {
                    if (this.target[quickSort[i6]][i3] == 1.0d) {
                        d += 1.0d;
                    }
                    while (i5 < dArr.length && dArr4[quickSort[i6]][i3] < dArr[i5]) {
                        dArr2[i][i5][i3] = d / (i6 + 1);
                        i5++;
                    }
                }
                for (int i7 = i5; i7 < dArr.length; i7++) {
                    dArr2[i][i7][i3] = d / this.learning_vectors;
                }
            }
        }
        return dArr2;
    }

    private void prepareData(Classifier classifier) {
        classifier.setMaxLearningVectors(this.learning_vectors);
        for (int i = 0; i < this.learning_vectors; i++) {
            classifier.storeLearningVector(this.inputVect[i], this.target[i]);
        }
    }

    private void prepareData(Classifier classifier, int i, Random random, List<List<Integer>> list) {
        ArrayList arrayList = new ArrayList(this.learning_vectors);
        int i2 = 0;
        ArrayList arrayList2 = new ArrayList(list.size());
        for (int i3 = 0; i3 < list.size(); i3++) {
            arrayList2.add(list.get(i3));
        }
        if (this.dataMultiplier > 1.0d) {
            for (int i4 = 0; i4 < this.outputs; i4++) {
                List list2 = (List) arrayList2.get(i4);
                int size = list2.size();
                i2 += size;
                if (i4 != i) {
                    int size2 = (int) ((this.dataMultiplier - 1.0d) * ((List) arrayList2.get(i4)).size());
                    for (int i5 = 0; i5 < size2; i5++) {
                        arrayList.add(list2.get(random.nextInt(size)));
                    }
                }
            }
            i2 += arrayList.size();
        } else {
            List list3 = (List) arrayList2.get(i);
            ArrayList arrayList3 = new ArrayList(list3.size());
            for (int i6 = 0; i6 < list3.size(); i6++) {
                arrayList3.add(list3.get(i6));
            }
            arrayList2.set(i, arrayList3);
            int size3 = (int) ((1.0d - this.dataMultiplier) * ((List) arrayList2.get(i)).size());
            List list4 = (List) arrayList2.get(i);
            for (int i7 = 0; i7 < size3; i7++) {
                list4.remove(random.nextInt(list4.size()));
            }
            for (int i8 = 0; i8 < this.outputs; i8++) {
                i2 += ((List) arrayList2.get(i8)).size();
            }
        }
        classifier.setMaxLearningVectors(i2);
        for (int i9 = 0; i9 < this.outputs; i9++) {
            List list5 = (List) arrayList2.get(i9);
            int size4 = list5.size();
            for (int i10 = 0; i10 < size4; i10++) {
                classifier.storeLearningVector(this.inputVect[((Integer) list5.get(i10)).intValue()], this.target[((Integer) list5.get(i10)).intValue()]);
            }
        }
        int size5 = arrayList.size();
        for (int i11 = 0; i11 < size5; i11++) {
            classifier.storeLearningVector(this.inputVect[((Integer) arrayList.get(i11)).intValue()], this.target[((Integer) arrayList.get(i11)).intValue()]);
        }
    }

    @Override // game.classifiers.ensemble.ClassifierEnsemble
    public void learn(int i) {
        learn();
    }

    @Override // game.classifiers.Classifier
    public void relearn() {
        learn();
    }

    @Override // game.classifiers.Classifier
    public double[] getOutputProbabilities(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        int i = 0;
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.ensClassifiers.size(); i2++) {
            double[] outputProbabilities = this.ensClassifiers.get(i2).getOutputProbabilities(dArr);
            for (int i3 = 0; i3 < outputProbabilities.length; i3++) {
                int findThresholdIndex = findThresholdIndex(outputProbabilities[i3]);
                if (this.accuracyMatrix[i2][findThresholdIndex][i3] > d || (this.accuracyMatrix[i2][findThresholdIndex][i3] == d && outputProbabilities[i3] > d2)) {
                    d = this.accuracyMatrix[i2][findThresholdIndex][i3];
                    i = i3;
                    d2 = outputProbabilities[i3];
                }
            }
        }
        double[] dArr2 = new double[this.outputs];
        dArr2[i] = 1.0d;
        return dArr2;
    }

    protected int findThresholdIndex(double d) {
        for (int i = 0; i < this.thresholds.length; i++) {
            if (d >= this.thresholds[i]) {
                return i;
            }
        }
        return this.thresholds.length - 1;
    }

    @Override // game.classifiers.ClassifierBase, game.configuration.Configurable
    public Class getConfigClass() {
        return ClassifierThresholdingConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String[] successorsCode = getSuccessorsCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        sb.append("#include \"").append(CCodeUtils.getClassificationModelPath()).append("ThresholdEnsemble.h\"\n");
        CCodeUtils.getCClassificationHeader(uniqueFunctionName, this.inputs, sb);
        CCodeUtils.getCClsModelArray(successorsCode, "models", sb);
        CCodeUtils.convertArray(this.thresholds, "thresholds", sb);
        CCodeUtils.convertArray(this.accuracyMatrix, "accuracyMatrix", sb);
        sb.append("return thresholdEnsembleOutput<").append(this.inputs).append(",").append(this.outputs).append(",").append(this.numClassifiers).append(",").append(this.thresholds.length).append(">(input,models,thresholds,accuracyMatrix);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }
}
