package cc.mallet.grmm.learning.extract;

import bsh.EvalError;
import cc.mallet.extract.ExtractionEvaluator;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.AcrfSerialEvaluator;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.grmm.learning.GenericAcrfData2TokenSequence;
import cc.mallet.grmm.learning.MultiSegmentationEvaluatorACRF;
import cc.mallet.pipe.Input2CharSequence;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.iterator.FileListIterator;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.BshInterpreter;
import cc.mallet.util.CommandOption;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/* loaded from: input_file:cc/mallet/grmm/learning/extract/AcrfExtractorTui.class */
public class AcrfExtractorTui {
    private static final Logger logger = MalletLogger.getLogger(AcrfExtractorTui.class.getName());
    private static CommandOption.File outputPrefix = new CommandOption.File(AcrfExtractorTui.class, "output-prefix", "FILENAME", true, null, "Directory to write saved model to.", null);
    private static CommandOption.File modelFile = new CommandOption.File(AcrfExtractorTui.class, "model-file", "FILENAME", true, null, "Text file describing model structure.", null);
    private static CommandOption.File trainFile = new CommandOption.File(AcrfExtractorTui.class, "training", "FILENAME", true, null, "File containing training data.", null);
    private static CommandOption.File testFile = new CommandOption.File(AcrfExtractorTui.class, "testing", "FILENAME", true, null, "File containing testing data.", null);
    private static CommandOption.Integer numLabelsOption = new CommandOption.Integer(AcrfExtractorTui.class, "num-labels", "INT", true, -1, "If supplied, number of labels on each line of input file.  Otherwise, the token ---- must separate labels from features.", null);
    private static CommandOption.String trainerOption = new CommandOption.String(AcrfExtractorTui.class, "trainer", "STRING", true, "ACRFExtractorTrainer", "Specification of trainer type.", null);
    private static CommandOption.String inferencerOption = new CommandOption.String(AcrfExtractorTui.class, "inferencer", "STRING", true, "LoopyBP", "Specification of inferencer.", null);
    private static CommandOption.String maxInferencerOption = new CommandOption.String(AcrfExtractorTui.class, "max-inferencer", "STRING", true, "LoopyBP.createForMaxProduct()", "Specification of inferencer.", null);
    private static CommandOption.String evalOption = new CommandOption.String(AcrfExtractorTui.class, "eval", "STRING", true, "LOG", "Evaluator to use.  Java code grokking performed.", null);
    private static CommandOption.String extractionEvalOption = new CommandOption.String(AcrfExtractorTui.class, "extraction-eval", "STRING", true, "PerDocumentF1", "Evaluator to use.  Java code grokking performed.", null);
    private static CommandOption.Integer checkpointIterations = new CommandOption.Integer(AcrfExtractorTui.class, "checkpoint", "INT", true, -1, "Save a copy after every ___ iterations.", null);
    static CommandOption.Boolean cacheUnrolledGraph = new CommandOption.Boolean(AcrfExtractorTui.class, "cache-graphs", "true|false", true, true, "Whether to use memory-intensive caching.", null);
    static CommandOption.Boolean perTemplateTrain = new CommandOption.Boolean(AcrfExtractorTui.class, "per-template-train", "true|false", true, false, "Whether to pretrain templates before joint training.", null);
    static CommandOption.Integer pttIterations = new CommandOption.Integer(AcrfExtractorTui.class, "per-template-iterations", "INTEGER", false, 100, "How many training iterations for each step of per-template-training.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(AcrfExtractorTui.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    static CommandOption.Boolean useTokenText = new CommandOption.Boolean(AcrfExtractorTui.class, "use-token-text", "true|false", true, true, "If true, first feature in list is assumed to be token identity, and is treated specially.", null);
    private static CommandOption.Boolean labelsAtEnd = new CommandOption.Boolean(AcrfExtractorTui.class, "labels-at-end", "INT", true, false, "If true, then label is at end of each line, rather than beginning.", null);
    static CommandOption.Boolean trainingIsList = new CommandOption.Boolean(AcrfExtractorTui.class, "training-is-list", "true|false", true, false, "If true, training option gives list of files to read for training.", null);
    private static CommandOption.File dataDir = new CommandOption.File(AcrfExtractorTui.class, "data-dir", "FILENAME", true, null, "If training-is-list, base directory in which training files located.", null);
    private static BshInterpreter interpreter = setupInterpreter();

    public static void main(String[] strArr) throws IOException, EvalError {
        doProcessOptions(AcrfExtractorTui.class, strArr);
        Timing timing = new Timing();
        GenericAcrfData2TokenSequence genericAcrfData2TokenSequence = !numLabelsOption.wasInvoked() ? new GenericAcrfData2TokenSequence() : new GenericAcrfData2TokenSequence(numLabelsOption.value);
        if (!useTokenText.value) {
            genericAcrfData2TokenSequence.setFeaturesIncludeToken(false);
            genericAcrfData2TokenSequence.setIncludeTokenText(false);
        }
        genericAcrfData2TokenSequence.setLabelsAtEnd(labelsAtEnd.value);
        Pipe[] pipeArr = new Pipe[2];
        pipeArr[0] = trainingIsList.value ? new Input2CharSequence() : new Noop();
        pipeArr[1] = genericAcrfData2TokenSequence;
        SerialPipes serialPipes = new SerialPipes(pipeArr);
        Iterator<Instance> constructIterator = constructIterator(trainFile.value, dataDir.value, trainingIsList.value);
        Iterator<Instance> constructIterator2 = testFile.wasInvoked() ? constructIterator(testFile.value, dataDir.value, trainingIsList.value) : null;
        ACRF.Template[] parseModelFile = parseModelFile(modelFile.value);
        ACRFExtractorTrainer createTrainer = createTrainer(trainerOption.value);
        ACRFEvaluator createEvaluator = createEvaluator(evalOption.value);
        ExtractionEvaluator createExtractionEvaluator = createExtractionEvaluator(extractionEvalOption.value);
        createTrainer.setPipes(serialPipes, new TokenSequence2FeatureVectorSequence()).setDataSource(constructIterator, constructIterator2).setEvaluator(createEvaluator).setTemplates(parseModelFile).setInferencer(createInferencer(inferencerOption.value)).setViterbiInferencer(createInferencer(maxInferencerOption.value)).setCheckpointDirectory(outputPrefix.value).setNumCheckpointIterations(checkpointIterations.value).setCacheUnrolledGraphs(cacheUnrolledGraph.value).setUsePerTemplateTrain(perTemplateTrain.value).setPerTemplateIterations(pttIterations.value);
        logger.info("Starting training...");
        ACRFExtractor trainExtractor = createTrainer.trainExtractor();
        timing.tick("Training");
        FileUtils.writeGzippedObject(new File(outputPrefix.value, "extor.ser.gz"), trainExtractor);
        timing.tick("Serializing");
        InstanceList testingData = createTrainer.getTestingData();
        if (testingData != null) {
            createEvaluator.test(trainExtractor.getAcrf(), testingData, "Final results");
        }
        if (createExtractionEvaluator != null && testingData != null) {
            createExtractionEvaluator.evaluate(trainExtractor.extract(testingData));
            timing.tick("Evaluting");
        }
        System.out.println("Total time (ms) = " + timing.elapsedTime());
    }

    private static BshInterpreter setupInterpreter() {
        BshInterpreter interpreter2 = CommandOption.getInterpreter();
        try {
            interpreter2.eval("import edu.umass.cs.mallet.base.extract.*");
            interpreter2.eval("import edu.umass.cs.mallet.grmm.inference.*");
            interpreter2.eval("import edu.umass.cs.mallet.grmm.learning.*");
            interpreter2.eval("import edu.umass.cs.mallet.grmm.learning.templates.*");
            interpreter2.eval("import edu.umass.cs.mallet.grmm.learning.extract.*");
            return interpreter2;
        } catch (EvalError e) {
            throw new RuntimeException(e);
        }
    }

    private static Iterator<Instance> constructIterator(File file, File file2, boolean z) throws IOException {
        return z ? new FileListIterator(file, file2, null, null, true) : new LineGroupIterator(new FileReader(file), Pattern.compile("^\\s*$"), true);
    }

    public static ACRFEvaluator createEvaluator(String str) throws EvalError {
        return str.indexOf(40) >= 0 ? (ACRFEvaluator) interpreter.eval(str) : createEvaluator(new LinkedList(Arrays.asList(str.split("\\s+"))));
    }

    private static ExtractionEvaluator createExtractionEvaluator(String str) throws EvalError {
        if (str.indexOf(40) >= 0) {
            return (ExtractionEvaluator) interpreter.eval(str);
        }
        return (ExtractionEvaluator) interpreter.eval("new " + str + "Evaluator ()");
    }

    private static ACRFEvaluator createEvaluator(LinkedList linkedList) {
        String str = (String) linkedList.removeFirst();
        if (!str.equalsIgnoreCase("SEGMENT")) {
            if (str.equalsIgnoreCase("LOG")) {
                return new DefaultAcrfTrainer.LogEvaluator();
            }
            if (!str.equalsIgnoreCase("SERIAL")) {
                throw new RuntimeException("Error in --eval " + evalOption.value + ": illegal evaluator " + str);
            }
            ArrayList arrayList = new ArrayList();
            while (!linkedList.isEmpty()) {
                arrayList.add(createEvaluator(linkedList));
            }
            return new AcrfSerialEvaluator(arrayList);
        }
        int parseInt = Integer.parseInt((String) linkedList.removeFirst());
        if (linkedList.size() % 2 != 0) {
            throw new RuntimeException("Error in --eval " + evalOption.value + ": Every start tag must have a continue.");
        }
        int size = linkedList.size() / 2;
        String[] strArr = new String[size];
        String[] strArr2 = new String[size];
        for (int i = 0; i < size; i++) {
            strArr[i] = (String) linkedList.removeFirst();
            strArr2[i] = (String) linkedList.removeFirst();
        }
        return new MultiSegmentationEvaluatorACRF(strArr, strArr2, parseInt);
    }

    private static ACRFExtractorTrainer createTrainer(String str) throws EvalError {
        Object eval = interpreter.eval(str.indexOf(40) >= 0 ? str : str.endsWith("Trainer") ? "new " + str + "()" : "new " + str + "Trainer()");
        if (eval instanceof ACRFExtractorTrainer) {
            return (ACRFExtractorTrainer) eval;
        }
        if (eval instanceof DefaultAcrfTrainer) {
            return new ACRFExtractorTrainer().setTrainingMethod((ACRFTrainer) eval);
        }
        throw new RuntimeException("Don't know what to do with trainer " + eval);
    }

    private static Inferencer createInferencer(String str) throws EvalError {
        Object eval = interpreter.eval(str.indexOf(40) >= 0 ? str : "new " + str + "()");
        if (eval instanceof Inferencer) {
            return (Inferencer) eval;
        }
        throw new RuntimeException("Don't know what to do with inferencer " + eval);
    }

    public static void doProcessOptions(Class cls, String[] strArr) {
        CommandOption.List list = new CommandOption.List("", new CommandOption[0]);
        list.add(cls);
        list.process(strArr);
        list.logOptions(Logger.getLogger(""));
    }

    private static ACRF.Template[] parseModelFile(File file) throws IOException, EvalError {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        ArrayList arrayList = new ArrayList();
        String readLine = bufferedReader.readLine();
        while (true) {
            String str = readLine;
            if (str == null) {
                return (ACRF.Template[]) arrayList.toArray(new ACRF.Template[0]);
            }
            Object eval = interpreter.eval(str);
            if (!(eval instanceof ACRF.Template)) {
                throw new RuntimeException("Error in " + file + " line " + bufferedReader.toString() + ":\n  Object " + eval + " not a template");
            }
            arrayList.add(eval);
            readLine = bufferedReader.readLine();
        }
    }
}
