package de.tu_dortmund.sfb876.optimplugin.optimizers;

import de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction;
import de.tu_dortmund.sfb876.optimplugin.optimizers.linesearch.LineSearch;
import de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer;
import java.util.ArrayList;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/optimizers/GradientDescent.class */
public class GradientDescent extends AbstractSmoothOptimizer {
    private RealMatrix X;
    private RealVector y;
    private RealVector theta;
    private double alpha;
    private int numOfIterations;
    private ArrayList<Double> costHistory = new ArrayList<>();
    private CostFunction costFunction;
    private Regularizer regularizer;
    private LineSearch lineSearch;
    private double tolerance;
    private boolean isPlotConvergence;

    public GradientDescent(RealMatrix realMatrix, RealVector realVector, RealVector realVector2, CostFunction costFunction, Regularizer regularizer, LineSearch lineSearch, double d, int i, double d2, boolean z) {
        this.isPlotConvergence = true;
        this.X = realMatrix;
        this.y = realVector;
        this.theta = realVector2;
        this.alpha = d;
        this.costFunction = costFunction;
        this.regularizer = regularizer;
        this.numOfIterations = i;
        this.lineSearch = lineSearch;
        this.tolerance = d2;
        this.isPlotConvergence = z;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public ArrayList<Double> getCostHistory() {
        return this.costHistory;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public RealVector optimize() throws Exception {
        logConfiguration();
        Logger logger = LoggerFactory.getLogger(GradientDescent.class);
        ConvergencePlotter convergencePlotter = null;
        if (this.isPlotConvergence) {
            convergencePlotter = new ConvergencePlotter(true, false, this.numOfIterations, 3);
            new Thread(convergencePlotter, "Plot Refresh Thread").start();
        }
        this.X.getRowDimension();
        logger.debug("Initial cost: {}, {}", Double.valueOf(this.costFunction.computeCost(this.X, this.y, this.theta)));
        int i = 1;
        while (i <= this.numOfIterations) {
            RealVector mapMultiply = this.costFunction.getGradient(this.X, this.y, this.theta).add(this.regularizer.getGradient(this.theta)).mapMultiply(-1.0d);
            if (this.lineSearch != null) {
                this.alpha = this.lineSearch.getLearningRate(this.costFunction, this.regularizer, this.X, this.y, this.theta, mapMultiply);
            }
            RealVector mapMultiply2 = mapMultiply.mapMultiply(this.alpha);
            RealVector realVector = this.theta;
            this.theta = this.theta.add(mapMultiply2);
            double computeCost = this.costFunction.computeCost(this.X, this.y, this.theta);
            this.costHistory.add(Double.valueOf(computeCost));
            logger.debug("iteration,cost: {}, {}", Integer.valueOf(i), Double.valueOf(computeCost));
            Object[] checkForTermination = i == 1 ? new Object[]{Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS), false} : checkForTermination(computeCost, this.costHistory.get(i - 2).doubleValue(), this.theta, realVector, this.tolerance);
            boolean booleanValue = ((Boolean) checkForTermination[2]).booleanValue();
            Double d = (Double) checkForTermination[0];
            Double d2 = (Double) checkForTermination[1];
            if (this.isPlotConvergence) {
                convergencePlotter.addSmoothConvergenceData(i, this.alpha, computeCost, d.doubleValue(), d2.doubleValue());
            }
            if (booleanValue) {
                if (this.isPlotConvergence) {
                    convergencePlotter.isStopRefresh = true;
                }
                logger.info("stopped in iteration: {} based on tolerance", Integer.valueOf(i));
                return this.theta;
            }
            i++;
        }
        return this.theta;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public void logConfiguration() {
        Logger logger = LoggerFactory.getLogger(GradientDescent.class);
        logger.info("training data matrix Dimension: {} x {}", Integer.valueOf(this.X.getRowDimension()), Integer.valueOf(this.X.getColumnDimension()));
        logger.info("loss function: {}", this.costFunction.getClass().getSimpleName());
        logger.info("Regularizer: {}", this.regularizer.getClass().getSimpleName());
        this.regularizer.logConfiguration();
        if (this.lineSearch == null) {
            logger.info("Learning rate: {}", Double.valueOf(this.alpha));
        } else {
            logger.info("lineSearch:{}", this.lineSearch.getClass().getSimpleName());
            this.lineSearch.logConfiguration();
        }
        logger.info("No of iterations: {}", Integer.valueOf(this.numOfIterations));
        logger.info("Tolerance: {}", Double.valueOf(this.tolerance));
    }
}
