package game.classifiers.single;

import configuration.classifiers.ClassifierConfig;
import configuration.classifiers.single.KNNClassifierConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.classifiers.ClassifierBase;
import game.evolution.treeEvolution.evolutionControl.EvolutionUtils;
import game.evolution.treeEvolution.exception.LearnException;
import game.tools.distance.DistanceMeasure;
import game.utils.Utils;
import org.ytoh.configurations.ui.SelectionSetModel;
import utils.UtilsCommon;

/* loaded from: input_file:game/classifiers/single/KNNClassifier.class */
public class KNNClassifier extends ClassifierBase {
    private int nearestNeighbours;
    private boolean weightByDistance;
    private SelectionSetModel<String> measureType;
    private double[][] learnDataInput;
    private int[] learnDataClasses;
    private DistanceMeasure distance;

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void init(ClassifierConfig classifierConfig) {
        super.init(classifierConfig);
        KNNClassifierConfig kNNClassifierConfig = (KNNClassifierConfig) classifierConfig;
        this.nearestNeighbours = kNNClassifierConfig.getNearestNeighbours();
        this.weightByDistance = kNNClassifierConfig.getWeightedVote();
        this.measureType = UtilsCommon.cloneSelectionSet(kNNClassifierConfig.getMeasureType());
    }

    @Override // game.classifiers.Classifier
    public void learn() {
        this.learnDataInput = this.inputVect;
        this.learnDataClasses = EvolutionUtils.convertOutputData(this.target);
        try {
            this.distance = (DistanceMeasure) Class.forName("game.tools.distance." + this.measureType.getEnabledElements(String.class)[0]).getConstructor(double[][].class).newInstance(this.learnDataInput);
            postLearnActions();
        } catch (Exception e) {
            throw new LearnException(e.getMessage());
        }
    }

    @Override // game.classifiers.Classifier
    public void relearn() {
        learn();
    }

    @Override // game.classifiers.Classifier
    public double[] getOutputProbabilities(double[] dArr) {
        double max;
        double[] distanceToAll = this.distance.getDistanceToAll(dArr);
        int[] insertSort = Utils.insertSort(distanceToAll, this.nearestNeighbours);
        double[] dArr2 = new double[this.outputs];
        if (this.weightByDistance) {
            double d = 0.0d;
            for (int i : insertSort) {
                d += distanceToAll[i];
            }
            if (d == 0.0d) {
                d = 1.0d;
                max = this.nearestNeighbours;
            } else {
                max = Math.max(this.nearestNeighbours - 1, 1);
            }
            for (int i2 = 0; i2 < insertSort.length; i2++) {
                int i3 = this.learnDataClasses[insertSort[i2]];
                dArr2[i3] = dArr2[i3] + ((1.0d - (distanceToAll[insertSort[i2]] / d)) / max);
            }
        } else {
            double d2 = 1 / this.nearestNeighbours;
            for (int i4 : insertSort) {
                int i5 = this.learnDataClasses[i4];
                dArr2[i5] = dArr2[i5] + d2;
            }
        }
        return dArr2;
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public ClassifierConfig getConfig() {
        KNNClassifierConfig kNNClassifierConfig = (KNNClassifierConfig) super.getConfig();
        kNNClassifierConfig.setNearestNeighbours(this.nearestNeighbours);
        kNNClassifierConfig.setWeightedVote(this.weightByDistance);
        kNNClassifierConfig.setMeasureType(this.measureType);
        return kNNClassifierConfig;
    }

    @Override // game.classifiers.ClassifierBase, game.configuration.Configurable
    public Class getConfigClass() {
        return KNNClassifierConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append(CCodeUtils.getClassificationModelPath()).append("KNN.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCClassificationHeader(uniqueFunctionName, this.inputs, sb);
        String str = this.measureType.getEnabledElements(String.class)[0];
        String str2 = str.substring(0, 1).toLowerCase() + str.substring(1);
        CCodeUtils.convertArray(this.learnDataInput, "learningVectorsInput", sb);
        CCodeUtils.convertArray(this.learnDataClasses, "learningVectorsClasses", sb);
        sb.append("return knnClassifierOutput<").append(this.inputs).append(",").append(this.learnDataInput.length).append(",").append(this.nearestNeighbours).append(",").append(this.outputs).append(">(input,learningVectorsInput,learningVectorsClasses,").append(str2).append(",").append(this.weightByDistance).append(");\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }
}
