package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.util.LabelsAssignment;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labels;
import cc.mallet.types.LabelsSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntArrayList;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/grmm/learning/DefaultAcrfTrainer.class */
public class DefaultAcrfTrainer implements ACRFTrainer {
    private Optimizer maxer;
    private File outputPrefix = new File("");
    private static final int SUBSET_ITER = 10;
    private static Logger logger = MalletLogger.getLogger(DefaultAcrfTrainer.class.getName());
    private static boolean rethrowExceptions = false;
    private static final double[] SIZE = {0.1d, 0.5d};
    private static final Random r = new Random(1729);

    /* loaded from: input_file:cc/mallet/grmm/learning/DefaultAcrfTrainer$FileEvaluator.class */
    public static class FileEvaluator extends ACRFEvaluator {
        private File file;

        public FileEvaluator(File file) {
            this.file = file;
        }

        @Override // cc.mallet.grmm.learning.ACRFEvaluator
        public boolean evaluate(ACRF acrf, int i, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3) {
            if (!shouldDoEvaluate(i)) {
                return true;
            }
            test(acrf, instanceList3, "Testing ");
            return true;
        }

        @Override // cc.mallet.grmm.learning.ACRFEvaluator
        public void test(InstanceList instanceList, List list, String str) {
            DefaultAcrfTrainer.logger.info("Number of testing instances = " + instanceList.size());
            TestResults computeTestResults = LogEvaluator.computeTestResults(instanceList, list);
            try {
                PrintWriter printWriter = new PrintWriter(new FileWriter(this.file, true));
                computeTestResults.print(str, printWriter);
                printWriter.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /* loaded from: input_file:cc/mallet/grmm/learning/DefaultAcrfTrainer$LogEvaluator.class */
    public static class LogEvaluator extends ACRFEvaluator {
        private TestResults lastResults;
        static final /* synthetic */ boolean $assertionsDisabled;

        @Override // cc.mallet.grmm.learning.ACRFEvaluator
        public boolean evaluate(ACRF acrf, int i, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3) {
            if (!shouldDoEvaluate(i)) {
                return true;
            }
            if (instanceList != null) {
                test(acrf, instanceList, "Training");
            }
            if (instanceList3 == null) {
                return true;
            }
            test(acrf, instanceList3, "Testing");
            return true;
        }

        @Override // cc.mallet.grmm.learning.ACRFEvaluator
        public void test(InstanceList instanceList, List list, String str) {
            DefaultAcrfTrainer.logger.info(str + ": Number of instances = " + instanceList.size());
            TestResults computeTestResults = computeTestResults(instanceList, list);
            computeTestResults.log(str);
            this.lastResults = computeTestResults;
        }

        public static TestResults computeTestResults(InstanceList instanceList, List list) {
            TestResults testResults = new TestResults(instanceList);
            Iterator<Instance> it2 = instanceList.iterator();
            Iterator it3 = list.iterator();
            while (it2.hasNext()) {
                compareLabelings(testResults, (LabelsSequence) it3.next(), ((LabelsAssignment) it2.next().getTarget()).getLabelsSequence());
            }
            testResults.computeStatistics();
            return testResults;
        }

        static void compareLabelings(TestResults testResults, LabelsSequence labelsSequence, LabelsSequence labelsSequence2) {
            if (!$assertionsDisabled && labelsSequence.size() != labelsSequence2.size()) {
                throw new AssertionError();
            }
            for (int i = 0; i < labelsSequence.size(); i++) {
                testResults.incrementCount(labelsSequence.getLabels(i), labelsSequence2.getLabels(i));
            }
        }

        public double getJointAccuracy() {
            return this.lastResults.getJointAccuracy();
        }

        static {
            $assertionsDisabled = !DefaultAcrfTrainer.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:cc/mallet/grmm/learning/DefaultAcrfTrainer$TestResults.class */
    public static class TestResults {
        public int[][] confusion;
        public int numClasses;
        public int[] trueCounts;
        public int[] returnedCounts;
        public double[] precision;
        public double[] recall;
        public double[] f1;
        public TIntArrayList[] factors;
        public int maxT;
        public int correctT;
        public Alphabet alphabet;

        TestResults(InstanceList instanceList) {
            this(instanceList.get(0));
        }

        TestResults(Instance instance) {
            this.maxT = 0;
            this.correctT = 0;
            this.alphabet = new Alphabet();
            setupAlphabet(instance);
            this.numClasses = this.alphabet.size();
            this.confusion = new int[this.numClasses][this.numClasses];
            this.precision = new double[this.numClasses];
            this.recall = new double[this.numClasses];
            this.f1 = new double[this.numClasses];
        }

        private void setupAlphabet(Instance instance) {
            LabelsAssignment labelsAssignment = (LabelsAssignment) instance.getTarget();
            this.factors = new TIntArrayList[labelsAssignment.numSlices()];
            for (int i = 0; i < labelsAssignment.numSlices(); i++) {
                LabelAlphabet outputAlphabet = labelsAssignment.getOutputAlphabet(i);
                this.factors[i] = new TIntArrayList(outputAlphabet.size());
                for (int i2 = 0; i2 < outputAlphabet.size(); i2++) {
                    this.factors[i].add(this.alphabet.lookupIndex(outputAlphabet.lookupObject(i2)));
                }
            }
        }

        void incrementCount(Labels labels, Labels labels2) {
            boolean z = true;
            for (int i = 0; i < labels.size(); i++) {
                Label label = labels.get(i);
                int lookupIndex = this.alphabet.lookupIndex(labels2.get(i).getEntry());
                int lookupIndex2 = this.alphabet.lookupIndex(label.getEntry());
                if (lookupIndex != lookupIndex2) {
                    z = false;
                }
                int[] iArr = this.confusion[lookupIndex];
                iArr[lookupIndex2] = iArr[lookupIndex2] + 1;
            }
            this.maxT++;
            if (z) {
                this.correctT++;
            }
        }

        void computeStatistics() {
            this.trueCounts = new int[this.numClasses];
            this.returnedCounts = new int[this.numClasses];
            for (int i = 0; i < this.numClasses; i++) {
                for (int i2 = 0; i2 < this.numClasses; i2++) {
                    int[] iArr = this.trueCounts;
                    int i3 = i;
                    iArr[i3] = iArr[i3] + this.confusion[i][i2];
                    int[] iArr2 = this.returnedCounts;
                    int i4 = i2;
                    iArr2[i4] = iArr2[i4] + this.confusion[i][i2];
                }
            }
            for (int i5 = 0; i5 < this.numClasses; i5++) {
                double d = this.confusion[i5][i5];
                if (this.returnedCounts[i5] == 0) {
                    this.precision[i5] = d == 0.0d ? 1.0d : 0.0d;
                } else {
                    this.precision[i5] = d / this.returnedCounts[i5];
                }
                if (this.trueCounts[i5] == 0) {
                    this.recall[i5] = 1.0d;
                } else {
                    this.recall[i5] = d / this.trueCounts[i5];
                }
                this.f1[i5] = ((2.0d * this.precision[i5]) * this.recall[i5]) / (this.precision[i5] + this.recall[i5]);
            }
        }

        public void log() {
            log("");
        }

        public void log(String str) {
            DefaultAcrfTrainer.logger.info(str + ":  i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
            for (int i = 0; i < this.numClasses; i++) {
                DefaultAcrfTrainer.logger.info(str + ":  " + i + "\t" + this.alphabet.lookupObject(i) + "\t" + this.trueCounts[i] + "\t" + this.confusion[i][i] + "\t" + this.returnedCounts[i] + "\t" + this.precision[i] + "\t" + this.recall[i] + "\t" + this.f1[i] + "\t");
            }
            for (int i2 = 0; i2 < this.factors.length; i2++) {
                int i3 = 0;
                int i4 = 0;
                for (int i5 = 0; i5 < this.factors[i2].size(); i5++) {
                    int i6 = this.factors[i2].get(i5);
                    i3 += this.confusion[i6][i6];
                    i4 += this.returnedCounts[i6];
                }
                DefaultAcrfTrainer.logger.info(str + ":  Factor " + i2 + " accuracy: (" + i3 + LangRequest.DEFAULT_SELECTION + i4 + ") " + (i3 / i4));
            }
            DefaultAcrfTrainer.logger.info(str + " CorrectT " + this.correctT + "  maxt " + this.maxT);
            DefaultAcrfTrainer.logger.info(str + " Joint accuracy: " + (this.correctT / this.maxT));
        }

        public void print(String str, PrintWriter printWriter) {
            printWriter.println("i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
            for (int i = 0; i < this.numClasses; i++) {
                printWriter.println(i + "\t" + this.alphabet.lookupObject(i) + "\t" + this.trueCounts[i] + "\t" + this.confusion[i][i] + "\t" + this.returnedCounts[i] + "\t" + this.precision[i] + "\t" + this.recall[i] + "\t" + this.f1[i] + "\t");
            }
            for (int i2 = 0; i2 < this.factors.length; i2++) {
                int i3 = 0;
                int i4 = 0;
                for (int i5 = 0; i5 < this.factors[i2].size(); i5++) {
                    int i6 = this.factors[i2].get(i5);
                    i3 += this.confusion[i6][i6];
                    i4 += this.returnedCounts[i6];
                }
                printWriter.println(str + " Factor " + i2 + " accuracy: (" + i3 + LangRequest.DEFAULT_SELECTION + i4 + ") " + (i3 / i4));
            }
            printWriter.println(str + " CorrectT " + this.correctT + "  maxt " + this.maxT);
            printWriter.println(str + " Joint accuracy: " + (this.correctT / this.maxT));
        }

        void printConfusion() {
            System.out.println("True\t\tReturned\tCount");
            for (int i = 0; i < this.numClasses; i++) {
                for (int i2 = 0; i2 < this.numClasses; i2++) {
                    System.out.println(i + "\t\t" + i2 + "\t" + this.confusion[i][i2]);
                }
            }
        }

        public double getJointAccuracy() {
            return this.correctT / this.maxT;
        }
    }

    public void setOutputPrefix(File file) {
        this.outputPrefix = file;
    }

    public Optimizer getMaxer() {
        return this.maxer;
    }

    public void setMaxer(Optimizer optimizer) {
        this.maxer = optimizer;
    }

    public static boolean isRethrowExceptions() {
        return rethrowExceptions;
    }

    public static void setRethrowExceptions(boolean z) {
        rethrowExceptions = z;
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList) {
        return train(acrf, instanceList, null, null, new LogEvaluator(), 1);
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, int i) {
        return train(acrf, instanceList, null, null, new LogEvaluator(), i);
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, ACRFEvaluator aCRFEvaluator, int i) {
        return train(acrf, instanceList, null, null, aCRFEvaluator, i);
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i) {
        return train(acrf, instanceList, instanceList2, instanceList3, new LogEvaluator(), i);
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i) {
        return train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i, createOptimizable(acrf, instanceList));
    }

    protected Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList instanceList) {
        return acrf.getMaximizable(instanceList);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i) {
        return incrementalTrain(acrf, instanceList, instanceList2, instanceList3, new LogEvaluator(), i);
    }

    public boolean incrementalTrain(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i) {
        long time = new Date().getTime();
        for (int i2 = 0; i2 < SIZE.length; i2++) {
            InstanceList instanceList4 = instanceList.split(new double[]{SIZE[i2], 1.0d - SIZE[i2]})[0];
            logger.info("Training on subset of size " + instanceList4.size());
            train(acrf, instanceList, instanceList2, (InstanceList) null, aCRFEvaluator, 10, createOptimizable(acrf, instanceList4));
            logger.info("Subset training " + i2 + " finished...");
        }
        logger.info("All subset training finished.  Time = " + (new Date().getTime() - time) + " ms.");
        return train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i);
    }

    @Override // cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i, Optimizable.ByGradientValue byGradientValue) {
        Optimizer createMaxer = createMaxer(byGradientValue);
        boolean z = false;
        boolean z2 = true;
        long currentTimeMillis = System.currentTimeMillis();
        int totalNodes = byGradientValue instanceof ACRF.MaximizableACRF ? ((ACRF.MaximizableACRF) byGradientValue).getTotalNodes() : 0;
        double d = 1.0E-5d * totalNodes;
        if (instanceList3 == null) {
            logger.warning("ACRF trainer: No test set provided.");
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            logger.info("ACRF trainer iteration " + i2 + " at time " + (new Date().getTime() - currentTimeMillis));
            try {
                z = createMaxer.optimize(1) | callEvaluator(acrf, instanceList, instanceList2, instanceList3, i2, aCRFEvaluator);
            } catch (RuntimeException e) {
                e.printStackTrace();
                if (!z2) {
                    logger.warning("Exception in iteration " + i2 + ":" + e + "\n   Quitting and saying converged...");
                    z = true;
                    if (rethrowExceptions) {
                        throw e;
                    }
                    if (i2 >= i) {
                        logger.info("ACRFTrainer: Too many iterations, stopping training.  maxIter = " + i);
                    }
                    logger.info("ACRF training time (ms) = " + (System.currentTimeMillis() - currentTimeMillis));
                    if (byGradientValue instanceof ACRF.MaximizableACRF) {
                        ((ACRF.MaximizableACRF) byGradientValue).report();
                    }
                    if (instanceList3 != null && aCRFEvaluator != null) {
                        boolean isCacheUnrolledGraphs = acrf.isCacheUnrolledGraphs();
                        acrf.setCacheUnrolledGraphs(false);
                        aCRFEvaluator.test(acrf, instanceList3, "Testing");
                        acrf.setCacheUnrolledGraphs(isCacheUnrolledGraphs);
                    }
                    return z;
                }
                logger.warning("Exception in iteration " + i2 + ":" + e + "\n  Resetting LBFGs and trying again...");
                if (createMaxer instanceof LimitedMemoryBFGS) {
                    ((LimitedMemoryBFGS) createMaxer).reset();
                }
                if (createMaxer instanceof ConjugateGradient) {
                    ((ConjugateGradient) createMaxer).reset();
                }
                z2 = false;
            }
            if (z) {
                break;
            }
            z2 = true;
            if (z) {
                break;
            }
            double value = byGradientValue.getValue();
            if (Math.abs(value - d2) >= d) {
                d2 = value;
            } else if (z2) {
                logger.info("ACRFTrainer saying converged:  Current value " + value + ", previous " + d2 + "\n...threshold was " + d + " = 1e-5 * " + totalNodes);
                z = true;
                break;
            }
            i2++;
        }
    }

    private Optimizer createMaxer(Optimizable.ByGradientValue byGradientValue) {
        return this.maxer == null ? new LimitedMemoryBFGS(byGradientValue) : this.maxer;
    }

    protected boolean callEvaluator(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, int i, ACRFEvaluator aCRFEvaluator) {
        if (aCRFEvaluator == null) {
            return false;
        }
        aCRFEvaluator.setOutputPrefix(this.outputPrefix);
        boolean isCacheUnrolledGraphs = acrf.isCacheUnrolledGraphs();
        acrf.setCacheUnrolledGraphs(false);
        Timing timing = new Timing();
        if (aCRFEvaluator.evaluate(acrf, i + 1, instanceList, instanceList2, instanceList3)) {
            timing.tick("Evaluation time (iteration " + i + ")");
            acrf.setCacheUnrolledGraphs(isCacheUnrolledGraphs);
            return false;
        }
        logger.info("ACRF trainer: evaluator returned false. Quitting.");
        timing.tick("Evaluation time (iteration " + i + ")");
        return true;
    }

    public boolean someUnsupportedTrain(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i) {
        Optimizable.ByGradientValue createOptimizable = createOptimizable(acrf, instanceList);
        train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, 5, createOptimizable);
        for (ACRF.Template template : acrf.getTemplates()) {
            template.addSomeUnsupportedWeights(instanceList);
        }
        logger.info("Some unsupporetd weights initialized.  Training...");
        return train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i, createOptimizable);
    }

    public void test(ACRF acrf, InstanceList instanceList, ACRFEvaluator aCRFEvaluator) {
        test(acrf, instanceList, new ACRFEvaluator[]{aCRFEvaluator});
    }

    public void test(ACRF acrf, InstanceList instanceList, ACRFEvaluator[] aCRFEvaluatorArr) {
        List bestLabels = acrf.getBestLabels(instanceList);
        for (int i = 0; i < aCRFEvaluatorArr.length; i++) {
            aCRFEvaluatorArr[i].setOutputPrefix(this.outputPrefix);
            aCRFEvaluatorArr[i].test(instanceList, bestLabels, "Testing");
        }
    }

    public static Random getRandom() {
        return r;
    }

    public void train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, double[] dArr, int i) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d = dArr[i2];
            InstanceList[] split = instanceList.split(r, new double[]{d, 1.0d});
            logger.info("ACRF trainer: Round " + i2 + ", training proportion = " + d);
            train(acrf, split[0], instanceList2, instanceList3, aCRFEvaluator, i);
        }
        logger.info("ACRF trainer: Training on full data");
        train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, 99999);
    }
}
