package game.models.ensemble;

import configuration.CfgTemplate;
import configuration.models.ModelConfig;
import configuration.models.ensemble.EvolvableEnsembleModelConfig;
import configuration.models.game.CfgGame;
import game.evolution.Dna;
import game.evolution.EvolutionContext;
import game.evolution.EvolutionStrategy;
import game.evolution.Genome;
import game.evolution.ObjectEvolvable;
import game.models.Model;
import game.models.ModelLearnable;
import game.models.evolution.EvolvableModel;
import game.models.evolution.ModelEvolvable;
import game.utils.MyRandom;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

/* loaded from: input_file:game/models/ensemble/ModelEvolvableEnsemble.class */
public class ModelEvolvableEnsemble extends ModelEnsembleBase implements EvolutionContext {
    static Logger logger = Logger.getLogger(ModelEvolvableEnsemble.class);
    protected MyRandom rndGenerator;
    protected int generations;
    EvolutionStrategy evolution;
    int maxInputs;
    int learnValidRatio;
    int learnVectNum;
    int validVectNum;
    boolean genomeDistEnabled;
    boolean corrDistEnabled;
    boolean outputDistEnabled;

    @Override // game.models.ensemble.ModelEnsembleBase, game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        this.learnValidRatio = ((EvolvableEnsembleModelConfig) modelConfig).getLearnValidRatio();
        this.generations = ((EvolvableEnsembleModelConfig) modelConfig).getGenerations();
        this.modelsNumber = ((EvolvableEnsembleModelConfig) modelConfig).getModelsNumber();
        this.genomeDistEnabled = ((EvolvableEnsembleModelConfig) modelConfig).isGenoDistanceEnabled();
        this.corrDistEnabled = ((EvolvableEnsembleModelConfig) modelConfig).isCorrelationDistanceEnabled();
        this.outputDistEnabled = ((EvolvableEnsembleModelConfig) modelConfig).isOutputsDistanceEnabled();
        try {
            this.evolution = (EvolutionStrategy) ((EvolvableEnsembleModelConfig) modelConfig).getEvolutionStrategyClass().newInstance();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e2) {
            e2.printStackTrace();
        }
        this.evolution.init(((EvolvableEnsembleModelConfig) modelConfig).getEvolutionStrategyConfig(), this);
        this.maxInputs = modelConfig.getMaxInputsNumber();
        this.maxLearningVectors = modelConfig.getMaxLearningVectors();
        this.targetVariable = modelConfig.getTargetVariable();
        this.name = modelConfig.getName();
        this.learning_vectors = 0;
    }

    @Override // game.models.ensemble.ModelEnsembleBase
    protected void addBaseModel(int i, CfgTemplate cfgTemplate) {
        EvolvableModel evolvableModel = new EvolvableModel();
        evolvableModel.init(cfgTemplate);
        addModel(i, evolvableModel);
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void setInputsNumber(int i) {
        super.setInputsNumber(i);
        Random random = new Random();
        Iterator<? extends Model> it = this.ensembleModels.iterator();
        while (it.hasNext()) {
            EvolvableModel evolvableModel = (EvolvableModel) it.next();
            int nextInt = random.nextInt(i) + 1;
            if (this.maxInputs > 0) {
                nextInt = this.maxInputs;
            }
            Genome genome = new Genome(i, nextInt);
            genome.initializeRandomly();
            evolvableModel.setDna(genome);
        }
    }

    @Override // game.evolution.EvolutionContext
    public <T extends ObjectEvolvable> void computeFitness(ArrayList<T> arrayList) {
        Iterator<T> it = arrayList.iterator();
        while (it.hasNext()) {
            T next = it.next();
            if (next.getFitness() <= 0.0d) {
                ModelLearnable modelLearnable = (ModelLearnable) next;
                this.rndGenerator.resetRandom();
                for (int i = 0; i < this.learnVectNum; i++) {
                    int randomLearningVector = this.rndGenerator.getRandomLearningVector();
                    if (i < modelLearnable.getMaxLearningVectors()) {
                        modelLearnable.storeLearningVector(this.inputVect[randomLearningVector], this.target[randomLearningVector]);
                    }
                }
                modelLearnable.learn();
                double d = 0.0d;
                this.rndGenerator.resetRandom();
                int i2 = this.validVectNum > 500 ? CfgGame.MAX_UNITS_USED : this.validVectNum;
                for (int i3 = 0; i3 < i2; i3++) {
                    int randomTestingVector = this.rndGenerator.getRandomTestingVector();
                    double output = ((ModelEvolvable) next).getOutput(this.inputVect[randomTestingVector]);
                    d += (output - this.target[randomTestingVector]) * (output - this.target[randomTestingVector]);
                }
                next.setFitness(1.0d / (1.0d + d));
            }
        }
    }

    @Override // game.evolution.EvolutionContext
    public <T extends ObjectEvolvable> T produceOffspring(Dna dna) {
        EvolvableModel evolvableModel = new EvolvableModel();
        evolvableModel.init(this.baseModelsCfg.get(new Random().nextInt(this.baseModelsCfg.size())));
        evolvableModel.setDna(dna);
        logger.trace("New model, class:" + evolvableModel.getClass().getName() + ", configuration:" + evolvableModel.getConfigClass());
        logger.trace("Model genome (input connections):" + dna.toString());
        return evolvableModel;
    }

    @Override // game.evolution.EvolutionContext
    public <T extends ObjectEvolvable> T produceRandomOffspring() {
        Random random = new Random();
        int nextInt = random.nextInt(this.inputsNumber) + 1;
        if (this.maxInputs > 0) {
            nextInt = this.maxInputs;
        }
        Genome genome = new Genome(this.inputsNumber, nextInt);
        genome.initializeRandomly();
        EvolvableModel evolvableModel = new EvolvableModel();
        evolvableModel.init(this.baseModelsCfg.get(random.nextInt(this.baseModelsCfg.size())));
        evolvableModel.setDna(genome);
        return evolvableModel;
    }

    @Override // game.evolution.EvolutionContext
    public double getDistance(ObjectEvolvable objectEvolvable, ObjectEvolvable objectEvolvable2) {
        double d = 0.0d;
        if (objectEvolvable.equals(objectEvolvable2)) {
            return 0.0d;
        }
        if (this.genomeDistEnabled) {
            d = 0.0d + objectEvolvable.getDna().distance(objectEvolvable2.getDna());
        }
        if (this.corrDistEnabled || this.outputDistEnabled) {
            d += 1000.0d * computeDistanceOfOuputs((ModelEvolvable) objectEvolvable, (ModelEvolvable) objectEvolvable2);
        }
        return d;
    }

    private double computeDistanceOfOuputs(ModelEvolvable modelEvolvable, ModelEvolvable modelEvolvable2) {
        double d = 0.0d;
        this.rndGenerator.resetRandom();
        int i = this.validVectNum > 500 ? CfgGame.MAX_UNITS_USED : this.validVectNum;
        for (int i2 = 0; i2 < i; i2++) {
            int randomTestingVector = this.rndGenerator.getRandomTestingVector();
            double output = modelEvolvable.getOutput(this.inputVect[randomTestingVector]);
            double output2 = modelEvolvable2.getOutput(this.inputVect[randomTestingVector]);
            if (this.corrDistEnabled) {
                double d2 = (output - this.target[randomTestingVector]) * (output2 - this.target[randomTestingVector]);
                if (d2 < 0.0d) {
                    d += -d2;
                }
            }
            if (this.outputDistEnabled) {
                d += (output - output2) * (output - output2);
            }
        }
        return d / i;
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        this.learnVectNum = (this.learning_vectors * this.learnValidRatio) / 100;
        this.validVectNum = this.learning_vectors - this.learnVectNum;
        prepareLearningAndValidationData();
        logger.info("Data prepared");
        logger.info("Initial population generated");
        logger.info("Evolution starts");
        logger.info("Number of generations: " + this.generations);
        computeFitness(this.ensembleModels);
        for (int i = 0; i < this.generations; i++) {
            logger.debug("Generation " + i);
            this.ensembleModels = this.evolution.newGeneration(this.ensembleModels);
        }
        this.ensembleModels = this.evolution.getFinalPopulation(this.ensembleModels);
        logger.info("Evolution ends, selecting survivals");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareLearningAndValidationData() {
        this.rndGenerator = new MyRandom(this.learning_vectors);
        this.rndGenerator.generateLearningAndTestingSet(this.validVectNum);
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        double d = 0.0d;
        Iterator<? extends Model> it = this.ensembleModels.iterator();
        while (it.hasNext()) {
            d += it.next().getOutput(dArr);
        }
        return d / this.ensembleModels.size();
    }

    public String toEquation(String[] strArr) {
        return StringUtils.EMPTY;
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return EvolvableEnsembleModelConfig.class;
    }

    @Override // game.models.ensemble.ModelEnsemble
    public void relearn() {
    }

    @Override // game.models.ensemble.ModelEnsemble
    public void learn(int i) {
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        return null;
    }
}
