package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.Serializable;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/CRFOptimizableByEntropyRegularization.class */
public class CRFOptimizableByEntropyRegularization implements Optimizable.ByGradientValue, Serializable {
    private static Logger logger;
    private int cachedValueWeightsStamp;
    private int cachedGradientWeightsStamp;
    protected CRF.Factors expectations;
    protected Transducer.Incrementor incrementor;
    protected InstanceList data;
    protected CRF crf;
    protected double scalingFactor;
    protected double cachedValue;
    protected double[] cachedGradient;
    private static final long serialVersionUID = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList instanceList, double d) {
        this.cachedValueWeightsStamp = -1;
        this.cachedGradientWeightsStamp = -1;
        this.data = instanceList;
        this.crf = crf;
        this.scalingFactor = d;
        this.expectations = new CRF.Factors(crf);
        CRF.Factors factors = this.expectations;
        factors.getClass();
        this.incrementor = new CRF.Factors.Incrementor();
        this.cachedValue = 0.0d;
        this.cachedGradient = new double[crf.getParameters().getNumFactors()];
    }

    public CRFOptimizableByEntropyRegularization(CRF crf, InstanceList instanceList) {
        this(crf, instanceList, 1.0d);
    }

    public double getScalingFactor() {
        return this.scalingFactor;
    }

    public void setScalingFactor(double d) {
        this.scalingFactor = d;
    }

    public void computeExpectations() {
        this.expectations.zero();
        for (int i = 0; i < this.data.size(); i++) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) this.data.get(i).getData();
            SumLatticeDefault sumLatticeDefault = new SumLatticeDefault((Transducer) this.crf, (Sequence) featureVectorSequence, true);
            this.cachedValue += new EntropyLattice(featureVectorSequence, sumLatticeDefault.getGammas(), sumLatticeDefault.getXis(), this.crf, this.incrementor, this.scalingFactor).getEntropy();
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.crf.getWeightsValueChangeStamp() != this.cachedValueWeightsStamp) {
            this.cachedValueWeightsStamp = this.crf.getWeightsValueChangeStamp();
            this.cachedValue = 0.0d;
            computeExpectations();
            this.cachedValue = this.scalingFactor * this.cachedValue;
            if (!$assertionsDisabled && (Double.isNaN(this.cachedValue) || Double.isInfinite(this.cachedValue))) {
                throw new AssertionError("Likelihood due to Entropy Regularization is NaN/Infinite");
            }
            logger.info("getValue() (entropy regularization) = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cachedGradientWeightsStamp != this.crf.getWeightsValueChangeStamp()) {
            this.cachedGradientWeightsStamp = this.crf.getWeightsValueChangeStamp();
            getValue();
            this.expectations.assertNotNaNOrInfinite();
            this.expectations.getParameters(this.cachedGradient);
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.crf.getParameters().getNumFactors();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.crf.getParameters().getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.crf.getParameters().setParameters(dArr);
        this.crf.weightsValueChanged();
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.crf.getParameters().getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.crf.getParameters().setParameter(i, d);
        this.crf.weightsValueChanged();
    }

    static {
        $assertionsDisabled = !CRFOptimizableByEntropyRegularization.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFOptimizableByEntropyRegularization.class.getName());
    }
}
