package game.evolution.treeEvolution.context;

import configuration.models.ModelConfig;
import game.data.AbstractGameData;
import game.evolution.treeEvolution.FitnessNode;
import game.evolution.treeEvolution.context.FitnessContextBase;
import game.evolution.treeEvolution.context.evaluators.ModelEvaluator;
import game.evolution.treeEvolution.context.evaluators.RMSEModelEvaluator;
import game.models.ConnectableModel;
import game.preprocessing.NormalizationPreprocessing;
import java.util.ArrayList;

/* loaded from: input_file:game/evolution/treeEvolution/context/ModelContextBase.class */
public abstract class ModelContextBase extends FitnessContextBase {
    protected ConnectableModel bestTestModel;
    protected ConnectableModel bestValidModel;
    protected ModelEvaluator evaluator;

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    protected FitnessContextBase.Fitness getModelFitness(FitnessNode fitnessNode) {
        return getModelFitnessLearnedOnData(fitnessNode, this.learnIndex, this.validIndex);
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    protected FitnessContextBase.Fitness getFitnessOnLearnValid(FitnessNode fitnessNode) {
        return getModelFitnessLearnedOnData(fitnessNode, this.finalLearnIndex, this.validIndex);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FitnessContextBase.Fitness getModelFitnessLearnedOnData(FitnessNode fitnessNode, int[] iArr, int[] iArr2) {
        ModelConfig modelConfig = (ModelConfig) fitnessNode;
        ConnectableModel initModel = initModel(modelConfig);
        learnModel(initModel, modelConfig, iArr);
        return evaluateModel(initModel, modelConfig, iArr2, this.testIndex);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FitnessContextBase.Fitness evaluateModel(ConnectableModel connectableModel, ModelConfig modelConfig, int[] iArr, int[] iArr2) {
        double performTestOnData = this.evaluator.performTestOnData(connectableModel, modelConfig, iArr, this.data);
        double performTestOnData2 = iArr2.length == 0 ? performTestOnData : this.evaluator.performTestOnData(connectableModel, modelConfig, iArr2, this.data);
        if (this.parallelLock != null) {
            getLock();
        }
        if (performTestOnData2 > this.bestTestFitness) {
            this.bestTestFitness = performTestOnData2;
            this.bestTestModel = connectableModel;
            this.bestTestModelConfig = (FitnessNode) modelConfig.mo161clone();
            this.bestTestModel.deleteLearningVectors();
        }
        if (performTestOnData > this.bestValidFitness) {
            this.bestValidFitness = performTestOnData;
            this.bestValidModel = connectableModel;
            this.bestValidModelConfig = (FitnessNode) modelConfig.mo161clone();
            this.bestValidModel.deleteLearningVectors();
        }
        if (this.parallelLock != null) {
            this.parallelLock.release();
        }
        return new FitnessContextBase.Fitness(performTestOnData, performTestOnData2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ConnectableModel initModel(ModelConfig modelConfig) {
        ConnectableModel connectableModel = new ConnectableModel();
        ArrayList arrayList = new ArrayList(this.globalPreprocessing.size() + this.localPreprocessing.size());
        arrayList.addAll(this.globalPreprocessing);
        for (int i = 0; i < this.localPreprocessing.size(); i++) {
            arrayList.add(this.localPreprocessing.get(i).m233clone());
        }
        connectableModel.init(modelConfig, arrayList);
        return connectableModel;
    }

    protected void learnModel(ConnectableModel connectableModel, ModelConfig modelConfig, int[] iArr) {
        connectableModel.setMaxLearningVectors(iArr.length);
        storeLearningVectors(this.data, connectableModel, iArr);
        connectableModel.learn();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public void initContextVariables(AbstractGameData abstractGameData) {
        super.initContextVariables(abstractGameData);
        NormalizationPreprocessing normalizationPreprocessing = new NormalizationPreprocessing();
        normalizationPreprocessing.init(abstractGameData.getInputVectors(), abstractGameData.getOutputAttrs());
        this.globalPreprocessing.add(normalizationPreprocessing);
        this.evaluator = new RMSEModelEvaluator();
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableModel learnModelOnAllData(FitnessNode fitnessNode) {
        int[] iArr = new int[this.data.getInstanceNumber()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        ModelConfig modelConfig = (ModelConfig) fitnessNode;
        ConnectableModel initModel = initModel(modelConfig);
        learnModel(initModel, modelConfig, iArr);
        return initModel;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public void clearBestModels() {
        super.clearBestModels();
        this.bestTestModel = null;
        this.bestValidModel = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void storeLearningVectors(AbstractGameData abstractGameData, ConnectableModel connectableModel, int[] iArr) {
        double[][] inputVectors = abstractGameData.getInputVectors();
        double[][] outputAttrs = abstractGameData.getOutputAttrs();
        double[] instanceWeights = abstractGameData.getInstanceWeights();
        if (instanceWeights == null) {
            for (int i = 0; i < iArr.length; i++) {
                connectableModel.storeLearningVector(inputVectors[iArr[i]], outputAttrs[iArr[i]][0]);
            }
            return;
        }
        for (int i2 : iArr) {
            connectableModel.storeLearningVector(inputVectors[i2], outputAttrs[i2][0], instanceWeights[i2]);
        }
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableModel getBestTestModel() {
        return this.bestTestModel;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableModel getBestValidModel() {
        return this.bestTestModel;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public FitnessNode getBestTestModelPredefinedConfig() {
        return (FitnessNode) this.bestTestModel.getConfig();
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public FitnessNode getBestValidModelPredefinedConfig() {
        return (FitnessNode) this.bestValidModel.getConfig();
    }
}
