package game.trainers;

import configuration.game.trainers.PALConjugateGradientConfig;
import game.pal.math.ConjugateGradientSearch;
import game.pal.math.MFWithGradient;
import game.pal.math.NumericalDerivative;
import game.utils.GlobalRandom;

/* loaded from: input_file:game/trainers/PALConjugateGradientTrainer.class */
public class PALConjugateGradientTrainer extends Trainer implements MFWithGradient {
    private static final long serialVersionUID = 1;
    private transient GradientTrainable unit;
    private transient ConjugateGradientSearch cgs;
    private transient double tolfx;
    private transient double tolx;
    private double lastError = -1.0d;
    double firstError = -1.0d;
    private static long teachCalls = 0;
    private static long errorCalls = 0;
    private static long gradCalls = 0;
    private static boolean verbose = false;

    @Override // game.trainers.Trainer
    public void init(GradientTrainable gradientTrainable, Object obj) {
        super.init(gradientTrainable, obj);
        PALConjugateGradientConfig pALConjugateGradientConfig = (PALConjugateGradientConfig) obj;
        pALConjugateGradientConfig.setValues();
        this.tolfx = pALConjugateGradientConfig.getTolfx();
        this.tolx = pALConjugateGradientConfig.getTolx();
        this.cgs = new ConjugateGradientSearch(pALConjugateGradientConfig.getMethod(), pALConjugateGradientConfig.getMaxIterations());
    }

    @Override // game.trainers.Trainer
    public void setCoef(int i) {
        super.setCoef(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.best[i2] = GlobalRandom.getInstance().getSmallDouble();
        }
    }

    @Override // game.trainers.Trainer
    public String getMethodName() {
        return "PAL : Conjugate Gradient";
    }

    @Override // game.trainers.Trainer, game.configuration.Configurable
    public Class getConfigClass() {
        return PALConjugateGradientConfig.class;
    }

    @Override // game.trainers.Trainer
    public boolean allowedByDefault() {
        return false;
    }

    @Override // game.trainers.Trainer
    public void teach() {
        teachCalls += serialVersionUID;
        this.cgs.optimize(this, this.best, this.tolfx, this.tolx);
        if (!verbose || teachCalls < 100) {
            return;
        }
        System.out.println("CG - error : " + errorCalls + " and gradient : " + gradCalls + " times");
        teachCalls = 0L;
        errorCalls = 0L;
        gradCalls = 0L;
    }

    @Override // game.trainers.Trainer
    public double[] getBest() {
        return this.best;
    }

    @Override // game.pal.math.MultivariateFunction
    public double evaluate(double[] dArr) {
        errorCalls += serialVersionUID;
        return getAndRecordError(dArr, 10, 100, true);
    }

    @Override // game.pal.math.MFWithGradient
    public double evaluate(double[] dArr, double[] dArr2) {
        computeGradient(dArr, dArr2);
        return evaluate(dArr);
    }

    @Override // game.trainers.Trainer
    public double getError(double[] dArr) {
        double error = this.unit.getError(dArr);
        if (error < this.lastError) {
            System.arraycopy(dArr, 0, this.best, 0, this.coefficients);
        }
        return error;
    }

    @Override // game.pal.math.MultivariateFunction
    public double getLowerBound(int i) {
        return -1000.0d;
    }

    @Override // game.pal.math.MultivariateFunction
    public double getUpperBound(int i) {
        return 1000.0d;
    }

    @Override // game.pal.math.MultivariateFunction
    public int getNumArguments() {
        return this.coefficients;
    }

    @Override // game.pal.math.MFWithGradient
    public void computeGradient(double[] dArr, double[] dArr2) {
        gradCalls += serialVersionUID;
        dArr2[0] = Double.MAX_VALUE;
        this.unit.gradient(dArr, dArr2);
        if (dArr2[0] == Double.MAX_VALUE) {
            NumericalDerivative.gradient(this, dArr, dArr2);
        }
    }

    @Override // game.trainers.Trainer
    public boolean isExecutableInParallelMode() {
        return true;
    }
}
