package com.rapidminer.extension.operator.text_processing.modelling.mallet;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.topics.TopicModelDiagnostics;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import com.rapidminer.datatable.DataTable;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.utils.ExampleSetBuilder;
import com.rapidminer.example.utils.ExampleSets;
import com.rapidminer.extension.operator.text_processing.modelling.AbstractDocumentModel;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.text.Document;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;

/* loaded from: input_file:com/rapidminer/extension/operator/text_processing/modelling/mallet/LDAModel.class */
public class LDAModel extends AbstractDocumentModel {
    private static final long serialVersionUID = 7869437298910665701L;
    private ParallelTopicModel topicModel;
    private boolean addMetaData;
    private int randomSeed;
    private int numIterations;
    private int thinning;
    private int burnIn;
    private String textAttName;
    private PerformanceVector performanceVector = new PerformanceVector();
    private String topicDiagnosticString = null;
    private HashMap<String, double[]> topicScores = new HashMap<>();

    public LDAModel(ParallelTopicModel parallelTopicModel, boolean z) {
        this.topicModel = parallelTopicModel;
        this.addMetaData = z;
        this.performanceVector.addCriterion(new EstimatedPerformance("LogLikelihood", parallelTopicModel.modelLogLikelihood(), 1, true));
    }

    @Override // com.rapidminer.extension.operator.text_processing.modelling.AbstractDocumentModel
    public boolean requieresTokenization() {
        return false;
    }

    public static ExampleSetMetaData createExampleSetMetaData(int i) {
        ExampleSetMetaData exampleSetMetaData = new ExampleSetMetaData();
        exampleSetMetaData.addAttribute(new AttributeMetaData("documentid", "id", 3, new String[0]));
        exampleSetMetaData.addAttribute(new AttributeMetaData("prediction(Topic)", "prediction", 1, new String[0]));
        exampleSetMetaData.addAttribute(new AttributeMetaData("text", 5));
        for (int i2 = 0; i2 < i; i2++) {
            exampleSetMetaData.addAttribute(new AttributeMetaData("confidence(Topic_" + Integer.toString(i2) + ")", "confidence_Topic_" + Integer.toString(i2), 4, new String[0]));
        }
        return exampleSetMetaData;
    }

