package game.classifiers.single;

import configuration.CfgTemplate;
import configuration.classifiers.ClassifierConfig;
import configuration.classifiers.single.ClassifierModelConfig;
import configuration.models.ModelConfig;
import configuration.models.ensemble.BaseModelsDefinition;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.classifiers.ClassifierBase;
import game.classifiers.ClassifierSingle;
import game.evolution.treeEvolution.context.InterruptibleArrayList;
import game.models.Model;
import game.models.ModelLearnable;
import game.models.ensemble.ModelEnsemble;
import game.utils.MyRandom;
import java.util.List;

/* loaded from: input_file:game/classifiers/single/ClassifierModel.class */
public class ClassifierModel extends ClassifierBase implements ClassifierSingle {
    protected List<Model> classifierModels;
    protected int numModels;
    protected ClassifierModelConfig classifierCfg;

    private void prepareData(ModelLearnable modelLearnable) {
        modelLearnable.resetLearningData();
        MyRandom myRandom = new MyRandom(this.learning_vectors);
        int maxLearningVectors = modelLearnable.getMaxLearningVectors() > this.learning_vectors ? this.learning_vectors : modelLearnable.getMaxLearningVectors();
        for (int i = 0; i < maxLearningVectors; i++) {
            int random = myRandom.getRandom(this.learning_vectors);
            modelLearnable.storeLearningVector(this.inputVect[random], this.target[random][modelLearnable.getTargetVariable()]);
        }
    }

    private void relearnModel(ModelLearnable modelLearnable) {
        if (modelLearnable instanceof ModelEnsemble) {
            ((ModelEnsemble) modelLearnable).relearn();
        } else {
            modelLearnable.learn();
        }
    }

