package game.trainers.gradient.numopt;

import common.MachineAccuracy;
import common.MathUtil;
import common.function.ObjectiveFunction;

/* loaded from: input_file:game/trainers/gradient/numopt/LineSearchStrongWolfe.class */
public class LineSearchStrongWolfe extends LineSearch {
    private double c1;
    private double c2;
    private double tau1;
    private double tau3;
    private int evaluateCalls;
    private double tol;
    private double[] dir;
    private double[] x0;
    private double fx0;
    private double slope0;
    private double slopeAlpha;
    private double alphaOld;
    private double fAlphaOld;
    private double slopeAlphaOld;
    private double a;
    private double b;
    private double fA;
    private double fB;
    private double slopeA;
    private double slopeB;
    private double alphaMax;
    private int maxEvaluateCalls;
    private double functionTolerance;

    public LineSearchStrongWolfe(ObjectiveFunction objectiveFunction) {
        super(objectiveFunction);
        this.c1 = 0.01d;
        this.c2 = 0.9d;
        this.tau1 = 9.0d;
        this.tau3 = 0.5d;
        this.alphaMax = Double.MAX_VALUE;
        this.maxEvaluateCalls = 1000;
        this.functionTolerance = MachineAccuracy.EPSILON;
    }

    public double getFunctionTolerance() {
        return this.functionTolerance;
    }

    public void setFunctionTolerance(double d) {
        this.functionTolerance = d;
    }

    public double getAlphaMax() {
        return this.alphaMax;
    }

    public void setAlphaMax(double d) {
        this.alphaMax = d;
    }

    public int getMaxEvaluateCalls() {
        return this.maxEvaluateCalls;
    }

    public void setMaxEvaluateCalls(int i) {
        this.maxEvaluateCalls = i;
    }

    @Override // game.trainers.gradient.numopt.LineSearch
    public double minimize(double[] dArr, double[] dArr2, double d, double[] dArr3) throws LineSearchException {
        this.xAlpha = dArr;
        this.gAlpha = dArr3;
        this.x0 = (double[]) this.xAlpha.clone();
        this.dir = dArr2;
        this.fx0 = d;
        this.slope0 = MathUtil.dotProduct(this.gAlpha, this.dir);
        if (this.slope0 >= 0.0d) {
            throw new LineSearchException("Slope not negative.");
        }
        this.alpha = this.initAlpha;
        this.fAlpha = this.fx0;
        this.slopeAlpha = this.slope0;
        this.alphaOld = 0.0d;
        this.evaluateCalls = 0;
        this.tol = this.functionTolerance / 1000.0d;
        boolean z = true;
        while (this.evaluateCalls < this.maxEvaluateCalls) {
            this.fAlphaOld = this.fAlpha;
            this.slopeAlphaOld = this.slopeAlpha;
            for (int i = 0; i < this.n; i++) {
                this.xAlpha[i] = this.x0[i] + (this.alpha * this.dir[i]);
            }
            this.fAlpha = this.func.evaluate(this.xAlpha, this.gAlpha);
            this.evaluateCalls++;
            this.slopeAlpha = MathUtil.dotProduct(this.gAlpha, this.dir);
            if (this.fAlpha > this.fx0 + (this.alpha * this.c1 * this.slope0) || (this.fAlpha >= this.fAlphaOld && !z)) {
                this.a = this.alphaOld;
                this.b = this.alpha;
                this.fA = this.fAlphaOld;
                this.fB = this.fAlpha;
                this.slopeA = this.slopeAlphaOld;
                this.slopeB = this.slopeAlpha;
                zoom();
                return this.fAlpha;
            }
            z = false;
            if (Math.abs(this.slopeAlpha) <= (-this.c2) * this.slope0) {
                return this.fAlpha;
            }
            if (this.slopeAlpha >= 0.0d) {
                this.a = this.alpha;
                this.b = this.alphaOld;
                this.fA = this.fAlpha;
                this.fB = this.fAlphaOld;
                this.slopeA = this.slopeAlpha;
                this.slopeB = this.slopeAlphaOld;
                zoom();
                return this.fAlpha;
            }
            if ((2.0d * this.alpha) - this.alphaOld < this.alphaMax) {
                double ChooseNewAlpha = ChooseNewAlpha((2.0d * this.alpha) - this.alphaOld, Math.min(this.alphaMax, this.alpha + (this.tau1 * (this.alpha - this.alphaOld))));
                this.alphaOld = this.alpha;
                this.alpha = ChooseNewAlpha;
            } else {
                this.alpha = this.alphaMax;
            }
        }
        return this.fAlpha;
    }

