package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/types/ExpGain.class */
public class ExpGain extends RankedFeatureVector {
    private static Logger logger;
    boolean usingHyperbolicPrior;
    double hyperbolicSlope;
    double hyperbolicSharpness;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/types/ExpGain$Factory.class */
    public static class Factory implements RankedFeatureVector.Factory {
        LabelVector[] classifications;
        double gaussianPriorVariance;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 0;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !ExpGain.class.desiredAssertionStatus();
        }

        public Factory(LabelVector[] labelVectorArr) {
            this.gaussianPriorVariance = 10.0d;
            this.classifications = labelVectorArr;
        }

        public Factory(LabelVector[] labelVectorArr, double d) {
            this.gaussianPriorVariance = 10.0d;
            this.classifications = labelVectorArr;
            this.gaussianPriorVariance = d;
        }

        @Override // cc.mallet.types.RankedFeatureVector.Factory
        public RankedFeatureVector newRankedFeatureVector(InstanceList instanceList) {
            if ($assertionsDisabled || instanceList.getTargetAlphabet() == this.classifications[0].getAlphabet()) {
                return new ExpGain(instanceList, this.classifications, this.gaussianPriorVariance);
            }
            throw new AssertionError();
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(0);
            objectOutputStream.writeInt(this.classifications.length);
            for (int i = 0; i < this.classifications.length; i++) {
                objectOutputStream.writeObject(this.classifications[i]);
            }
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.readInt();
            int readInt = objectInputStream.readInt();
            this.classifications = new LabelVector[readInt];
            for (int i = 0; i < readInt; i++) {
                this.classifications[i] = (LabelVector) objectInputStream.readObject();
            }
        }
    }

    static {
        $assertionsDisabled = !ExpGain.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(ExpGain.class.getName());
    }

    private static double[] calcExpGains(InstanceList instanceList, LabelVector[] labelVectorArr, double d) {
        int size = instanceList.size();
        int size2 = instanceList.getTargetAlphabet().size();
        int size3 = instanceList.getDataAlphabet().size();
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        double[][] dArr = new double[size2][size3];
        double[][] dArr2 = new double[size2][size3];
        double[][] dArr3 = new double[size2][size3];
        logger.info("Starting klgains, #instances=" + size);
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < size; i++) {
            if (!$assertionsDisabled && labelVectorArr[i].getLabelAlphabet() != instanceList.getTargetAlphabet()) {
                throw new AssertionError();
            }
            Instance instance = instanceList.get(i);
            Labeling labeling = instance.getLabeling();
            FeatureVector featureVector = (FeatureVector) instance.getData();
            double d4 = 0.0d;
            for (int i2 = 0; i2 < size2; i2++) {
                double value = labeling.value(i2);
                double value2 = labelVectorArr[i].value(i2);
                d2 += value;
                d3 += value2;
                d4 += value2;
                if (value != 0.0d || value2 != 0.0d) {
                    for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                        int indexAtLocation = featureVector.indexAtLocation(i3);
                        if (!$assertionsDisabled && featureVector.valueAtLocation(i3) != 1.0d) {
                            throw new AssertionError();
                        }
                        double[] dArr4 = dArr[i2];
                        dArr4[indexAtLocation] = dArr4[indexAtLocation] + value;
                        double[] dArr5 = dArr2[i2];
                        dArr5[indexAtLocation] = dArr5[indexAtLocation] + value2;
                    }
                }
            }
            if (!$assertionsDisabled && Math.abs(d4 - 1.0d) >= 0.001d) {
                throw new AssertionError();
            }
        }
        if (!$assertionsDisabled && Math.abs((d2 / size) - 1.0d) >= 0.001d) {
            throw new AssertionError("trueLabelWeightSum should be 1.0, it was " + d2);
        }
        if (!$assertionsDisabled && Math.abs((d3 / size) - 1.0d) >= 0.001d) {
            throw new AssertionError("modelLabelWeightSum should be 1.0, it was " + d3);
        }
        double[][] dArr6 = new double[size2][size3];
        double[][] dArr7 = new double[size2][size3];
        double[][] dArr8 = new double[size2][size3];
        double[][] dArr9 = new double[size2][size3];
        double[][] dArr10 = new double[size2][size3];
        for (int i4 = 0; i4 < size2; i4++) {
            for (int i5 = 0; i5 < size3; i5++) {
                dArr8[i4][i5] = Double.POSITIVE_INFINITY;
                dArr9[i4][i5] = Double.NEGATIVE_INFINITY;
            }
        }
        double d5 = 99.0d;
        for (int i6 = 0; d5 > 1.0E-8d && i6 < 50; i6++) {
            for (int i7 = 0; i7 < size2; i7++) {
                for (int i8 = 0; i8 < size3; i8++) {
                    dArr6[i7][i8] = dArr[i7][i8] - (dArr3[i7][i8] / d);
                    dArr10[i7][i8] = (-1.0d) / d;
                }
            }
            for (int i9 = 0; i9 < instanceList.size(); i9++) {
                if (!$assertionsDisabled && labelVectorArr[i9].getLabelAlphabet() != instanceList.getTargetAlphabet()) {
                    throw new AssertionError();
                }
                Instance instance2 = instanceList.get(i9);
                instance2.getLabeling();
                FeatureVector featureVector2 = (FeatureVector) instance2.getData();
                for (int i10 = 0; i10 < featureVector2.numLocations(); i10++) {
                    int indexAtLocation2 = featureVector2.indexAtLocation(i10);
                    for (int i11 = 0; i11 < size2; i11++) {
                        double value3 = labelVectorArr[i9].value(i11);
                        double exp = value3 * Math.exp(dArr3[i11][indexAtLocation2]);
                        double d6 = exp + (1.0d - value3);
                        double[] dArr11 = dArr6[i11];
                        dArr11[indexAtLocation2] = dArr11[indexAtLocation2] - (exp / d6);
                        double[] dArr12 = dArr10[i11];
                        dArr12[indexAtLocation2] = dArr12[indexAtLocation2] + (((exp * exp) / (d6 * d6)) - (exp / d6));
                    }
                }
            }
            d5 = 0.0d;
            double d7 = 0.0d;
            for (int i12 = 0; i12 < size2; i12++) {
                for (int i13 = 0; i13 < size3; i13++) {
                    double d8 = -(dArr6[i12][i13] / dArr10[i12][i13]);
                    if (dArr[i12][i13] != 0.0d || dArr2[i12][i13] != 0.0d) {
                        if (Double.isNaN(dArr3[i12][i13]) || Double.isNaN(d8)) {
                            logger.info("alpha[" + i12 + "][" + i13 + "]=" + dArr3[i12][i13] + " p=" + dArr[i12][i13] + " q=" + dArr2[i12][i13] + " dalpha=" + dArr6[i12][i13] + " ddalpha=" + dArr10[i12][i13] + " alphachange=" + d8 + " min=" + dArr9[i12][i13] + " max=" + dArr8[i12][i13]);
                        }
                        if (Double.isNaN(dArr3[i12][i13]) || Double.isNaN(dArr6[i12][i13]) || Double.isNaN(dArr10[i12][i13]) || Double.isInfinite(dArr3[i12][i13]) || Double.isInfinite(dArr6[i12][i13]) || Double.isInfinite(dArr10[i12][i13])) {
                            d8 = 0.0d;
                        }
                        double d9 = dArr3[i12][i13];
                        double d10 = Math.abs(d8 + dArr7[i12][i13]) / Math.abs(d8) < 0.01d ? dArr3[i12][i13] + (d8 / 2.0d) : dArr3[i12][i13] + d8;
                        if (d8 < 0.0d && dArr8[i12][i13] > dArr3[i12][i13]) {
                            dArr8[i12][i13] = dArr3[i12][i13];
                        }
                        if (d8 > 0.0d && dArr9[i12][i13] < dArr3[i12][i13]) {
                            dArr9[i12][i13] = dArr3[i12][i13];
                        }
                        if (d10 <= dArr8[i12][i13] && d10 >= dArr9[i12][i13]) {
                            dArr3[i12][i13] = d10;
                        } else {
                            if (!$assertionsDisabled && dArr8[i12][i13] == Double.POSITIVE_INFINITY) {
                                throw new AssertionError();
                            }
                            if (!$assertionsDisabled && dArr9[i12][i13] == Double.NEGATIVE_INFINITY) {
                                throw new AssertionError();
                            }
                            dArr3[i12][i13] = dArr9[i12][i13] + ((dArr8[i12][i13] - dArr9[i12][i13]) / 2.0d);
                        }
                        double d11 = dArr3[i12][i13] - d9;
                        if (Math.abs(d11) > d7) {
                            d7 = Math.abs(d11);
                        }
                        if (Math.abs(dArr6[i12][i13]) > d5) {
                            d5 = Math.abs(dArr6[i12][i13]);
                        }
                        dArr7[i12][i13] = d11;
                    }
                }
            }
            logger.info("After " + i6 + " Newton iterations, maximum alphachange=" + d7 + " dalpha=" + d5);
        }
        double[][] dArr13 = new double[size2][size3];
        for (int i14 = 0; i14 < instanceList.size(); i14++) {
            if (!$assertionsDisabled && labelVectorArr[i14].getLabelAlphabet() != instanceList.getTargetAlphabet()) {
                throw new AssertionError();
            }
            Instance instance3 = instanceList.get(i14);
            instance3.getLabeling();
            FeatureVector featureVector3 = (FeatureVector) instance3.getData();
            int numLocations = featureVector3.numLocations() - 1;
            for (int i15 = 0; i15 < size2; i15++) {
                double value4 = labelVectorArr[i14].value(i15);
                for (int i16 = 0; i16 < featureVector3.numLocations(); i16++) {
                    int indexAtLocation3 = featureVector3.indexAtLocation(i16);
                    double[] dArr14 = dArr13[i15];
                    dArr14[indexAtLocation3] = dArr14[indexAtLocation3] + Math.log((value4 * Math.exp(dArr3[i15][indexAtLocation3])) + (1.0d - value4));
                }
            }
        }
        double[] dArr15 = new double[size3];
        for (int i17 = 0; i17 < size2; i17++) {
            for (int i18 = 0; i18 < size3; i18++) {
                if (!$assertionsDisabled && Double.isInfinite(dArr3[i17][i18])) {
                    throw new AssertionError();
                }
                double d12 = dArr3[i17][i18];
                if (d12 != 0.0d) {
                    double d13 = ((d12 * dArr[i17][i18]) - dArr13[i17][i18]) - ((d12 * d12) / (2.0d * d));
                    if (d13 >= 0.0d) {
                        int i19 = i18;
                        dArr15[i19] = dArr15[i19] + d13;
                    }
                }
            }
        }
        return dArr15;
    }

    public ExpGain(InstanceList instanceList, LabelVector[] labelVectorArr, double d) {
        super(instanceList.getDataAlphabet(), calcExpGains(instanceList, labelVectorArr, d));
        this.usingHyperbolicPrior = false;
        this.hyperbolicSlope = 0.2d;
        this.hyperbolicSharpness = 10.0d;
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] classificationArr) {
        LabelVector[] labelVectorArr = new LabelVector[classificationArr.length];
        for (int i = 0; i < classificationArr.length; i++) {
            labelVectorArr[i] = classificationArr[i].getLabelVector();
        }
        return labelVectorArr;
    }

    public ExpGain(InstanceList instanceList, Classification[] classificationArr, double d) {
        super(instanceList.getDataAlphabet(), calcExpGains(instanceList, getLabelVectorsFromClassifications(classificationArr), d));
        this.usingHyperbolicPrior = false;
        this.hyperbolicSlope = 0.2d;
        this.hyperbolicSharpness = 10.0d;
    }
}
