package de.tu_dortmund.sfb876.optimplugin.optimizers.linesearch;

import de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction;
import de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer;
import java.text.DecimalFormat;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/optimizers/linesearch/Armijo.class */
public class Armijo implements LineSearch {
    private double _alphaMin;
    private double _alphaMax;
    private double alphaMin;
    private double alphaMax;
    private double rho;
    private double c1;
    private int maxIterations;
    private double prevGoodAlpha = 0.01d;
    private Logger logger = LoggerFactory.getLogger(Armijo.class);

    public Armijo(double d, double d2, double d3, double d4, int i) {
        this._alphaMin = 1.0E-16d;
        this._alphaMax = 1.0d;
        this.rho = 0.8d;
        this.c1 = 0.2d;
        this._alphaMin = d;
        this._alphaMax = d2;
        this.rho = d3;
        this.c1 = d4;
        this.maxIterations = i;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.linesearch.LineSearch
    public double getLearningRate(CostFunction costFunction, Regularizer regularizer, RealMatrix realMatrix, RealVector realVector, RealVector realVector2, RealVector realVector3) throws Exception {
        this.alphaMax = this._alphaMax;
        this.alphaMin = this._alphaMin;
        double dotProduct = costFunction.getGradient(realMatrix, realVector, realVector2).add(regularizer.getGradient(realVector2).mapDivide(realMatrix.getRowDimension())).dotProduct(realVector3);
        if (dotProduct >= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new Exception("Supplied direction is not a descent direction");
        }
        if (this.alphaMax < this.alphaMin) {
            throw new Exception("Stepsize is very small");
        }
        int i = 0;
        while (this.alphaMax > this.alphaMin) {
            i++;
            double computeCost = costFunction.computeCost(realMatrix, realVector, realVector2.add(realVector3.mapMultiply(this.alphaMax))) + regularizer.addtoCost(realVector2.add(realVector3.mapMultiply(this.alphaMax)));
            double computeCost2 = costFunction.computeCost(realMatrix, realVector, realVector2) + regularizer.addtoCost(realVector2) + (this.alphaMax * this.c1 * dotProduct);
            DecimalFormat decimalFormat = new DecimalFormat("#.###########");
            Object[] objArr = new Object[5];
            objArr[0] = Integer.valueOf(i);
            objArr[1] = decimalFormat.format(computeCost);
            objArr[2] = decimalFormat.format(computeCost2);
            objArr[3] = decimalFormat.format(this.alphaMax);
            objArr[4] = Boolean.valueOf(computeCost <= computeCost2);
            this.logger.trace("iteration: {},firstTerm: {}, secondTerm: {}, alphaMax: {},(firstTerm <= secondTerm): {} ", objArr);
            if (computeCost <= computeCost2) {
                this.logger.debug("Armijo Learning rate: {}", decimalFormat.format(this.alphaMax));
                this.prevGoodAlpha = this.alphaMax;
                return this.alphaMax;
            }
            this.alphaMax = this.rho * this.alphaMax;
            if ((i > this.maxIterations) | (this.alphaMax < this.alphaMin)) {
                Object[] objArr2 = new Object[3];
                objArr2[0] = Boolean.valueOf(i > this.maxIterations);
                objArr2[1] = Boolean.valueOf(this.alphaMax < this.alphaMin);
                objArr2[2] = decimalFormat.format(this.prevGoodAlpha);
                this.logger.warn("Max iterations exceeded:{},alphaMax < alphaMin: {}, Learning rate: {}", objArr2);
                return this.prevGoodAlpha;
            }
        }
        return CMAESOptimizer.DEFAULT_STOPFITNESS;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.linesearch.LineSearch
    public void logConfiguration() {
        Logger logger = LoggerFactory.getLogger(Armijo.class);
        logger.info("c1: {}", Double.valueOf(this.c1));
        logger.info("alpha min: {}", Double.valueOf(this._alphaMin));
        logger.info("alpha max: {}", Double.valueOf(this._alphaMax));
        logger.info("rho: {}", Double.valueOf(this.rho));
        logger.info("Max no of line search iterations: {}", Integer.valueOf(this.maxIterations));
    }
}