    private void zoom() {
        double min = Math.min(0.1d, this.c2);
        while (this.evaluateCalls < this.maxEvaluateCalls) {
            this.alpha = ChooseNewAlpha(this.a + (min * (this.b - this.a)), this.b - (this.tau3 * (this.b - this.a)));
            if (Math.abs((this.alpha - this.a) * this.slopeA) <= this.tol) {
                return;
            }
            for (int i = 0; i < this.n; i++) {
                this.xAlpha[i] = this.x0[i] + (this.alpha * this.dir[i]);
            }
            this.fAlpha = this.func.evaluate(this.xAlpha, this.gAlpha);
            this.evaluateCalls++;
            this.slopeAlpha = MathUtil.dotProduct(this.gAlpha, this.dir);
            double d = this.a;
            double d2 = this.b;
            double d3 = this.fA;
            double d4 = this.fB;
            double d5 = this.slopeA;
            double d6 = this.slopeB;
            if (this.fAlpha > this.fx0 + (this.alpha * this.c1 * this.slope0) || this.fAlpha >= this.fA) {
                this.a = d;
                this.b = this.alpha;
                this.fA = d3;
                this.fB = this.fAlpha;
                this.slopeA = d5;
                this.slopeB = this.slopeAlpha;
            } else {
                if (Math.abs(this.slopeAlpha) <= (-this.c2) * this.slope0) {
                    return;
                }
                this.a = this.alpha;
                this.fA = this.fAlpha;
                this.slopeA = this.slopeAlpha;
                if (this.slopeAlpha * (this.b - this.a) >= 0.0d) {
                    this.b = d;
                    this.fB = d3;
                    this.slopeB = d5;
                } else {
                    this.b = d2;
                    this.fB = d4;
                    this.slopeB = d6;
                }
            }
            if (Math.abs(this.b - this.a) < MachineAccuracy.EPSILON) {
                return;
            }
        }
    }

    private double ChooseNewAlpha(double d, double d2) {
        double d3;
        double d4;
        double interpolateAndMinimize = CubicInterpolation.interpolateAndMinimize(this.alphaOld, this.alpha, this.fAlphaOld, this.fAlpha, this.slopeAlphaOld, this.slopeAlpha);
        double interpolate = CubicInterpolation.interpolate(this.alphaOld, this.alpha, this.fAlphaOld, this.fAlpha, this.slopeAlphaOld, this.slopeAlpha, d);
        double interpolate2 = CubicInterpolation.interpolate(this.alphaOld, this.alpha, this.fAlphaOld, this.fAlpha, this.slopeAlphaOld, this.slopeAlpha, d2);
        if (d > d2) {
            d = d2;
            d2 = d;
            interpolate = interpolate2;
            interpolate2 = interpolate;
        }
        if (interpolate < interpolate2) {
            d3 = d;
            d4 = interpolate;
        } else {
            d3 = d2;
            d4 = interpolate2;
        }
        return (d > interpolateAndMinimize || interpolateAndMinimize > d2 || CubicInterpolation.interpolate(this.alphaOld, this.alpha, this.fAlphaOld, this.fAlpha, this.slopeAlphaOld, this.slopeAlpha, interpolateAndMinimize) > d4) ? d3 : interpolateAndMinimize;
    }

    private void checkAndPrintGoodSolution() {
        if (!LineSearches.checkWolfeSuffcientDecrease(this.fx0, this.slope0, this.alpha, this.fAlpha, this.c1)) {
            System.out.println("Wolfe (1) not satisfied");
        }
        if (!LineSearches.checkStrongWolfeCurvature(this.slope0, this.alpha, this.c2)) {
            System.out.println("Wolfe (2) not satisfied");
        }
        if (this.fAlpha >= this.fx0) {
            System.out.println("fAlpha >= fx0");
            if (this.fAlpha > this.fx0) {
                System.out.println("fAlpha > fx0");
            }
        }
    }
}
