package edu.pitt.dbmi.edda.operator.ldaop;

import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.Input2CharSequence;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.PrintInputAndTarget;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2Label;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.topics.DMRTopicModel;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import edu.pitt.dbmi.edda.operator.c45bayes.C45BayesModel;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.apache.log4j.Level;

/* loaded from: input_file:edu/pitt/dbmi/edda/operator/ldaop/TopicModelWorker.class */
public class TopicModelWorker {
    private static final Logger logger = Logger.getLogger(TopicModelWorker.class.getName());
    private ExampleSet inputExampleSet;
    private List<String[]> labelDirectoryPairs;
    private InstanceList instances;
    private TopicModelAdapter topicModelAdapter;
    private ParallelTopicModel topicModel;
    private ExampleSet outgoingExampleSet;
    private int numTopics = 10;
    private double alpha = 1.0d;
    private double beta = 0.01d;
    private int numberOfIterations = Level.TRACE_INT;
    private int burnInPeriod = 500;
    private int numberOfThreads = 1;
    private int randomSeed = 10;
    private int optimizeInterval = 10;
    private boolean isSymmetricAlpha = true;
    private int temperingInterval = 0;
    private int inferencerIterations = 0;
    private int inferencerThinning = 0;
    private int inferencerBurnInPeriod = 0;
    private boolean isOutputingDocumentThetas = true;
    private boolean isOutputingKulbachLeiblerDivergences = false;
    private int numberMostProbableWordsForDisplay = 10;
    private boolean isGeneratingDiagnostics = false;

    public static void main(String[] strArr) {
        new TopicModelWorker();
    }

    public void process() throws OperatorException {
        try {
            establishTopicModelAdapter();
            generateOutgoingExampleSet();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void establishTopicModelAdapter() throws IOException {
        if (this.topicModelAdapter == null) {
            this.topicModelAdapter = buildTopicModelAdapter();
        } else {
            this.topicModel = this.topicModelAdapter.getParallelTopicModel();
            this.instances = generateInstances(this.topicModelAdapter.getPipe());
        }
    }

    private void generateOutgoingExampleSet() {
        if (isOutputingDocumentThetas() && isOutputingKulbachLeiblerDivergences()) {
            AttributeBuilderBoth attributeBuilderBoth = new AttributeBuilderBoth();
            attributeBuilderBoth.setTopicModelAdapter(this.topicModelAdapter);
            this.outgoingExampleSet = buildExampleSet(attributeBuilderBoth);
        } else if (isOutputingDocumentThetas()) {
            AttributeBuilderThetas attributeBuilderThetas = new AttributeBuilderThetas();
            attributeBuilderThetas.setTopicModelAdapter(this.topicModelAdapter);
            this.outgoingExampleSet = buildExampleSet(attributeBuilderThetas);
        } else if (this.isOutputingKulbachLeiblerDivergences) {
            AttributeBuilderKulbackLeibler attributeBuilderKulbackLeibler = new AttributeBuilderKulbackLeibler();
            attributeBuilderKulbackLeibler.setTopicModelAdapter(this.topicModelAdapter);
            this.outgoingExampleSet = buildExampleSet(attributeBuilderKulbackLeibler);
        }
    }

    private TopicModelAdapter buildTopicModelAdapter() throws IOException {
        Pipe buildPipe = buildPipe();
        this.instances = generateInstances(buildPipe);
        this.topicModel = buildModel();
        this.topicModel.addInstances(this.instances);
        this.topicModel.estimate();
        displayAveragesAcrossLabels();
        AttributeBuilderThetas attributeBuilderThetas = new AttributeBuilderThetas();
        attributeBuilderThetas.setTopicModel(this.topicModel);
        TopicModelAdapter topicModelAdapter = new TopicModelAdapter(this.topicModel, buildExampleSet(attributeBuilderThetas));
        topicModelAdapter.setPipe(buildPipe);
        topicModelAdapter.setNumberMostProbableWordsForDisplay(getNumberMostProbableWordsForDisplay());
        return topicModelAdapter;
    }

    private ParallelTopicModel buildModel() {
        ParallelTopicModel parallelTopicModel = new ParallelTopicModel(getNumTopics(), this.alpha, this.beta);
        parallelTopicModel.printLogLikelihood = false;
        parallelTopicModel.setNumThreads(getNumberOfThreads());
        parallelTopicModel.setNumIterations(getNumberOfIterations());
        parallelTopicModel.setBurninPeriod(getBurnInPeriod());
        parallelTopicModel.setRandomSeed(getRandomSeed());
        parallelTopicModel.setOptimizeInterval(getOptimizeInterval());
        parallelTopicModel.setSymmetricAlpha(isSymmetricAlpha());
        parallelTopicModel.setTemperingInterval(getTemperingInterval());
        if (isGeneratingDiagnostics()) {
            parallelTopicModel.setTopicDisplay(getNumTopics(), getNumTopics());
        } else {
            parallelTopicModel.setTopicDisplay(0, 0);
        }
        return parallelTopicModel;
    }

    private ExampleSet buildExampleSet(AttributeBuilder attributeBuilder) {
        attributeBuilder.setInstances(this.instances);
        attributeBuilder.setInferencerIterations(getInferencerIterations());
        attributeBuilder.setInferencerThinning(getInferencerThinning());
        attributeBuilder.setInferencerBurnInPeriod(getInferencerBurnInPeriod());
        attributeBuilder.setGeneratingDiagnostics(isGeneratingDiagnostics());
        attributeBuilder.createAttributes();
        return attributeBuilder.getExampleSet();
    }

    public Pipe buildPipe() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Input2CharSequence("UTF-8"));
        arrayList.add(new CharSequence2TokenSequence(Pattern.compile("\\S+")));
        arrayList.add(new TokenSequence2FeatureSequence());
        arrayList.add(new Target2Label());
        if (isGeneratingDiagnostics()) {
            arrayList.add(new PrintInputAndTarget());
        }
        return new SerialPipes(arrayList);
    }

