package game.models.ensemble;

import configuration.models.ensemble.StackingModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.models.ModelLearnable;
import game.utils.MyRandom;

/* loaded from: input_file:game/models/ensemble/ModelStacking.class */
public class ModelStacking extends ModelEnsembleBase {
    private void prepareData(ModelLearnable modelLearnable) {
        modelLearnable.resetLearningData();
        MyRandom myRandom = new MyRandom(this.learning_vectors);
        int maxLearningVectors = modelLearnable.getMaxLearningVectors() > this.learning_vectors ? this.learning_vectors : modelLearnable.getMaxLearningVectors();
        for (int i = 0; i < maxLearningVectors; i++) {
            int random = myRandom.getRandom(this.learning_vectors);
            modelLearnable.storeLearningVector(this.inputVect[random], this.target[random]);
        }
    }

    private void learnMetamodel(ModelLearnable modelLearnable) {
        modelLearnable.resetLearningData();
        double[] dArr = new double[this.modelsNumber - 1];
        MyRandom myRandom = new MyRandom(this.learning_vectors);
        int maxLearningVectors = modelLearnable.getMaxLearningVectors() > this.learning_vectors ? this.learning_vectors : modelLearnable.getMaxLearningVectors();
        for (int i = 0; i < maxLearningVectors; i++) {
            int random = myRandom.getRandom(this.learning_vectors);
            for (int i2 = 0; i2 < this.modelsNumber - 1; i2++) {
                dArr[i2] = this.ensembleModels.get(i2).getOutput(this.inputVect[random]);
            }
            modelLearnable.storeLearningVector(dArr, this.target[random]);
        }
        modelLearnable.learn();
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        for (int i = 0; i < this.modelsNumber - 1; i++) {
            if (this.ensembleModels.get(i) instanceof ModelLearnable) {
                ModelLearnable modelLearnable = (ModelLearnable) this.ensembleModels.get(i);
                if (!modelLearnable.isLearned()) {
                    prepareData(modelLearnable);
                    modelLearnable.learn();
                }
            }
        }
        if (this.ensembleModels.get(this.modelsNumber - 1) instanceof ModelLearnable) {
            ModelLearnable modelLearnable2 = (ModelLearnable) this.ensembleModels.get(this.modelsNumber - 1);
            if (!modelLearnable2.isLearned()) {
                learnMetamodel(modelLearnable2);
            }
        }
        postLearnActions();
    }

    @Override // game.models.ensemble.ModelEnsemble
    public void relearn() {
        for (int i = 0; i < this.modelsNumber - 1; i++) {
            if (this.ensembleModels.get(i) instanceof ModelLearnable) {
                ModelLearnable modelLearnable = (ModelLearnable) this.ensembleModels.get(i);
                prepareData(modelLearnable);
                relearnModel(modelLearnable);
            }
        }
        if (this.ensembleModels.get(this.modelsNumber - 1) instanceof ModelLearnable) {
            learnMetamodel((ModelLearnable) this.ensembleModels.get(this.modelsNumber - 1));
        }
        postLearnActions();
    }

    @Override // game.models.ensemble.ModelEnsemble
    public void learn(int i) {
        for (int i2 = 0; i2 < this.modelsNumber - 1; i2++) {
            if (this.ensembleModels.get(i2) instanceof ModelLearnable) {
                ModelLearnable modelLearnable = (ModelLearnable) this.ensembleModels.get(i2);
                if (!modelLearnable.isLearned()) {
                    prepareData(modelLearnable);
                    modelLearnable.learn();
                } else if (i2 == i) {
                    prepareData(modelLearnable);
                    relearnModel(modelLearnable);
                }
            }
        }
        learnMetamodel((ModelLearnable) this.ensembleModels.get(this.modelsNumber - 1));
        this.learned = true;
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        if (!this.learned) {
            learn();
        }
        double[] dArr2 = new double[this.modelsNumber - 1];
        for (int i = 0; i < this.modelsNumber - 1; i++) {
            dArr2[i] = this.ensembleModels.get(i).getOutput(dArr);
        }
        return this.ensembleModels.get(this.modelsNumber - 1).getOutput(dArr2);
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return StackingModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String[] successorsCode = getSuccessorsCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("StackingEnsemble.h\"\n");
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        CCodeUtils.getCRegModelArray(successorsCode, "models", sb);
        sb.append("return stackingEnsembleOutput<").append(this.modelsNumber).append(",").append(this.inputsNumber).append(">(input,models);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }
}
