package de.tu_dortmund.sfb876.optimplugin.optimizers;

import com.rapidminer.optimplugin.OptimPluginUtil;
import de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction;
import de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/optimizers/StochasticGradientDescent.class */
public class StochasticGradientDescent implements Optimizer {
    private RealMatrix X;
    private RealVector y;
    private RealVector theta;
    private double gamma;
    private double alpha;
    private int numEpochs;
    private CostFunction costFunction;
    private Regularizer regularizer;
    private boolean isPlotConvergence;
    private ArrayRealVector runningAvgTheta;
    private boolean isDecreasingStepSize;
    private ArrayList<Double> costHistory = new ArrayList<>();
    private Logger logger = LoggerFactory.getLogger(StochasticGradientDescent.class);

    public StochasticGradientDescent(RealMatrix realMatrix, RealVector realVector, RealVector realVector2, CostFunction costFunction, Regularizer regularizer, double d, int i, boolean z, boolean z2) {
        this.isPlotConvergence = true;
        this.X = realMatrix;
        this.y = realVector;
        this.theta = realVector2;
        this.gamma = d;
        this.costFunction = costFunction;
        this.regularizer = regularizer;
        this.numEpochs = i;
        this.isPlotConvergence = z;
        this.runningAvgTheta = new ArrayRealVector(this.theta.getDimension(), CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.isDecreasingStepSize = z2;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public ArrayList<Double> getCostHistory() {
        return this.costHistory;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public RealVector optimize() {
        logConfiguration();
        ConvergencePlotter convergencePlotter = null;
        int rowDimension = this.X.getRowDimension();
        if (this.isPlotConvergence) {
            convergencePlotter = new ConvergencePlotter(false, false, this.numEpochs * rowDimension, 3);
            new Thread(convergencePlotter, "Plot Refresh Thread").start();
        }
        int i = 0;
        for (int i2 = 1; i2 <= this.numEpochs; i2++) {
            OptimPluginUtil.shuffleRows(this.X, this.y, new Random());
            double d = 0.0d;
            for (int i3 = 0; i3 < rowDimension; i3++) {
                i++;
                if (this.isDecreasingStepSize) {
                    this.alpha = this.gamma / Math.sqrt(i);
                } else {
                    this.alpha = this.gamma / Math.sqrt(this.numEpochs * rowDimension);
                }
                RealMatrix rowMatrix = this.X.getRowMatrix(i3);
                RealVector subVector = this.y.getSubVector(i3, 1);
                this.theta = this.theta.subtract(this.costFunction.getGradient(rowMatrix, subVector, this.theta).mapMultiply(this.alpha).add(this.regularizer.getGradient(this.theta)));
                this.runningAvgTheta = this.runningAvgTheta.add(this.theta);
                d += this.costFunction.computeCost(rowMatrix, subVector, this.runningAvgTheta.mapDivide(i));
                if (Integer.valueOf(i).intValue() % (100 > rowDimension ? rowDimension : 100) == 0) {
                    double d2 = d / (100 > rowDimension ? rowDimension : 100);
                    this.costHistory.add(Double.valueOf(d2));
                    if (this.isPlotConvergence) {
                        this.logger.info("iteration: {} cost: {}", Integer.valueOf(i), Double.valueOf(d2));
                        convergencePlotter.addNonSmoothConvergenceData(i, d2);
                    }
                    d = 0.0d;
                }
            }
        }
        if (this.isPlotConvergence) {
            convergencePlotter.isStopRefresh = true;
        }
        int i4 = 0;
        for (double d3 : this.runningAvgTheta.mapDivide(i).toArray()) {
            if (d3 != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                i4++;
            }
        }
        this.logger.info("NonZeros :{}", Integer.valueOf(i4));
        return this.runningAvgTheta.mapDivide(i);
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public void logConfiguration() {
        Logger logger = LoggerFactory.getLogger(getClass());
        logger.info("training data matrix Dimension: {} x {}", Integer.valueOf(this.X.getRowDimension()), Integer.valueOf(this.X.getColumnDimension()));
        logger.info("loss function: {}", this.costFunction.getClass().getCanonicalName());
        logger.info("Regularizer: {}", this.regularizer.getClass().getCanonicalName());
        this.regularizer.logConfiguration();
        logger.info("No of epochs: {}", Integer.valueOf(this.numEpochs));
        logger.info("isDecreasing Stepsize: {}", Boolean.valueOf(this.isDecreasingStepSize));
        if (this.isDecreasingStepSize) {
            logger.info("stepsize: {}/sqrt(t)", Double.valueOf(this.gamma));
        } else {
            logger.info("step size: {}", Double.valueOf(this.alpha));
        }
    }
}
