package de.tu_dortmund.sfb876.optimplugin.optimizers;

import com.rapidminer.optimplugin.OptimPluginUtil;
import de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/optimizers/L1RDAOptimizer.class */
public class L1RDAOptimizer implements Optimizer {
    private RealMatrix X;
    private RealVector y;
    private RealVector theta;
    private double lambda;
    private double gamma;
    private int numEpochs;
    private CostFunction costFunction;
    private RealVector runningAvgTheta;
    private RealVector dualAvgGradient;
    private boolean isPlotConvergence;
    private ArrayList<Double> costHistory = new ArrayList<>();
    private Logger logger = LoggerFactory.getLogger(L1RDAOptimizer.class);

    public L1RDAOptimizer(RealMatrix realMatrix, RealVector realVector, RealVector realVector2, double d, double d2, int i, CostFunction costFunction, boolean z) {
        this.isPlotConvergence = true;
        this.X = realMatrix;
        this.y = realVector;
        this.theta = realVector2;
        this.lambda = d;
        this.gamma = d2;
        this.numEpochs = i;
        this.costFunction = costFunction;
        this.dualAvgGradient = new ArrayRealVector(this.theta.getDimension(), CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.runningAvgTheta = new ArrayRealVector(this.theta.getDimension(), CMAESOptimizer.DEFAULT_STOPFITNESS);
        this.isPlotConvergence = z;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public RealVector optimize() {
        logConfiguration();
        ConvergencePlotter convergencePlotter = null;
        int rowDimension = this.X.getRowDimension();
        if (this.isPlotConvergence) {
            convergencePlotter = new ConvergencePlotter(false, false, this.numEpochs * rowDimension, 3);
            new Thread(convergencePlotter, "Plot Refresh Thread").start();
        }
        int i = 1;
        for (int i2 = 1; i2 <= this.numEpochs; i2++) {
            OptimPluginUtil.shuffleRows(this.X, this.y, new Random());
            double d = 0.0d;
            for (int i3 = 0; i3 < rowDimension; i3++) {
                RealMatrix rowMatrix = this.X.getRowMatrix(i3);
                RealVector subVector = this.y.getSubVector(i3, 1);
                this.dualAvgGradient = this.dualAvgGradient.mapMultiply((i - 1.0d) / i).add(this.costFunction.getGradient(rowMatrix, subVector, this.theta).mapMultiply(this.gamma).mapMultiply(1.0d / i));
                double sqrt = ((-1.0d) * Math.sqrt(i)) / this.gamma;
                this.theta.setEntry(0, sqrt * this.dualAvgGradient.getEntry(0));
                for (int i4 = 1; i4 < this.dualAvgGradient.getDimension(); i4++) {
                    if (Math.abs(this.dualAvgGradient.getEntry(i4)) <= this.lambda) {
                        this.theta.setEntry(i4, CMAESOptimizer.DEFAULT_STOPFITNESS);
                    } else {
                        double entry = this.dualAvgGradient.getEntry(i4);
                        this.theta.setEntry(i4, (entry - (this.lambda * Math.signum(entry))) * sqrt);
                    }
                }
                this.runningAvgTheta = this.runningAvgTheta.add(this.theta);
                d += this.costFunction.computeCost(rowMatrix, subVector, this.runningAvgTheta.mapDivide(i));
                i++;
                if (Integer.valueOf(i).intValue() % (100 > rowDimension ? rowDimension : 100) == 0) {
                    double d2 = d / (100 > rowDimension ? rowDimension : 100);
                    this.costHistory.add(Double.valueOf(d2));
                    if (this.isPlotConvergence) {
                        this.logger.info("iteration: {} cost: {}", Integer.valueOf(i), Double.valueOf(d2));
                        convergencePlotter.addNonSmoothConvergenceData(i, d2);
                    }
                    d = 0.0d;
                }
            }
        }
        if (this.isPlotConvergence) {
            convergencePlotter.isStopRefresh = true;
        }
        int i5 = 0;
        for (double d3 : this.runningAvgTheta.mapDivide(i).toArray()) {
            if (d3 != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                i5++;
            }
        }
        this.logger.info("NonZeros :{}", Integer.valueOf(i5));
        return this.runningAvgTheta.mapDivide(i);
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public ArrayList<Double> getCostHistory() {
        return this.costHistory;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.optimizers.Optimizer
    public void logConfiguration() {
        Logger logger = LoggerFactory.getLogger(L1RDAOptimizer.class);
        logger.info("training data matrix Dimension: {} x {}", Integer.valueOf(this.X.getRowDimension()), Integer.valueOf(this.X.getColumnDimension()));
        logger.info("loss function: {}", this.costFunction.getClass().getCanonicalName());
        logger.info("No of epochs: {}", Integer.valueOf(this.numEpochs));
        logger.info("gamma: {}", Double.valueOf(this.gamma));
        logger.info("lambda: {}", Double.valueOf(this.lambda));
    }
}