    private void displayAveragesAcrossLabels() {
        double d = 0.0d;
        double d2 = 0.0d;
        double[] dArr = new double[this.topicModel.getNumTopics()];
        double[] dArr2 = new double[this.topicModel.getNumTopics()];
        int i = 0;
        Iterator<Instance> it = this.instances.iterator();
        while (it.hasNext()) {
            String label = ((Label) it.next().getTarget()).toString();
            if (label.equals(C45BayesModel.CONST_C45_BAYES_MODEL_CLS_EXCLUDE)) {
                d += 1.0d;
                double[] topicProbabilities = this.topicModel.getTopicProbabilities(i);
                for (int i2 = 0; i2 < this.topicModel.getNumTopics(); i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + topicProbabilities[i2];
                }
            } else if (label.equals(C45BayesModel.CONST_C45_BAYES_MODEL_CLS_INCLUDE)) {
                d2 += 1.0d;
                double[] topicProbabilities2 = this.topicModel.getTopicProbabilities(i);
                for (int i4 = 0; i4 < this.topicModel.getNumTopics(); i4++) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + topicProbabilities2[i4];
                }
            }
            i++;
        }
        for (int i6 = 0; i6 < this.topicModel.getNumTopics(); i6++) {
            int i7 = i6;
            dArr[i7] = dArr[i7] / d;
        }
        Formatter formatter = new Formatter(new StringBuilder(), Locale.US);
        formatter.format("Average Topic Probabilities over %s class\n", C45BayesModel.CONST_C45_BAYES_MODEL_CLS_EXCLUDE);
        for (int i8 = 0; i8 < this.topicModel.getNumTopics(); i8++) {
            formatter.format("%.3f ", Double.valueOf(dArr[i8]));
        }
        logger.fine(formatter.toString());
        for (int i9 = 0; i9 < this.topicModel.getNumTopics(); i9++) {
            int i10 = i9;
            dArr2[i10] = dArr2[i10] / d2;
        }
        Formatter formatter2 = new Formatter(new StringBuilder(), Locale.US);
        formatter2.format("Average Topic Probabilities over %s class\n", C45BayesModel.CONST_C45_BAYES_MODEL_CLS_INCLUDE);
        for (int i11 = 0; i11 < this.topicModel.getNumTopics(); i11++) {
            formatter2.format("%.3f ", Double.valueOf(dArr2[i11]));
        }
        logger.fine(formatter2.toString());
    }

    public InstanceList generateInstances(Pipe pipe) {
        return this.inputExampleSet != null ? readExampleSet(pipe) : readDirectories(pipe);
    }

    public InstanceList readDirectories(Pipe pipe) {
        FileIterator fileIterator = new FileIterator(this.labelDirectoryPairs, false, false);
        InstanceList instanceList = new InstanceList(pipe);
        instanceList.addThruPipe(fileIterator);
        return instanceList;
    }

    public InstanceList readExampleSet(Pipe pipe) {
        ExampleSetIterator exampleSetIterator = new ExampleSetIterator(this.inputExampleSet);
        InstanceList instanceList = new InstanceList(pipe);
        instanceList.addThruPipe(exampleSetIterator);
        return instanceList;
    }

    private void estimateDirichletParameters(InstanceList instanceList) {
        try {
            DMRTopicModel dMRTopicModel = new DMRTopicModel(10);
            dMRTopicModel.addInstances(instanceList);
            dMRTopicModel.estimate();
            dMRTopicModel.writeParameters(new File("dmr.parameters"));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public ParallelTopicModel getTopicModel() {
        return this.topicModel;
    }

    public void setTopicModel(ParallelTopicModel parallelTopicModel) {
        this.topicModel = parallelTopicModel;
    }

    public boolean isGeneratingDiagnostics() {
        return this.isGeneratingDiagnostics;
    }

    public void setGeneratingDiagnostics(boolean z) {
        this.isGeneratingDiagnostics = z;
    }

    public List<String[]> getLabelDirectoryPairs() {
        return this.labelDirectoryPairs;
    }

    public void setLabelDirectoryPairs(List<String[]> list) {
        this.labelDirectoryPairs = list;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public double getBeta() {
        return this.beta;
    }

    public void setBeta(double d) {
        this.beta = d;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public void setNumTopics(int i) {
        this.numTopics = i;
    }

    public ExampleSet getOutgoingExampleSet() {
        return this.outgoingExampleSet;
    }

    public int getNumberMostProbableWordsForDisplay() {
        return this.numberMostProbableWordsForDisplay;
    }

    public void setNumberMostProbableWordsForDisplay(int i) {
        this.numberMostProbableWordsForDisplay = i;
    }

    public int getNumberOfIterations() {
        return this.numberOfIterations;
    }

    public void setNumberOfIterations(int i) {
        this.numberOfIterations = i;
    }

    public int getBurnInPeriod() {
        return this.burnInPeriod;
    }

    public void setBurnInPeriod(int i) {
        this.burnInPeriod = i;
    }

    public int getNumberOfThreads() {
        return this.numberOfThreads;
    }

    public void setNumberOfThreads(int i) {
        this.numberOfThreads = i;
    }

    public boolean isOutputingDocumentThetas() {
        return this.isOutputingDocumentThetas;
    }

    public void setOutputingDocumentThetas(boolean z) {
        this.isOutputingDocumentThetas = z;
    }

    public boolean isOutputingKulbachLeiblerDivergences() {
        return this.isOutputingKulbachLeiblerDivergences;
    }

    public void setOutputingKulbachLeiblerDivergences(boolean z) {
        this.isOutputingKulbachLeiblerDivergences = z;
    }

    public TopicModelAdapter getTopicModelAdapter() {
        return this.topicModelAdapter;
    }

    public void setTopicModelAdapter(TopicModelAdapter topicModelAdapter) {
        this.topicModelAdapter = topicModelAdapter;
    }

    public int getOptimizeInterval() {
        return this.optimizeInterval;
    }

    public void setOptimizeInterval(int i) {
        this.optimizeInterval = i;
    }

    public int getRandomSeed() {
        return this.randomSeed;
    }

    public void setRandomSeed(int i) {
        this.randomSeed = i;
    }

    public boolean isSymmetricAlpha() {
        return this.isSymmetricAlpha;
    }

    public void setSymmetricAlpha(boolean z) {
        this.isSymmetricAlpha = z;
    }

    public int getTemperingInterval() {
        return this.temperingInterval;
    }

    public void setTemperingInterval(int i) {
        this.temperingInterval = i;
    }

    public int getInferencerIterations() {
        return this.inferencerIterations;
    }

    public void setInferencerIterations(int i) {
        this.inferencerIterations = i;
    }

    public int getInferencerThinning() {
        return this.inferencerThinning;
    }

    public void setInferencerThinning(int i) {
        this.inferencerThinning = i;
    }

    public int getInferencerBurnInPeriod() {
        return this.inferencerBurnInPeriod;
    }

    public void setInferencerBurnInPeriod(int i) {
        this.inferencerBurnInPeriod = i;
    }

    public void setInputExampleSet(ExampleSet exampleSet) {
        this.inputExampleSet = exampleSet;
    }
}