    private void checkLearned() {
        if (this.learned) {
            return;
        }
        for (int i = 0; i < this.numModels; i++) {
            if ((this.classifierModels.get(i) instanceof ModelLearnable) && !((ModelLearnable) this.classifierModels.get(i)).isLearned()) {
                return;
            }
        }
        this.learned = true;
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void setMaxLearningVectors(int i) {
        super.setMaxLearningVectors(i);
        for (int i2 = 0; i2 < this.numModels; i2++) {
            if (this.classifierModels.get(i2) instanceof ModelLearnable) {
                ((ModelLearnable) this.classifierModels.get(i2)).setMaxLearningVectors(i);
            }
        }
    }

    @Override // game.classifiers.ClassifierBase, game.configuration.Configurable
    public Class getConfigClass() {
        return ClassifierModelConfig.class;
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void init(ClassifierConfig classifierConfig) {
        this.classifierModels = new InterruptibleArrayList();
        this.classifierCfg = (ClassifierModelConfig) classifierConfig.mo161clone();
        super.init(classifierConfig);
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void storeLearningVector(double[] dArr, double[] dArr2) {
        super.storeLearningVector(dArr, dArr2);
        if (this.classifierCfg != null) {
            this.numModels = this.outputs;
            createClassModels(this.classifierCfg);
            this.classifierCfg = null;
        }
    }

    protected void createClassModel(int i, CfgTemplate cfgTemplate) {
        try {
            ModelLearnable modelLearnable = (ModelLearnable) cfgTemplate.getClassRef().newInstance();
            modelLearnable.init((ModelConfig) cfgTemplate);
            modelLearnable.setTargetVariable(i);
            modelLearnable.setMaxLearningVectors(this.maxLearningVectors);
            this.classifierModels.add(modelLearnable);
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e2) {
            e2.printStackTrace();
        }
    }

    protected void createClassModels(ClassifierModelConfig classifierModelConfig) {
        List<CfgTemplate> baseModelCfgs = classifierModelConfig.getBaseModelCfgs();
        MyRandom myRandom = new MyRandom(baseModelCfgs.size());
        switch (classifierModelConfig.getBaseModelsDef()) {
            case PREDEFINED:
                if (this.outputs != baseModelCfgs.size()) {
                    randomClassModelDistribution(baseModelCfgs, myRandom);
                    return;
                }
                for (int i = 0; i < this.outputs; i++) {
                    createClassModel(i, baseModelCfgs.get(i));
                }
                return;
            case RANDOM:
                randomClassModelDistribution(baseModelCfgs, myRandom);
                return;
            case UNIFORM:
                for (int i2 = 0; i2 < this.outputs; i2++) {
                    createClassModel(i2, baseModelCfgs.get(0));
                }
                return;
            case UNIFORM_RANDOM:
                int nextInt = myRandom.nextInt(baseModelCfgs.size());
                for (int i3 = 0; i3 < this.outputs; i3++) {
                    createClassModel(i3, baseModelCfgs.get(nextInt));
                }
                return;
            default:
                return;
        }
    }

    private void randomClassModelDistribution(List<CfgTemplate> list, MyRandom myRandom) {
        for (int i = 0; i < this.outputs; i++) {
            createClassModel(i, list.get(myRandom.nextInt(list.size())));
        }
    }

    @Override // game.classifiers.Classifier
    public void learn() {
        for (int i = 0; i < this.numModels; i++) {
            if (this.classifierModels.get(i) instanceof ModelLearnable) {
                ModelLearnable modelLearnable = (ModelLearnable) this.classifierModels.get(i);
                if (!modelLearnable.isLearned()) {
                    prepareData(modelLearnable);
                    modelLearnable.learn();
                }
            }
        }
        postLearnActions();
    }

    @Override // game.classifiers.Classifier
    public void relearn() {
        for (int i = 0; i < this.numModels; i++) {
            if (this.classifierModels.get(i) instanceof ModelLearnable) {
                ModelLearnable modelLearnable = (ModelLearnable) this.classifierModels.get(i);
                prepareData(modelLearnable);
                relearnModel(modelLearnable);
            }
        }
        postLearnActions();
    }

    public void learn(int i) {
        if (this.classifierModels.get(i) instanceof ModelLearnable) {
            ModelLearnable modelLearnable = (ModelLearnable) this.classifierModels.get(i);
            prepareData(modelLearnable);
            relearnModel(modelLearnable);
            checkLearned();
        }
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public ClassifierConfig getConfig() {
        if (this.classifierCfg != null) {
            return this.classifierCfg;
        }
        ClassifierModelConfig classifierModelConfig = (ClassifierModelConfig) super.getConfig();
        classifierModelConfig.setModelsNumber(this.numModels);
        classifierModelConfig.setBaseModelsDef(BaseModelsDefinition.PREDEFINED);
        for (int i = 0; i < this.numModels; i++) {
            classifierModelConfig.addBaseModelCfg(this.classifierModels.get(i).getConfig());
        }
        return classifierModelConfig;
    }

    @Override // game.classifiers.ClassifierSingle
    public Model getModel(int i) {
        return this.classifierModels.get(i);
    }

    @Override // game.classifiers.ClassifierSingle
    public Model[] getAllModels() {
        return (Model[]) this.classifierModels.toArray(new Model[0]);
    }

    @Override // game.classifiers.ClassifierSingle
    public void setModel(int i, Model model) {
        this.classifierModels.set(i, model);
        this.learned = false;
    }

    @Override // game.classifiers.ClassifierSingle
    public void addModel(int i, Model model) {
        this.classifierModels.add(i, model);
        this.numModels++;
        this.learned = false;
    }

    @Override // game.classifiers.ClassifierSingle
    public void removeModel(int i) {
        this.classifierModels.remove(i);
        this.numModels--;
        this.learned = false;
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public int getOutput(double[] dArr) {
        if (!this.learned) {
            learn();
        }
        double output = this.classifierModels.get(0).getOutput(dArr);
        int targetVariable = this.classifierModels.get(0).getTargetVariable();
        for (int i = 1; i < this.numModels; i++) {
            double output2 = this.classifierModels.get(i).getOutput(dArr);
            if (output2 > output) {
                output = output2;
                targetVariable = this.classifierModels.get(i).getTargetVariable();
            }
        }
        return targetVariable;
    }

    @Override // game.classifiers.Classifier
    public double[] getOutputProbabilities(double[] dArr) {
        if (!this.learned) {
            learn();
        }
        double[] dArr2 = new double[this.outputs];
        double d = 0.0d;
        for (int i = 0; i < this.numModels; i++) {
            int targetVariable = this.classifierModels.get(i).getTargetVariable();
            dArr2[targetVariable] = dArr2[targetVariable] + this.classifierModels.get(i).getOutput(dArr);
            if (dArr2[i] < 0.0d) {
                dArr2[i] = 0.0d;
            }
            if (dArr2[i] > 1.0d) {
                dArr2[i] = 1.0d;
            }
            d += dArr2[i];
        }
        if (d != 0.0d) {
            for (int i2 = 0; i2 < this.numModels; i2++) {
                dArr2[i2] = dArr2[i2] / d;
            }
        }
        return dArr2;
    }

    @Override // game.classifiers.ClassifierBase, game.classifiers.Classifier
    public void deleteLearningVectors() {
        super.deleteLearningVectors();
        for (int i = 0; i < this.numModels; i++) {
            if (this.classifierModels.get(i) instanceof ModelLearnable) {
                ((ModelLearnable) this.classifierModels.get(i)).deleteLearningVectors();
            }
        }
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String[] successorsCode = getSuccessorsCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        sb.append("#include \"").append(CCodeUtils.getClassificationModelPath()).append("ClassifierModel.h\"\n");
        CCodeUtils.getCClassificationHeader(uniqueFunctionName, this.inputs, sb);
        CCodeUtils.getCRegModelArray(successorsCode, "models", sb);
        sb.append("return classifierModelOutput<").append(this.inputs).append(",").append(this.outputs).append(">(input,models);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }

    protected String[] getSuccessorsCode(StringBuilder sb, StringBuilder sb2) {
        if (!this.learned) {
            learn();
        }
        String[] strArr = new String[this.numModels];
        for (int i = 0; i < this.numModels; i++) {
            strArr[i] = this.classifierModels.get(i).toCCode(sb, sb2);
        }
        return strArr;
    }
}
