package game.models.single;

import configuration.models.ModelConfig;
import configuration.models.single.KNNModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.evolution.treeEvolution.exception.LearnException;
import game.models.ModelLearnableBase;
import game.tools.distance.DistanceMeasure;
import game.utils.Utils;
import org.ytoh.configurations.ui.SelectionSetModel;
import utils.UtilsCommon;

/* loaded from: input_file:game/models/single/KNNModel.class */
public class KNNModel extends ModelLearnableBase {
    private int nearestNeighbours;
    private boolean weightByDistance;
    private SelectionSetModel<String> measureType;
    private double[][] learnDataInput;
    private double[] learnDataTarget;
    private DistanceMeasure distance;

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        KNNModelConfig kNNModelConfig = (KNNModelConfig) modelConfig;
        this.nearestNeighbours = kNNModelConfig.getNearestNeighbours();
        this.weightByDistance = kNNModelConfig.getWeightedVote();
        this.measureType = UtilsCommon.cloneSelectionSet(kNNModelConfig.getMeasureType());
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        this.learnDataInput = this.inputVect;
        this.learnDataTarget = this.target;
        try {
            this.distance = (DistanceMeasure) Class.forName("game.tools.distance." + this.measureType.getEnabledElements(String.class)[0]).getConstructor(double[][].class).newInstance(this.learnDataInput);
            postLearnActions();
        } catch (Exception e) {
            throw new LearnException(e.getMessage());
        }
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        double max;
        double[] distanceToAll = this.distance.getDistanceToAll(dArr);
        int[] insertSort = Utils.insertSort(distanceToAll, this.nearestNeighbours);
        double d = 0.0d;
        if (this.weightByDistance) {
            double d2 = 0.0d;
            for (int i : insertSort) {
                d2 += distanceToAll[i];
            }
            if (d2 == 0.0d) {
                d2 = 1.0d;
                max = this.nearestNeighbours;
            } else {
                max = Math.max(this.nearestNeighbours - 1, 1);
            }
            for (int i2 = 0; i2 < insertSort.length; i2++) {
                d += (this.learnDataTarget[insertSort[i2]] * (1.0d - (distanceToAll[insertSort[i2]] / d2))) / max;
            }
        } else {
            for (int i3 : insertSort) {
                d += this.learnDataTarget[i3];
            }
            d /= this.nearestNeighbours;
        }
        return d;
    }

    @Override // game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        KNNModelConfig kNNModelConfig = (KNNModelConfig) super.getConfig();
        kNNModelConfig.setNearestNeighbours(this.nearestNeighbours);
        kNNModelConfig.setWeightedVote(this.weightByDistance);
        kNNModelConfig.setMeasureType(this.measureType);
        return kNNModelConfig;
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return KNNModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("KNNRegression.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        String str = this.measureType.getEnabledElements(String.class)[0];
        String str2 = str.substring(0, 1).toLowerCase() + str.substring(1);
        CCodeUtils.convertArray(this.learnDataInput, "learningVectorsInput", sb);
        CCodeUtils.convertArray(this.learnDataTarget, "learningVectorsTarget", sb);
        sb.append("return knnModelOutput<").append(this.inputsNumber).append(",").append(this.learnDataInput.length).append(",").append(this.nearestNeighbours).append(">(input,learningVectorsInput,learningVectorsTarget,").append(str2).append(",").append(this.weightByDistance).append(");\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }
}
