package cc.mallet.optimize;

import cc.mallet.classify.MCMaxEntTrainer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.logging.Logger;
import salvo.jesus.graph.xml.XGMML;

/* loaded from: input_file:cc/mallet/optimize/StochasticMetaAscent.class */
public class StochasticMetaAscent implements Optimizer.ByBatches {
    private static Logger logger = MalletLogger.getLogger(StochasticMetaAscent.class.getName());
    private final int MAX_ITER = 200;
    private final double LAMBDA = 1.0d;
    private final double TOLERANCE = 0.01d;
    private final double EPS = 1.0E-10d;
    private double mu = 0.1d;
    private int totalIterations = 0;
    private double eta_init = 0.03d;
    private boolean useHessian = true;
    private double[] gain;
    private double[] gradientTrace;
    Optimizable.ByBatchGradient maxable;

    public StochasticMetaAscent(Optimizable.ByBatchGradient byBatchGradient) {
        this.maxable = null;
        this.maxable = byBatchGradient;
    }

    public void setInitialStep(double d) {
        this.eta_init = d;
    }

    public void setMu(double d) {
        this.mu = d;
    }

    public void setUseHessian(boolean z) {
        this.useHessian = z;
    }

    @Override // cc.mallet.optimize.Optimizer.ByBatches
    public boolean optimize(int i, int[] iArr) {
        return optimize(200, i, iArr);
    }

    @Override // cc.mallet.optimize.Optimizer.ByBatches
    public boolean optimize(int i, int i2, int[] iArr) {
        int numParameters = this.maxable.getNumParameters();
        double[] dArr = new double[numParameters];
        double[] dArr2 = new double[numParameters];
        double[] dArr3 = new double[numParameters];
        if (this.gain == null) {
            System.err.println("StochasticMetaAscent: initialStep=" + this.eta_init + "  metaStep=" + this.mu);
            this.gain = new double[numParameters];
            Arrays.fill(this.gain, this.eta_init);
            this.gradientTrace = new double[numParameters];
        }
        this.maxable.getParameters(dArr);
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i4 = 0; i4 < i2; i4++) {
                logger.info("Iteration " + (this.totalIterations + i3) + ", batch " + i4 + " of " + i2);
                this.maxable.getParameters(dArr);
                double batchValue = this.maxable.getBatchValue(i4, iArr);
                d += batchValue;
                if (Double.isNaN(batchValue)) {
                    throw new IllegalArgumentException("NaN in value computation.  Probably you need to reduce initialStep or metaStep.");
                }
                this.maxable.getBatchValueGradient(dArr2, i4, iArr);
                MatrixOps.timesEquals(dArr2, -1.0d);
                if (this.useHessian) {
                    computeHessianProduct(this.maxable, dArr, i4, iArr, dArr2, this.gradientTrace, dArr3);
                }
                reportOnVec(XGMML.X_ATTRIBUTE_LITERAL, dArr);
                reportOnVec("step", this.gain);
                reportOnVec(MCMaxEntTrainer.GRADIENT_GAIN, dArr2);
                reportOnVec("trace", this.gradientTrace);
                for (int i5 = 0; i5 < numParameters; i5++) {
                    double[] dArr4 = this.gain;
                    int i6 = i5;
                    dArr4[i6] = dArr4[i6] * Math.max(0.5d, 1.0d - ((this.mu * dArr2[i5]) * this.gradientTrace[i5]));
                    int i7 = i5;
                    dArr[i7] = dArr[i7] - (this.gain[i5] * dArr2[i5]);
                    if (this.useHessian) {
                        this.gradientTrace[i5] = (1.0d * this.gradientTrace[i5]) - (this.gain[i5] * (dArr2[i5] + (1.0d * dArr3[i5])));
                    } else {
                        this.gradientTrace[i5] = (1.0d * this.gradientTrace[i5]) - (this.gain[i5] * (dArr2[i5] + (1.0d * this.gradientTrace[i5])));
                    }
                }
                this.maxable.setParameters(dArr);
                double batchValue2 = this.maxable.getBatchValue(i4, iArr);
                d2 += batchValue2;
                logger.info("StochasticMetaAscent: initial value: " + batchValue + "  final value:" + batchValue2);
            }
            logger.info("StochasticMetaDescent: Value at iteration (" + (this.totalIterations + i3) + ")= " + d2);
            if (2.0d * Math.abs(d2 - d) <= 0.01d * (Math.abs(d2) + Math.abs(d) + 1.0E-10d)) {
                logger.info("Stochastic Meta Ascent: Value difference " + Math.abs(d2 - d) + " below tolerance; saying converged.");
                this.totalIterations += i3;
                return true;
            }
        }
        this.totalIterations += i;
        return false;
    }

    private void reportOnVec(String str, double[] dArr) {
        DecimalFormat decimalFormat = new DecimalFormat("0.####");
        System.out.println("StochasticMetaAscent: " + str + ":  min " + decimalFormat.format(MatrixOps.min(dArr)) + "  max " + decimalFormat.format(MatrixOps.max(dArr)) + "  mean " + decimalFormat.format(MatrixOps.mean(dArr)) + "  2norm " + decimalFormat.format(MatrixOps.twoNorm(dArr)) + "  abs-norm " + decimalFormat.format(MatrixOps.absNorm(dArr)));
    }

    private void computeHessianProduct(Optimizable.ByBatchGradient byBatchGradient, double[] dArr, int i, int[] iArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        int numParameters = byBatchGradient.getNumParameters();
        double[] dArr5 = new double[numParameters];
        double[] dArr6 = new double[numParameters];
        System.arraycopy(dArr, 0, dArr6, 0, numParameters);
        MatrixOps.plusEquals(dArr, dArr3, 1.0E-6d);
        byBatchGradient.setParameters(dArr);
        byBatchGradient.getBatchValueGradient(dArr5, i, iArr);
        byBatchGradient.setParameters(dArr6);
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr4[i2] = ((-dArr5[i2]) - dArr2[i2]) / 1.0E-6d;
        }
    }
}
