package game.trainers;

import configuration.game.trainers.PALDifferentialEvolutionConfig;
import game.pal.math.DifferentialEvolution;
import game.pal.math.MultivariateFunction;
import game.utils.GlobalRandom;

/* loaded from: input_file:game/trainers/PALDifferentialEvolutionTrainer.class */
public class PALDifferentialEvolutionTrainer extends Trainer implements MultivariateFunction {
    private static final long serialVersionUID = 1;
    private transient GradientTrainable unit;
    private int coefficients;
    private double[] best;
    private double errorBestSoFar;
    private transient DifferentialEvolution de;
    double lastError = -1.0d;
    double firstError = -1.0d;
    int cnt = 0;
    private double tolfx = 0.1d;
    private double tolx = 0.1d;

    @Override // game.trainers.Trainer
    public void init(GradientTrainable gradientTrainable, Object obj) {
        super.init(gradientTrainable, obj);
    }

    @Override // game.trainers.Trainer
    public void setCoef(int i) {
        super.setCoef(i);
        this.de = new DifferentialEvolution(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.best[i2] = GlobalRandom.getInstance().getSmallDouble();
        }
    }

    @Override // game.trainers.Trainer
    public String getMethodName() {
        return "PAL: Differential Evolution";
    }

    @Override // game.trainers.Trainer, game.configuration.Configurable
    public Class getConfigClass() {
        return PALDifferentialEvolutionConfig.class;
    }

    @Override // game.trainers.Trainer
    public boolean allowedByDefault() {
        return false;
    }

    @Override // game.trainers.Trainer
    public void teach() {
        this.de.optimize(this, this.best, this.tolfx, this.tolx);
    }

    @Override // game.trainers.Trainer
    public double[] getBest() {
        return this.best;
    }

    @Override // game.trainers.Trainer
    public double getBest(int i) {
        return this.best[i];
    }

    @Override // game.pal.math.MultivariateFunction
    public double evaluate(double[] dArr) {
        return getAndRecordError(dArr, 10, 100, true);
    }

    @Override // game.trainers.Trainer
    public double getError(double[] dArr) {
        double error = this.unit.getError(dArr);
        if (error < this.errorBestSoFar) {
            System.arraycopy(dArr, 0, this.best, 0, this.coefficients);
        }
        return error;
    }

    @Override // game.pal.math.MultivariateFunction
    public int getNumArguments() {
        return this.coefficients;
    }

    @Override // game.pal.math.MultivariateFunction
    public double getLowerBound(int i) {
        return -10000.0d;
    }

    @Override // game.pal.math.MultivariateFunction
    public double getUpperBound(int i) {
        return 10000.0d;
    }

    @Override // game.trainers.Trainer
    public boolean isExecutableInParallelMode() {
        return true;
    }
}
