package game.models.single;

import configuration.models.ModelConfig;
import configuration.models.TrainerSelectable;
import configuration.models.single.ModelSingleConfigBase;
import game.models.ModelLearnableBase;
import game.trainers.GradientTrainable;
import game.trainers.Trainer;
import game.utils.GlobalRandom;
import java.util.Random;

/* loaded from: input_file:game/models/single/SingleModel.class */
public abstract class SingleModel extends ModelLearnableBase implements GradientTrainable {
    protected double[] a;
    protected transient double[] gradient;
    protected transient double error;
    protected int coef;
    transient boolean[] inValidationSet;
    int validationPercent;
    boolean validationEnabled;
    transient Random rnd;
    protected transient Trainer trainer;

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        try {
            TrainerSelectable trainerSelectable = (TrainerSelectable) modelConfig;
            this.trainer = (Trainer) trainerSelectable.getTrainerClass().newInstance();
            super.init(trainerSelectable);
            if (trainerSelectable.getTrainerCfg() == null) {
                this.trainer.init(this, this.trainer.getConfigClass().newInstance());
            } else {
                this.trainer.init(this, trainerSelectable.getTrainerCfg());
            }
            this.trainer.setCoef(this.coef);
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e2) {
            e2.printStackTrace();
        }
        ModelSingleConfigBase modelSingleConfigBase = (ModelSingleConfigBase) modelConfig;
        this.validationPercent = modelSingleConfigBase.getValidationPercent();
        this.validationEnabled = modelSingleConfigBase.isValidationEnabled();
        this.rnd = new Random();
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void setInputsNumber(int i) {
        super.setInputsNumber(i);
        if (this.trainer != null) {
            this.trainer.setCoef(this.coef);
        }
    }

    @Override // game.trainers.GradientTrainable
    public Trainer getTrainer() {
        return this.trainer;
    }

    @Override // game.trainers.GradientTrainable
    public void setTrainer(Trainer trainer) {
        this.trainer = trainer;
        if (this.trainer != null) {
            this.trainedBy = this.trainer.getMethodName();
        }
    }

    public double[] computeStartingPoint() {
        double[] dArr = new double[this.coef];
        for (int i = 0; i < this.coef; i++) {
            dArr[i] = GlobalRandom.getInstance().getSmallDouble();
        }
        return dArr;
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        if (this.validationEnabled) {
            initializeValidationSet();
        }
        this.trainer.setStartingPoint(computeStartingPoint());
        this.trainer.teach();
        for (int i = 0; i < this.coef; i++) {
            double best = this.trainer.getBest(i);
            if (!Double.isNaN(best) && !Double.isInfinite(best)) {
                this.a[i] = best;
            }
        }
        postLearnActions();
    }

    public void initializeValidationSet() {
        this.inValidationSet = new boolean[this.learning_vectors];
        for (int i = 0; i < this.learning_vectors; i++) {
            this.inValidationSet[i] = false;
        }
        int i2 = (int) (this.learning_vectors * (this.validationPercent / 100.0d));
        for (int i3 = 0; i3 < i2; i3++) {
            setRandomValidationVector();
        }
    }

    private void setRandomValidationVector() {
        int nextInt = this.rnd.nextInt(this.learning_vectors);
        if (this.inValidationSet[nextInt]) {
            this.inValidationSet[nextInt] = false;
            do {
                nextInt = (nextInt + 1) % this.learning_vectors;
            } while (this.inValidationSet[nextInt]);
            this.inValidationSet[nextInt] = true;
        }
        this.inValidationSet[nextInt] = true;
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        return getOutputWith(dArr, this.a);
    }

    protected double getOutputWith(double[] dArr, double[] dArr2) {
        return Double.NaN;
    }

    public double getOutputTo(int i) {
        return getOutput(this.inputVect[i]);
    }

    public boolean computeErrorAndGradient(double[] dArr) {
        return false;
    }

    @Override // game.trainers.GradientTrainable
    public double[] getGradient() {
        return this.gradient;
    }

    @Override // game.trainers.GradientTrainable
    public double getError() {
        return this.error;
    }

    public synchronized double getTrainingOrValidationError(double[] dArr, boolean z) {
        double d = 0.0d;
        if (!this.validationEnabled && z) {
            return -1.0d;
        }
        for (int i = 0; i < this.learning_vectors; i++) {
            if (!this.validationEnabled || this.inValidationSet[i] == z) {
                double outputWith = getOutputWith(this.inputVect[i], dArr) - this.target[i];
                d += outputWith * outputWith;
            }
        }
        return d;
    }

    public double getError(double[] dArr) {
        return getTrainingOrValidationError(dArr, false);
    }

    public double getValidationError(double[] dArr) {
        return getTrainingOrValidationError(dArr, true);
    }

    @Override // game.trainers.GradientTrainable
    public boolean hessian(double[] dArr, double[][] dArr2) {
        return false;
    }

    public boolean gradient(double[] dArr, double[] dArr2) {
        return false;
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void deleteLearningVectors() {
        this.gradient = null;
        super.deleteLearningVectors();
    }

    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        return null;
    }
}
