package game.evolution.treeEvolution.context;

import configuration.classifiers.ClassifierConfig;
import game.classifiers.ConnectableClassifier;
import game.data.AbstractGameData;
import game.evolution.treeEvolution.FitnessNode;
import game.evolution.treeEvolution.context.FitnessContextBase;
import game.evolution.treeEvolution.context.evaluators.AccuracyClassifierEvaluator;
import game.evolution.treeEvolution.context.evaluators.ClassifierEvaluator;
import game.evolution.treeEvolution.context.evaluators.CostFunctionClassifierEvaluator;
import game.evolution.treeEvolution.evolutionControl.EvolutionUtils;
import game.preprocessing.CloneBalancePreprocessing;
import game.preprocessing.NormalizationPreprocessing;
import java.util.ArrayList;
import org.apache.commons.lang.StringUtils;
import weka.core.TestInstances;

/* loaded from: input_file:game/evolution/treeEvolution/context/ClassifierContextBase.class */
public abstract class ClassifierContextBase extends FitnessContextBase {
    protected ConnectableClassifier bestTestModel;
    protected ConnectableClassifier bestValidModel;
    protected ClassifierEvaluator evaluator;

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    protected FitnessContextBase.Fitness getModelFitness(FitnessNode fitnessNode) {
        return getModelFitnessLearnedOnData(fitnessNode, this.learnIndex, this.validIndex);
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    protected FitnessContextBase.Fitness getFitnessOnLearnValid(FitnessNode fitnessNode) {
        return getModelFitnessLearnedOnData(fitnessNode, this.finalLearnIndex, this.validIndex);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FitnessContextBase.Fitness getModelFitnessLearnedOnData(FitnessNode fitnessNode, int[] iArr, int[] iArr2) {
        ConnectableClassifier initClassifier = initClassifier((ClassifierConfig) fitnessNode, iArr);
        initClassifier.learn();
        return evaluateClassifier(initClassifier, fitnessNode, iArr2, this.testIndex);
    }

    protected ConnectableClassifier initClassifier(ClassifierConfig classifierConfig, int[] iArr) {
        ConnectableClassifier initClassifier = initClassifier(classifierConfig);
        initClassifier.setMaxLearningVectors(iArr.length);
        storeLearningVectors(this.data, initClassifier, iArr);
        return initClassifier;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FitnessContextBase.Fitness evaluateClassifier(ConnectableClassifier connectableClassifier, FitnessNode fitnessNode, int[] iArr, int[] iArr2) {
        double performTestOnData = this.evaluator.performTestOnData(connectableClassifier, iArr, this.data);
        double performTestOnData2 = iArr2.length == 0 ? performTestOnData : this.evaluator.performTestOnData(connectableClassifier, iArr2, this.data);
        if (this.parallelComputation) {
            getLock();
        }
        if (performTestOnData2 > this.bestTestFitness) {
            this.bestTestFitness = performTestOnData2;
            this.bestTestModel = connectableClassifier;
            this.bestTestModelConfig = fitnessNode.clone();
            this.bestTestModel.deleteLearningVectors();
        }
        if (performTestOnData > this.bestValidFitness) {
            this.bestValidFitness = performTestOnData;
            this.bestValidModel = connectableClassifier;
            this.bestValidModelConfig = fitnessNode.clone();
            this.bestValidModel.deleteLearningVectors();
        }
        if (this.parallelComputation) {
            this.parallelLock.release();
        }
        return new FitnessContextBase.Fitness(performTestOnData, performTestOnData2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ConnectableClassifier initClassifier(ClassifierConfig classifierConfig) {
        ConnectableClassifier connectableClassifier = new ConnectableClassifier();
        ArrayList arrayList = new ArrayList(this.globalPreprocessing.size() + this.localPreprocessing.size());
        arrayList.addAll(this.globalPreprocessing);
        for (int i = 0; i < this.localPreprocessing.size(); i++) {
            arrayList.add(this.localPreprocessing.get(i).m233clone());
        }
        connectableClassifier.init(classifierConfig, arrayList);
        return connectableClassifier;
    }

    public void init(AbstractGameData abstractGameData, double[][] dArr) {
        super.init(abstractGameData);
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + dArr[i][i2];
            }
            dArr2[i] = dArr2[i] / (dArr[i].length - 1);
        }
        CloneBalancePreprocessing cloneBalancePreprocessing = new CloneBalancePreprocessing();
        cloneBalancePreprocessing.setDataNumMultipliers(dArr2);
        this.localPreprocessing.add(cloneBalancePreprocessing);
        this.evaluator = new CostFunctionClassifierEvaluator(dArr);
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public void init(AbstractGameData abstractGameData) {
        super.init(abstractGameData);
        double[] autoBalance = CloneBalancePreprocessing.autoBalance(EvolutionUtils.getClassCount(EvolutionUtils.convertOutputData(abstractGameData.getOutputAttrs())));
        if (autoBalance != null) {
            this.localPreprocessing.add(new CloneBalancePreprocessing());
            String str = StringUtils.EMPTY;
            for (double d : autoBalance) {
                str = str + TestInstances.DEFAULT_SEPARATORS + d;
            }
            this.log.info("Enabling class auto balance with weights:" + str);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public void initContextVariables(AbstractGameData abstractGameData) {
        super.initContextVariables(abstractGameData);
        NormalizationPreprocessing normalizationPreprocessing = new NormalizationPreprocessing();
        normalizationPreprocessing.init(abstractGameData.getInputVectors(), abstractGameData.getOutputAttrs());
        this.globalPreprocessing.add(normalizationPreprocessing);
        this.evaluator = new AccuracyClassifierEvaluator();
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableClassifier learnModelOnAllData(FitnessNode fitnessNode) {
        int[] iArr = new int[this.data.getInstanceNumber()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        ConnectableClassifier initClassifier = initClassifier((ClassifierConfig) fitnessNode, iArr);
        initClassifier.learn();
        return initClassifier;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public void clearBestModels() {
        super.clearBestModels();
        this.bestTestModel = null;
        this.bestValidModel = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void storeLearningVectors(AbstractGameData abstractGameData, ConnectableClassifier connectableClassifier, int[] iArr) {
        double[][] inputVectors = abstractGameData.getInputVectors();
        double[][] outputAttrs = abstractGameData.getOutputAttrs();
        double[] instanceWeights = abstractGameData.getInstanceWeights();
        if (instanceWeights == null) {
            for (int i = 0; i < iArr.length; i++) {
                connectableClassifier.storeLearningVector(inputVectors[iArr[i]], outputAttrs[iArr[i]]);
            }
            return;
        }
        for (int i2 : iArr) {
            connectableClassifier.storeLearningVector(inputVectors[i2], outputAttrs[i2], instanceWeights[i2]);
        }
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableClassifier getBestTestModel() {
        return this.bestTestModel;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public ConnectableClassifier getBestValidModel() {
        return this.bestValidModel;
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public FitnessNode getBestTestModelPredefinedConfig() {
        return (FitnessNode) this.bestTestModel.getConfig();
    }

    @Override // game.evolution.treeEvolution.context.FitnessContextBase
    public FitnessNode getBestValidModelPredefinedConfig() {
        return (FitnessNode) this.bestValidModel.getConfig();
    }
}
