package com.rapidminer.extension.operator.text_processing.mallet;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.utils.ExampleSetBuilder;
import com.rapidminer.example.utils.ExampleSets;
import com.rapidminer.extension.Utility.ParameterReplacementProcessXMLFilter;
import com.rapidminer.extension.metadata.TopicModelMetaData;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.GenerateNewExampleSetMDRule;
import com.rapidminer.operator.ports.metadata.GenerateNewMDRule;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.text.Document;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.studio.internal.Resources;
import com.rapidminer.tools.RandomGenerator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;

/* loaded from: input_file:com/rapidminer/extension/operator/text_processing/mallet/LDA.class */
public class LDA extends Operator {
    InputPort docInput;
    OutputPort exaOutput;
    OutputPort topicOutput;
    OutputPort modelOutput;
    OutputPort performanceOutput;
    public static final String PARAMETER_NUMTOPICS = "number_of_topics";
    public static final String PARAMETER_ALPHA = "alpha";
    public static final String PARAMETER_BETA_HEURISTICS = "use_beta_heuristics";
    public static final String PARAMETER_BETA = "beta";
    public static final String PARAMETER_ALPHA_HEURISTICS = "use_alpha_heuristics";
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_REPRODUCEABLE = "reproducible";
    public static final String PARAMETER_TOP_WORDS = "top_words_per_topic";
    public static final String PARAMETER_META_DATA = "include_meta_data";

    public LDA(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.docInput = getInputPorts().createPort("col", IOObjectCollection.class);
        this.exaOutput = getOutputPorts().createPort("exa");
        this.topicOutput = getOutputPorts().createPort("top");
        this.modelOutput = getOutputPorts().createPort("mod");
        this.performanceOutput = getOutputPorts().createPort("per");
        getTransformer().addRule(new GenerateNewExampleSetMDRule(this.exaOutput) { // from class: com.rapidminer.extension.operator.text_processing.mallet.LDA.1
            public void transformMD() {
                try {
                    LDA.this.exaOutput.deliverMD(LDAModel.createExampleSetMetaData(LDA.this.getParameterAsInt(LDA.PARAMETER_NUMTOPICS)));
                } catch (UndefinedParameterError e) {
                    LDA.this.exaOutput.deliverMD(LDAModel.createExampleSetMetaData(1));
                }
            }
        });
        getTransformer().addRule(new GenerateNewExampleSetMDRule(this.topicOutput) { // from class: com.rapidminer.extension.operator.text_processing.mallet.LDA.2
            public void transformMD() {
                ExampleSetMetaData exampleSetMetaData = new ExampleSetMetaData();
                exampleSetMetaData.addAttribute(new AttributeMetaData("topicId", 3));
                exampleSetMetaData.addAttribute(new AttributeMetaData("word", 1));
                exampleSetMetaData.addAttribute(new AttributeMetaData("weight", 4));
                LDA.this.topicOutput.deliverMD(exampleSetMetaData);
            }
        });
        getTransformer().addRule(new GenerateNewMDRule(this.modelOutput, LDAModel.class) { // from class: com.rapidminer.extension.operator.text_processing.mallet.LDA.3
            public void transformMD() {
                TopicModelMetaData topicModelMetaData = new TopicModelMetaData(LDAModel.class);
                try {
                    topicModelMetaData.putMetaData("numTopics", Integer.valueOf(LDA.this.getParameterAsInt(LDA.PARAMETER_NUMTOPICS)));
                } catch (UndefinedParameterError e) {
                    topicModelMetaData.putMetaData("numTopics", 1);
                }
                LDA.this.modelOutput.deliverMD(topicModelMetaData);
            }
        });
        getTransformer().addRule(new GenerateNewMDRule(this.performanceOutput, new MetaData(PerformanceVector.class)));
    }

