package game.models.ensemble;

import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;

/* loaded from: input_file:game/models/ensemble/ModelBoosting.class */
public abstract class ModelBoosting extends ModelInstanceWeights {
    protected double[] modelWeights = new double[0];

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // game.models.ensemble.ModelInstanceWeights
    public void initLearnArrays(double[] dArr) {
        super.initLearnArrays(dArr);
        if (this.modelWeights.length != this.modelsNumber) {
            this.modelWeights = new double[this.modelsNumber];
        }
    }

    private void normalizeModelWeights() {
        double d = 0.0d;
        for (int i = 0; i < this.modelsNumber; i++) {
            d += this.modelWeights[i];
        }
        if (d == 0.0d) {
            for (int i2 = 0; i2 < this.modelsNumber; i2++) {
                this.modelWeights[i2] = 1 / this.modelsNumber;
            }
            return;
        }
        if (d != Double.POSITIVE_INFINITY) {
            for (int i3 = 0; i3 < this.modelsNumber; i3++) {
                this.modelWeights[i3] = this.modelWeights[i3] / d;
            }
            return;
        }
        for (int i4 = 0; i4 < this.modelsNumber; i4++) {
            if (this.modelWeights[i4] == Double.POSITIVE_INFINITY) {
                this.modelWeights[i4] = 1.0d;
            } else {
                this.modelWeights[i4] = 0.0d;
            }
        }
        normalizeModelWeights();
    }

    @Override // game.models.ensemble.ModelInstanceWeights
    protected abstract void modifyWeights(double[] dArr, int i);

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ModelLearnable
    public void learn() {
        super.learn();
        normalizeModelWeights();
    }

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ensemble.ModelEnsemble
    public void relearn() {
        super.relearn();
        normalizeModelWeights();
    }

    @Override // game.models.ensemble.ModelInstanceWeights, game.models.ensemble.ModelEnsemble
    public void learn(int i) {
        super.learn(i);
        normalizeModelWeights();
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        if (!this.learned) {
            learn();
        }
        double d = 0.0d;
        for (int i = 0; i < this.modelsNumber; i++) {
            d += this.ensembleModels.get(i).getOutput(dArr) * this.modelWeights[i];
        }
        return d;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String[] successorsCode = getSuccessorsCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("BoostingEnsemble.h\"\n");
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        CCodeUtils.getCRegModelArray(successorsCode, "models", sb);
        CCodeUtils.convertArray(this.modelWeights, "modelWeights", sb);
        sb.append("return boostingEnsembleOutput<").append(this.modelsNumber).append(",").append(this.inputsNumber).append(">(input,models,modelWeights);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }
}
