package game.models.ensemble;

import configuration.models.ModelConfig;
import configuration.models.ensemble.AreaSpecializationModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.models.Model;
import game.utils.Utils;

/* loaded from: input_file:game/models/ensemble/ModelAreaSpecialization.class */
public class ModelAreaSpecialization extends ModelInstanceWeights {
    private int area;
    private double modelsSpecialization;
    private double[][] savedVectors;
    private double[] savedVectorsOutput;

    @Override // game.models.ensemble.ModelInstanceWeights
    protected void modifyWeights(double[] dArr, int i) {
        Model model = this.ensembleModels.get(i);
        for (int i2 = 0; i2 < this.learning_vectors; i2++) {
            dArr[i2] = dArr[i2] * Math.pow(Math.abs(this.target[i2] - model.getOutput(this.inputVect[i2])), this.modelsSpecialization);
        }
    }

    private int getBestModel(double[] dArr, double d) {
        int i = 0;
        double abs = Math.abs(this.ensembleModels.get(0).getOutput(dArr) - d);
        for (int i2 = 1; i2 < this.modelsNumber; i2++) {
            double abs2 = Math.abs(this.ensembleModels.get(i2).getOutput(dArr) - d);
            if (abs2 < abs) {
                abs = abs2;
                i = i2;
            }
        }
        return i;
    }

    private void saveLearningVectors() {
        this.savedVectors = this.inputVect;
        this.savedVectorsOutput = this.target;
    }

    @Override // game.models.ensemble.ModelEnsembleBase, game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        AreaSpecializationModelConfig areaSpecializationModelConfig = (AreaSpecializationModelConfig) modelConfig;
        this.area = areaSpecializationModelConfig.getArea();
        this.modelsSpecialization = areaSpecializationModelConfig.getModelsSpecialization();
    }

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ModelLearnable
    public void learn() {
        saveLearningVectors();
        super.learn();
        postLearnActions();
    }

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ensemble.ModelEnsemble
    public void relearn() {
        saveLearningVectors();
        super.relearn();
        postLearnActions();
    }

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ensemble.ModelEnsemble
    public void learn(int i) {
        saveLearningVectors();
        super.learn(i);
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        int length = this.area > this.savedVectors.length ? this.savedVectors.length : this.area;
        double[] dArr2 = new double[this.savedVectors.length];
        for (int i = 0; i < this.savedVectors.length; i++) {
            for (int i2 = 0; i2 < this.inputsNumber; i2++) {
                double d = dArr[i2] - this.savedVectors[i][i2];
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (d * d);
            }
            dArr2[i] = Math.sqrt(dArr2[i]);
        }
        int[] insertSort = Utils.insertSort(dArr2, length);
        double[] dArr3 = new double[this.modelsNumber];
        for (int i4 = 0; i4 < this.modelsNumber; i4++) {
            dArr3[i4] = Double.NaN;
        }
        double d2 = dArr2[insertSort[insertSort.length - 1]];
        double[] dArr4 = new double[this.modelsNumber];
        for (int i5 = 0; i5 < length; i5++) {
            int bestModel = getBestModel(this.savedVectors[insertSort[i5]], this.savedVectorsOutput[insertSort[i5]]);
            if (Double.isNaN(dArr3[bestModel])) {
                dArr3[bestModel] = this.ensembleModels.get(bestModel).getOutput(dArr);
            }
            dArr4[bestModel] = dArr4[bestModel] + Utils.gaussian(this.savedVectorsOutput[insertSort[i5]] - dArr3[bestModel], d2) + 1.0E-4d;
        }
        double d3 = 0.0d;
        for (int i6 = 0; i6 < this.modelsNumber; i6++) {
            d3 += dArr4[i6];
        }
        double d4 = 0.0d;
        for (int i7 = 0; i7 < this.modelsNumber; i7++) {
            if (dArr4[i7] != 0.0d) {
                if (Double.isNaN(dArr3[i7])) {
                    dArr3[i7] = this.ensembleModels.get(i7).getOutput(dArr);
                }
                d4 += (dArr3[i7] * dArr4[i7]) / d3;
            }
        }
        return d4;
    }

    @Override // game.models.ensemble.ModelEnsembleBase, game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        AreaSpecializationModelConfig areaSpecializationModelConfig = (AreaSpecializationModelConfig) super.getConfig();
        areaSpecializationModelConfig.setArea(this.area);
        areaSpecializationModelConfig.setModelsSpecialization(this.modelsSpecialization);
        return areaSpecializationModelConfig;
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return AreaSpecializationModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String[] successorsCode = getSuccessorsCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("AreaSpecializationEnsemble.h\"\n");
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        CCodeUtils.getCRegModelArray(successorsCode, "models", sb);
        CCodeUtils.convertArray(this.savedVectors, "savedVectors", sb);
        CCodeUtils.convertArray(this.savedVectorsOutput, "savedVectorsOutput", sb);
        CCodeUtils.convertArray(computeModelResponses(), "modelResponses", sb);
        sb.append("return AreaSpecializationEnsembleOutput<").append(this.savedVectors.length).append(",").append(this.modelsNumber).append(",").append(this.inputsNumber).append(">(input,models,").append(this.area).append(",").append(this.modelsSpecialization).append(",").append("savedVectors,savedVectorsOutput,modelResponses);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }

    public double[][] computeModelResponses() {
        double[][] dArr = new double[this.modelsNumber][this.savedVectors.length];
        for (int i = 0; i < dArr.length; i++) {
            Model model = this.ensembleModels.get(i);
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                dArr[i][i2] = model.getOutput(this.savedVectors[i2]);
            }
        }
        return dArr;
    }
}