    public void doWork() throws UserError {
        IOObjectCollection<Document> data = this.docInput.getData(IOObjectCollection.class);
        new MalletHelper();
        InstanceList convertDocsToInstances = MalletHelper.convertDocsToInstances(data);
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMTOPICS);
        convertDocsToInstances.getDataAlphabet().size();
        RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(getParameterAsBoolean("use_local_random_seed"), getParameterAsInt("local_random_seed"));
        ParallelTopicModel parallelTopicModel = new ParallelTopicModel(parameterAsInt, getParameterAsBoolean(PARAMETER_ALPHA_HEURISTICS) ? 50.0d / getParameterAsInt(PARAMETER_NUMTOPICS) : getParameterAsDouble(PARAMETER_ALPHA), getParameterAsBoolean(PARAMETER_BETA_HEURISTICS) ? 50.0d / convertDocsToInstances.getDataAlphabet().size() : getParameterAsDouble(PARAMETER_BETA));
        int parallelism = Resources.getConcurrencyContext(this).getParallelism();
        if (getParameterAsBoolean("reproducible")) {
            parallelTopicModel.setNumThreads(1);
        } else {
            parallelTopicModel.setNumThreads(parallelism);
        }
        parallelTopicModel.setRandomSeed(randomGenerator.nextInt());
        parallelTopicModel.addInstances(convertDocsToInstances);
        parallelTopicModel.setNumIterations(getParameterAsInt("iterations"));
        try {
            parallelTopicModel.estimate();
        } catch (IOException e) {
            e.printStackTrace();
        }
        LDAModel lDAModel = new LDAModel(parallelTopicModel, getParameterAsBoolean(PARAMETER_META_DATA));
        Alphabet dataAlphabet = convertDocsToInstances.getDataAlphabet();
        ExampleSet applyOnDocumentsWithConvertedInstances = lDAModel.applyOnDocumentsWithConvertedInstances(data, convertDocsToInstances, true);
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("topicId", 3));
        arrayList.add(AttributeFactory.createAttribute("word", 1));
        arrayList.add(AttributeFactory.createAttribute("weight", 4));
        ExampleSetBuilder from = ExampleSets.from(arrayList);
        int parameterAsInt2 = getParameterAsInt(PARAMETER_TOP_WORDS);
        ArrayList<TreeSet<IDSorter>> sortedWords = parallelTopicModel.getSortedWords();
        for (int i = 0; i < parameterAsInt; i++) {
            Iterator<IDSorter> it = sortedWords.get(i).iterator();
            for (int i2 = 0; it.hasNext() && i2 < parameterAsInt2; i2++) {
                from.addRow(new double[]{i, ((Attribute) arrayList.get(1)).getMapping().mapString((String) dataAlphabet.lookupObject(r0.getID())), it.next().getWeight()});
            }
        }
        PerformanceVector performanceVector = new PerformanceVector();
        performanceVector.addCriterion(new EstimatedPerformance("LogLikelihood", parallelTopicModel.modelLogLikelihood(), 1, true));
        performanceVector.setMainCriterionName("LogLikelihood");
        this.topicOutput.deliver(from.build());
        this.exaOutput.deliver(applyOnDocumentsWithConvertedInstances);
        this.performanceOutput.deliver(performanceVector);
        this.modelOutput.deliver(lDAModel);
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMTOPICS, "Number of topics to find", 1, Integer.MAX_VALUE, 10, false));
        ParameterTypeBoolean parameterTypeBoolean = new ParameterTypeBoolean(PARAMETER_ALPHA_HEURISTICS, "Use heuristics to determine alpha.", true, false);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_ALPHA, "Baysian prior on the topic distribution. A common value is 50 (per number of topics). A common value is 50/Number of topics", 0.0d, Double.MAX_VALUE, 0.1d, false);
        parameterTypeDouble.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_ALPHA_HEURISTICS, true, false));
        parameterTypes.add(parameterTypeBoolean);
        parameterTypes.add(parameterTypeDouble);
        ParameterTypeDouble parameterTypeDouble2 = new ParameterTypeDouble(PARAMETER_BETA, "Baysian prior on the word distribution. A common value is 0.01 or 50/Number of words.", 0.0d, Double.MAX_VALUE, 0.01d, false);
        ParameterTypeBoolean parameterTypeBoolean2 = new ParameterTypeBoolean(PARAMETER_BETA_HEURISTICS, "Use heuristics to determine beta.", true, false);
        parameterTypeDouble2.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_BETA_HEURISTICS, true, false));
        parameterTypes.add(parameterTypeBoolean2);
        parameterTypes.add(parameterTypeDouble2);
        parameterTypes.add(new ParameterTypeInt(PARAMETER_TOP_WORDS, "Number of top words to extract.", 1, Integer.MAX_VALUE, 5, false));
        parameterTypes.add(new ParameterTypeInt("iterations", "Number of iterations.", 1, Integer.MAX_VALUE, 1000, false));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_META_DATA, "add meta information.", true, false));
        parameterTypes.add(new ParameterTypeBoolean("reproducible", "If this parameter is set to true, only one thread will be used.", false, false));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }

    static {
        ParameterReplacementProcessXMLFilter.registerReplacement(LDA.class, "number of topics", PARAMETER_NUMTOPICS);
        ParameterReplacementProcessXMLFilter.registerReplacement(LDA.class, "top words per topic", PARAMETER_TOP_WORDS);
    }
}