    public ExampleSet applyOnExampleSet(ExampleSet exampleSet, @Nullable Attribute attribute, boolean z) {
        InstanceList convertToInstanceList = attribute == null ? MalletHelper.convertToInstanceList(exampleSet, exampleSet.getAttributes().get(this.textAttName)) : MalletHelper.convertToInstanceList(exampleSet, exampleSet.getAttributes().get(attribute.getName()));
        List<Attribute> generateNewAttributes = generateNewAttributes(false);
        for (Attribute attribute2 : generateNewAttributes) {
            attribute2.setTableIndex(exampleSet.getAttributes().size());
            exampleSet.getAttributes().addRegular(attribute2);
            exampleSet.getExampleTable().addAttribute(attribute2);
        }
        int i = 0;
        Iterator<Instance> it = convertToInstanceList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            Example example = exampleSet.getExample(i);
            double[] probs = getProbs(next, i, z);
            int predictionId = getPredictionId(probs);
            example.setValue(generateNewAttributes.get(0), i);
            example.setValue(generateNewAttributes.get(1), generateNewAttributes.get(1).getMapping().mapString("Topic_" + Integer.toString(predictionId)));
            int i2 = 0;
            for (double d : probs) {
                example.setValue(generateNewAttributes.get(2 + i2), d);
                i2++;
            }
            i++;
        }
        return setRoles(exampleSet);
    }

    @Override // com.rapidminer.extension.operator.text_processing.modelling.AbstractDocumentModel
    public ExampleSet applyOnDocuments(IOObjectCollection<Document> iOObjectCollection) {
        return applyOnDocumentsWithConvertedInstances(iOObjectCollection, MalletHelper.convertDocsToInstances(iOObjectCollection), false);
    }

    public ExampleSet applyOnDocumentsWithConvertedInstances(IOObjectCollection<Document> iOObjectCollection, InstanceList instanceList, boolean z) {
        List<Attribute> generateNewAttributes = generateNewAttributes(true);
        if (this.addMetaData) {
            Iterator<Attribute> it = MalletHelper.getMetaDataHashMap(iOObjectCollection).values().iterator();
            while (it.hasNext()) {
                generateNewAttributes.add(it.next());
            }
        }
        ExampleSetBuilder from = ExampleSets.from(generateNewAttributes);
        int i = 0;
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            double[] probs = getProbs(it2.next(), i, z);
            double[] dArr = new double[generateNewAttributes.size()];
            dArr[0] = i;
            dArr[2] = generateNewAttributes.get(2).getMapping().mapString((String) instanceList.get(i).getSource());
            int predictionId = getPredictionId(probs);
            for (int i2 = 0; i2 < probs.length; i2++) {
                dArr[i2 + 3] = probs[i2];
            }
            dArr[1] = generateNewAttributes.get(1).getMapping().mapString("Topic_" + Integer.toString(predictionId));
            if (this.addMetaData) {
                Document element = iOObjectCollection.getElement(((Integer) instanceList.get(i).getName()).intValue(), false);
                for (int length = 3 + probs.length; length < generateNewAttributes.size(); length++) {
                    Object metaDataValue = element.getMetaDataValue(generateNewAttributes.get(length).getName());
                    Attribute attribute = generateNewAttributes.get(length);
                    if (attribute.isNominal()) {
                        dArr[length] = generateNewAttributes.get(length).getMapping().mapString((String) metaDataValue);
                    } else if (attribute.isNumerical()) {
                        dArr[length] = metaDataValue != null ? ((Double) metaDataValue).doubleValue() : Double.NaN;
                    } else {
                        dArr[length] = metaDataValue != null ? ((Date) metaDataValue).getTime() : Double.NaN;
                    }
                }
            }
            from.addRow(dArr);
            i++;
        }
        return setRoles(from.build());
    }

    private double[] getProbs(Instance instance, int i, boolean z) {
        double[] topicProbabilities;
        if (z) {
            topicProbabilities = this.topicModel.getTopicProbabilities(i);
        } else {
            TopicInferencer inferencer = this.topicModel.getInferencer();
            if (this.randomSeed == -1) {
                inferencer.setRandomSeed(this.topicModel.randomSeed);
            } else {
                inferencer.setRandomSeed(this.randomSeed);
            }
            topicProbabilities = inferencer.getSampledDistribution(instance, this.numIterations, this.thinning, this.burnIn);
        }
        return topicProbabilities;
    }

    private ExampleSet setRoles(ExampleSet exampleSet) {
        for (int i = 0; i < this.topicModel.getNumTopics(); i++) {
            exampleSet.getAttributes().setSpecialAttribute(exampleSet.getAttributes().get("confidence(Topic_" + Integer.toString(i) + ")"), "confidence_Topic_" + Integer.toString(i));
        }
        Attribute attribute = exampleSet.getAttributes().get("documentid");
        Attribute attribute2 = exampleSet.getAttributes().get("prediction(Topic)");
        exampleSet.getAttributes().setId(attribute);
        exampleSet.getAttributes().setPredictedLabel(attribute2);
        return exampleSet;
    }

    @Override // com.rapidminer.extension.operator.text_processing.modelling.AbstractDocumentModel
    public Attribute generatePredictionAttribute() {
        Attribute createAttribute = AttributeFactory.createAttribute("prediction(Topic)", 1);
        for (int i = 0; i < this.topicModel.getNumTopics(); i++) {
            createAttribute.getMapping().mapString("Topic_" + Integer.toString(i));
        }
        return createAttribute;
    }

    private List<Attribute> generateNewAttributes(boolean z) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("documentid", 3));
        arrayList.add(generatePredictionAttribute());
        if (z) {
            arrayList.add(AttributeFactory.createAttribute("text", 5));
        }
        for (int i = 0; i < this.topicModel.getNumTopics(); i++) {
            arrayList.add(AttributeFactory.createAttribute("confidence(Topic_" + Integer.toString(i) + ")", 4));
        }
        return arrayList;
    }

    private int getPredictionId(double[] dArr) {
        int i = 0;
        double d = Double.MIN_VALUE;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        return i;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x003f. Please report as an issue. */
    @Override // com.rapidminer.extension.operator.text_processing.modelling.AbstractDocumentModel
    public void setAdvancedParameters(List<String[]> list) {
        this.randomSeed = -1;
        this.numIterations = 1000;
        this.thinning = 10;
        this.burnIn = 10;
        for (String[] strArr : list) {
            String str = strArr[0];
            boolean z = -1;
            switch (str.hashCode()) {
                case -1250845061:
                    if (str.equals("LDA.iterations")) {
                        z = true;
                        break;
                    }
                    break;
                case -915590830:
                    if (str.equals("LDA.thinning")) {
                        z = 2;
                        break;
                    }
                    break;
                case -699587475:
                    if (str.equals("random_seed")) {
                        z = false;
                        break;
                    }
                    break;
                case 2003256057:
                    if (str.equals("LDA.burnin")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.randomSeed = Integer.parseInt(strArr[1]);
                    continue;
                case true:
                    this.numIterations = Integer.parseInt(strArr[1]);
                    getLog().log("SET NUMER OF ITERATIONS" + strArr[1]);
                    continue;
                case true:
                    this.thinning = Integer.parseInt(strArr[1]);
                    break;
            }
            this.burnIn = Integer.parseInt(strArr[1]);
        }
    }

    public void addScores(TopicModelDiagnostics.TopicScores topicScores) {
        this.topicScores.put(topicScores.name, topicScores.scores);
        double averageScore = getAverageScore(topicScores);
        boolean z = true;
        if (topicScores.name.equals("exclusivity") || topicScores.name.equals("coherence") || topicScores.name.equals("word-length")) {
            z = false;
        }
        this.performanceVector.addCriterion(new EstimatedPerformance("Avg(" + topicScores.name + ")", averageScore, 1, z));
    }

    public void calculatePerplexity(InstanceList instanceList) {
        this.performanceVector.addCriterion(new EstimatedPerformance("Perplexity", new PerplexityCalculator(this.topicModel).estimatePerplexity(instanceList), instanceList.size(), true));
        this.performanceVector.setMainCriterionName("Perplexity");
    }

    public void addAlphaAndBetaToPerformanceVector() {
        this.performanceVector.addCriterion(new EstimatedPerformance("AlphaSum", this.topicModel.alphaSum, 1, true));
        this.performanceVector.addCriterion(new EstimatedPerformance("Beta", this.topicModel.beta, 1, true));
        this.performanceVector.addCriterion(new EstimatedPerformance("BetaSum", this.topicModel.betaSum, 1, true));
    }

    public int getNumberOfTopics() {
        return this.topicModel.getNumTopics();
    }

    public void setTopicDiagnosticString(String str) {
        this.topicDiagnosticString = str;
    }

    public ParallelTopicModel getTopicModel() {
        return this.topicModel;
    }

    public PerformanceVector getPerformanceVector() {
        return this.performanceVector;
    }

    private double getAverageScore(TopicModelDiagnostics.TopicScores topicScores) {
        double d = 0.0d;
        int i = 0;
        for (double d2 : topicScores.scores) {
            d += d2;
            i++;
        }
        return d / i;
    }

    public boolean isAddMetaData() {
        return this.addMetaData;
    }

    public void setAddMetaData(boolean z) {
        this.addMetaData = z;
    }

    public String getTextAttName() {
        return this.textAttName;
    }

    public void setTextAttName(String str) {
        this.textAttName = str;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("LDA Model with " + Integer.toString(this.topicModel.getNumTopics()) + " topics \n");
        sb.append("alphaSum = " + Double.toString(this.topicModel.alphaSum) + "\n");
        sb.append("beta = " + Double.toString(this.topicModel.beta));
        sb.append("\n");
        if (this.topicDiagnosticString != null) {
            sb.append(this.topicDiagnosticString);
        }
        return sb.toString();
    }

    public DataTable createDataTable() {
        Set<String> keySet = this.topicScores.keySet();
        String[] strArr = new String[keySet.size() + 1];
        strArr[0] = "topic";
        int i = 1;
        Iterator<String> it = keySet.iterator();
        while (it.hasNext()) {
            strArr[i] = it.next();
            i++;
        }
        SimpleDataTable simpleDataTable = new SimpleDataTable("Topics", strArr);
        for (int i2 = 0; i2 < this.topicModel.getNumTopics(); i2++) {
            double[] dArr = new double[keySet.size() + 1];
            String str = "Topic_" + Integer.toString(i2);
            dArr[0] = simpleDataTable.mapString(0, str);
            int i3 = 1;
            Iterator<String> it2 = keySet.iterator();
            while (it2.hasNext()) {
                dArr[i3] = this.topicScores.get(it2.next())[i2];
                i3++;
            }
            simpleDataTable.add(new SimpleDataTableRow(dArr, str));
        }
        return simpleDataTable;
    }
}
