package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import gnu.trove.TIntHashSet;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;

/* loaded from: input_file:cc/mallet/topics/TopicModelDiagnostics.class */
public class TopicModelDiagnostics {
    int numTopics;
    int numTopWords;
    public static final int TWO_PERCENT_INDEX = 1;
    public static final int FIFTY_PERCENT_INDEX = 6;
    public static final double[] DEFAULT_DOC_PROPORTIONS = {0.01d, 0.02d, 0.05d, 0.1d, 0.2d, 0.3d, 0.5d};
    ArrayList<TreeSet<IDSorter>> topicSortedWords;
    String[][] topicTopWords;
    ParallelTopicModel model;
    Alphabet alphabet;
    int[][][] topicCodocumentMatrices;
    int[] numRank1Documents;
    int[] numNonZeroDocuments;
    int[][] numDocumentsAtProportions;
    double[] sumCountTimesLogCount;
    int[] wordTypeCounts;
    int numTokens = 0;
    ArrayList<TopicScores> diagnostics = new ArrayList<>();

    /* loaded from: input_file:cc/mallet/topics/TopicModelDiagnostics$TopicScores.class */
    public class TopicScores {
        public String name;
        public double[] scores;
        public double[][] topicWordScores;
        public boolean wordScoresDefined = false;

        public TopicScores(String str, int i, int i2) {
            this.name = str;
            this.scores = new double[i];
            this.topicWordScores = new double[i][i2];
        }

        public void setTopicScore(int i, double d) {
            this.scores[i] = d;
        }

        public void addToTopicScore(int i, double d) {
            double[] dArr = this.scores;
            dArr[i] = dArr[i] + d;
        }

        public void setTopicWordScore(int i, int i2, double d) {
            this.topicWordScores[i][i2] = d;
            this.wordScoresDefined = true;
        }
    }

    public TopicModelDiagnostics(ParallelTopicModel parallelTopicModel, int i) {
        this.numTopics = parallelTopicModel.getNumTopics();
        this.numTopWords = i;
        this.model = parallelTopicModel;
        this.alphabet = parallelTopicModel.getAlphabet();
        this.topicSortedWords = parallelTopicModel.getSortedWords();
        this.topicTopWords = new String[this.numTopics][i];
        this.numRank1Documents = new int[this.numTopics];
        this.numNonZeroDocuments = new int[this.numTopics];
        this.numDocumentsAtProportions = new int[this.numTopics][DEFAULT_DOC_PROPORTIONS.length];
        this.sumCountTimesLogCount = new double[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            TreeSet<IDSorter> treeSet = this.topicSortedWords.get(i2);
            int size = treeSet.size() < i ? treeSet.size() : i;
            Iterator<IDSorter> it = treeSet.iterator();
            for (int i3 = 0; i3 < size; i3++) {
                this.topicTopWords[i2][i3] = (String) this.alphabet.lookupObject(it.next().getID());
            }
        }
        collectDocumentStatistics();
        this.diagnostics.add(getTokensPerTopic(parallelTopicModel.tokensPerTopic));
        this.diagnostics.add(getDocumentEntropy(parallelTopicModel.tokensPerTopic));
        this.diagnostics.add(getWordLengthScores());
        this.diagnostics.add(getCoherence());
        this.diagnostics.add(getDistanceFromUniform());
        this.diagnostics.add(getDistanceFromCorpus());
        this.diagnostics.add(getEffectiveNumberOfWords());
        this.diagnostics.add(getTokenDocumentDiscrepancies());
        this.diagnostics.add(getRank1Percent());
        this.diagnostics.add(getDocumentPercentRatio(6, 1));
        this.diagnostics.add(getDocumentPercent(5));
        this.diagnostics.add(getExclusivity());
    }

    public void collectDocumentStatistics() {
        this.topicCodocumentMatrices = new int[this.numTopics][this.numTopWords][this.numTopWords];
        this.wordTypeCounts = new int[this.alphabet.size()];
        this.numTokens = 0;
        TIntHashSet[] tIntHashSetArr = new TIntHashSet[this.numTopics];
        int[][] iArr = new int[this.numTopics][this.numTopWords];
        TIntHashSet[] tIntHashSetArr2 = new TIntHashSet[this.numTopics];
        this.model.getData().size();
        int[] iArr2 = new int[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            TIntHashSet tIntHashSet = new TIntHashSet();
            for (int i2 = 0; i2 < this.numTopWords; i2++) {
                if (this.topicTopWords[i][i2] != null) {
                    int lookupIndex = this.alphabet.lookupIndex(this.topicTopWords[i][i2]);
                    iArr[i][i2] = lookupIndex;
                    tIntHashSet.add(lookupIndex);
                }
            }
            tIntHashSetArr[i] = tIntHashSet;
            tIntHashSetArr2[i] = new TIntHashSet();
        }
        int i3 = 0;
        Iterator<TopicAssignment> it = this.model.getData().iterator();
        while (it.hasNext()) {
            TopicAssignment next = it.next();
            FeatureSequence featureSequence = (FeatureSequence) next.instance.getData();
            LabelSequence labelSequence = next.topicSequence;
            for (int i4 = 0; i4 < featureSequence.size(); i4++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i4);
                int indexAtPosition2 = labelSequence.getIndexAtPosition(i4);
                this.numTokens++;
                int[] iArr3 = this.wordTypeCounts;
                iArr3[indexAtPosition] = iArr3[indexAtPosition] + 1;
                iArr2[indexAtPosition2] = iArr2[indexAtPosition2] + 1;
                if (tIntHashSetArr[indexAtPosition2].contains(indexAtPosition)) {
                    tIntHashSetArr2[indexAtPosition2].add(indexAtPosition);
                }
            }
            int size = featureSequence.size();
            if (size > 0) {
                int i5 = -1;
                int i6 = -1;
                for (int i7 = 0; i7 < this.numTopics; i7++) {
                    if (iArr2[i7] > 0) {
                        int[] iArr4 = this.numNonZeroDocuments;
                        int i8 = i7;
                        iArr4[i8] = iArr4[i8] + 1;
                        if (iArr2[i7] > i6) {
                            i5 = i7;
                            i6 = iArr2[i7];
                        }
                        double[] dArr = this.sumCountTimesLogCount;
                        int i9 = i7;
                        dArr[i9] = dArr[i9] + (iArr2[i7] * Math.log(iArr2[i7]));
                        double d = (this.model.alpha[i7] + iArr2[i7]) / (this.model.alphaSum + size);
                        for (int i10 = 0; i10 < DEFAULT_DOC_PROPORTIONS.length && d >= DEFAULT_DOC_PROPORTIONS[i10]; i10++) {
                            int[] iArr5 = this.numDocumentsAtProportions[i7];
                            int i11 = i10;
                            iArr5[i11] = iArr5[i11] + 1;
                        }
                        TIntHashSet tIntHashSet2 = tIntHashSetArr2[i7];
                        int[] iArr6 = iArr[i7];
                        for (int i12 = 0; i12 < this.numTopWords; i12++) {
                            if (tIntHashSet2.contains(iArr6[i12])) {
                                for (int i13 = i12; i13 < this.numTopWords; i13++) {
                                    if (i12 == i13) {
                                        int[] iArr7 = this.topicCodocumentMatrices[i7][i12];
                                        int i14 = i12;
                                        iArr7[i14] = iArr7[i14] + 1;
                                    } else if (tIntHashSet2.contains(iArr6[i13])) {
                                        int[] iArr8 = this.topicCodocumentMatrices[i7][i12];
                                        int i15 = i13;
                                        iArr8[i15] = iArr8[i15] + 1;
                                        int[] iArr9 = this.topicCodocumentMatrices[i7][i13];
                                        int i16 = i12;
                                        iArr9[i16] = iArr9[i16] + 1;
                                    }
                                }
                            }
                        }
                        tIntHashSetArr2[i7].clear();
                        iArr2[i7] = 0;
                    }
                }
                if (i5 > -1) {
                    int[] iArr10 = this.numRank1Documents;
                    int i17 = i5;
                    iArr10[i17] = iArr10[i17] + 1;
                }
            }
            i3++;
        }
    }

    public int[][] getCodocumentMatrix(int i) {
        return this.topicCodocumentMatrices[i];
    }

    public TopicScores getTokensPerTopic(int[] iArr) {
        TopicScores topicScores = new TopicScores("tokens", this.numTopics, this.numTopWords);
        for (int i = 0; i < this.numTopics; i++) {
            topicScores.setTopicScore(i, iArr[i]);
        }
        return topicScores;
    }

    public TopicScores getDocumentEntropy(int[] iArr) {
        TopicScores topicScores = new TopicScores("document_entropy", this.numTopics, this.numTopWords);
        for (int i = 0; i < this.numTopics; i++) {
            topicScores.setTopicScore(i, ((-this.sumCountTimesLogCount[i]) / iArr[i]) + Math.log(iArr[i]));
        }
        return topicScores;
    }

    public TopicScores getDistanceFromUniform() {
        int[] iArr = this.model.tokensPerTopic;
        TopicScores topicScores = new TopicScores("uniform_dist", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        int size = this.alphabet.size();
        for (int i = 0; i < this.numTopics; i++) {
            double d = 0.0d;
            int i2 = 0;
            Iterator<IDSorter> it = this.topicSortedWords.get(i).iterator();
            while (it.hasNext()) {
                IDSorter next = it.next();
                next.getID();
                double weight = next.getWeight();
                double log = (weight / iArr[i]) * Math.log((weight * size) / iArr[i]);
                if (i2 < this.numTopWords) {
                    topicScores.setTopicWordScore(i, i2, log);
                }
                d += log;
                i2++;
            }
            topicScores.setTopicScore(i, d);
        }
        return topicScores;
    }

    public TopicScores getEffectiveNumberOfWords() {
        int[] iArr = this.model.tokensPerTopic;
        TopicScores topicScores = new TopicScores("eff_num_words", this.numTopics, this.numTopWords);
        this.alphabet.size();
        for (int i = 0; i < this.numTopics; i++) {
            double d = 0.0d;
            Iterator<IDSorter> it = this.topicSortedWords.get(i).iterator();
            while (it.hasNext()) {
                IDSorter next = it.next();
                next.getID();
                double weight = next.getWeight() / iArr[i];
                d += weight * weight;
            }
            topicScores.setTopicScore(i, 1.0d / d);
        }
        return topicScores;
    }

    public TopicScores getDistanceFromCorpus() {
        int[] iArr = this.model.tokensPerTopic;
        TopicScores topicScores = new TopicScores("corpus_dist", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        for (int i = 0; i < this.numTopics; i++) {
            double d = this.numTokens / iArr[i];
            double d2 = 0.0d;
            int i2 = 0;
            Iterator<IDSorter> it = this.topicSortedWords.get(i).iterator();
            while (it.hasNext()) {
                IDSorter next = it.next();
                int id = next.getID();
                double weight = next.getWeight();
                double log = (weight / iArr[i]) * Math.log((d * weight) / this.wordTypeCounts[id]);
                if (i2 < this.numTopWords) {
                    topicScores.setTopicWordScore(i, i2, log);
                }
                d2 += log;
                i2++;
            }
            topicScores.setTopicScore(i, d2);
        }
        return topicScores;
    }

    public TopicScores getTokenDocumentDiscrepancies() {
        TopicScores topicScores = new TopicScores("token-doc-diff", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        for (int i = 0; i < this.numTopics; i++) {
            int[][] iArr = this.topicCodocumentMatrices[i];
            TreeSet<IDSorter> treeSet = this.topicSortedWords.get(i);
            double d = 0.0d;
            double[] dArr = new double[this.numTopWords];
            double[] dArr2 = new double[this.numTopWords];
            double d2 = 0.0d;
            double d3 = 0.0d;
            Iterator<IDSorter> it = treeSet.iterator();
            for (int i2 = 0; it.hasNext() && i2 < this.numTopWords; i2++) {
                dArr[i2] = it.next().getWeight();
                dArr2[i2] = iArr[i2][i2];
                d2 += dArr[i2];
                d3 += dArr2[i2];
            }
            for (int i3 = 0; i3 < this.numTopWords; i3++) {
                double d4 = dArr[i3] / d2;
                double d5 = dArr2[i3] / d3;
                double d6 = 0.5d * (d4 + d5);
                double log = d4 > 0.0d ? 0.0d + (0.5d * d4 * Math.log(d4 / d6)) : 0.0d;
                if (d5 > 0.0d) {
                    log += 0.5d * d5 * Math.log(d5 / d6);
                }
                topicScores.setTopicWordScore(i, i3, log);
                d += log;
            }
            topicScores.setTopicScore(i, d);
        }
        return topicScores;
    }

    public TopicScores getWordLengthScores() {
        TopicScores topicScores = new TopicScores("word-length", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        for (int i = 0; i < this.numTopics; i++) {
            int i2 = 0;
            for (int i3 = 0; i3 < this.topicTopWords[i].length && this.topicTopWords[i][i3] != null; i3++) {
                int length = this.topicTopWords[i][i3].length();
                i2 += length;
                topicScores.setTopicWordScore(i, i3, length);
            }
            topicScores.setTopicScore(i, i2 / this.topicTopWords[i].length);
        }
        return topicScores;
    }

    public TopicScores getWordLengthStandardDeviation() {
        TopicScores topicScores = new TopicScores("word-length-sd", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            for (int i3 = 0; i3 < this.topicTopWords[i2].length && this.topicTopWords[i2][i3] != null; i3++) {
                d += this.topicTopWords[i2][i3].length();
                i++;
            }
        }
        double d2 = d / i;
        double d3 = 0.0d;
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            for (int i5 = 0; i5 < this.topicTopWords[i4].length && this.topicTopWords[i4][i5] != null; i5++) {
                int length = this.topicTopWords[i4][i5].length();
                d3 += (length - d2) * (length - d2);
            }
        }
        double sqrt = Math.sqrt(d3 / (i - 1));
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            for (int i7 = 0; i7 < this.topicTopWords[i6].length && this.topicTopWords[i6][i7] != null; i7++) {
                int length2 = this.topicTopWords[i6][i7].length();
                topicScores.addToTopicScore(i6, (length2 - d2) / sqrt);
                topicScores.setTopicWordScore(i6, i7, (length2 - d2) / sqrt);
            }
        }
        return topicScores;
    }

    public TopicScores getCoherence() {
        TopicScores topicScores = new TopicScores("coherence", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        for (int i = 0; i < this.numTopics; i++) {
            int[][] iArr = this.topicCodocumentMatrices[i];
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numTopWords; i2++) {
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i3 = 0; i3 < i2; i3++) {
                    double log = Math.log((iArr[i2][i3] + this.model.beta) / (iArr[i3][i3] + this.model.beta));
                    d2 += log;
                    if (log < d3) {
                        d3 = log;
                    }
                }
                d += d2;
                topicScores.setTopicWordScore(i, i2, d3);
            }
            topicScores.setTopicScore(i, d);
        }
        return topicScores;
    }

    public TopicScores getRank1Percent() {
        TopicScores topicScores = new TopicScores("rank_1_docs", this.numTopics, this.numTopWords);
        for (int i = 0; i < this.numTopics; i++) {
            topicScores.setTopicScore(i, this.numRank1Documents[i] / this.numNonZeroDocuments[i]);
        }
        return topicScores;
    }

    public TopicScores getDocumentPercentRatio(int i, int i2) {
        TopicScores topicScores = new TopicScores("allocation_ratio", this.numTopics, this.numTopWords);
        if (i > this.numDocumentsAtProportions[0].length || i2 > this.numDocumentsAtProportions[0].length) {
            System.err.println("Invalid proportion indices (max " + (this.numDocumentsAtProportions[0].length - 1) + ") : " + i + ", " + i2);
            return topicScores;
        }
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            topicScores.setTopicScore(i3, this.numDocumentsAtProportions[i3][i] / this.numDocumentsAtProportions[i3][i2]);
        }
        return topicScores;
    }

    public TopicScores getDocumentPercent(int i) {
        TopicScores topicScores = new TopicScores("allocation_count", this.numTopics, this.numTopWords);
        if (i > this.numDocumentsAtProportions[0].length) {
            System.err.println("Invalid proportion indices (max " + (this.numDocumentsAtProportions[0].length - 1) + ") : " + i);
            return topicScores;
        }
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            topicScores.setTopicScore(i2, this.numDocumentsAtProportions[i2][i] / this.numNonZeroDocuments[i2]);
        }
        return topicScores;
    }

    public TopicScores getExclusivity() {
        int[] iArr = this.model.tokensPerTopic;
        TopicScores topicScores = new TopicScores("exclusivity", this.numTopics, this.numTopWords);
        topicScores.wordScoresDefined = true;
        double d = 0.0d;
        for (int i = 0; i < this.numTopics; i++) {
            d += this.model.beta / (this.model.betaSum + iArr[i]);
        }
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            double d2 = 0.0d;
            int i3 = 0;
            Iterator<IDSorter> it = this.topicSortedWords.get(i2).iterator();
            while (it.hasNext()) {
                IDSorter next = it.next();
                int id = next.getID();
                double weight = next.getWeight();
                double d3 = d;
                int[] iArr2 = this.model.typeTopicCounts[id];
                for (int i4 = 0; i4 < iArr2.length && iArr2[i4] > 0; i4++) {
                    int i5 = iArr2[i4] & this.model.topicMask;
                    d3 += (iArr2[i4] >> this.model.topicBits) / (this.model.betaSum + iArr[i5]);
                }
                double d4 = ((this.model.beta + weight) / (this.model.betaSum + iArr[i2])) / d3;
                topicScores.setTopicWordScore(i2, i3, d4);
                d2 += d4;
                i3++;
                if (i3 == this.numTopWords) {
                    break;
                }
            }
            topicScores.setTopicScore(i2, d2 / this.numTopWords);
        }
        return topicScores;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        Formatter formatter = new Formatter(sb, Locale.US);
        for (int i = 0; i < this.numTopics; i++) {
            formatter.format("Topic %d", Integer.valueOf(i));
            Iterator<TopicScores> it = this.diagnostics.iterator();
            while (it.hasNext()) {
                TopicScores next = it.next();
                formatter.format("\t%s=%.4f", next.name, Double.valueOf(next.scores[i]));
            }
            formatter.format("\n", new Object[0]);
            for (int i2 = 0; i2 < this.topicTopWords[i].length && this.topicTopWords[i][i2] != null; i2++) {
                formatter.format("  %s", this.topicTopWords[i][i2]);
                Iterator<TopicScores> it2 = this.diagnostics.iterator();
                while (it2.hasNext()) {
                    TopicScores next2 = it2.next();
                    if (next2.wordScoresDefined) {
                        formatter.format("\t%s=%.4f", next2.name, Double.valueOf(next2.topicWordScores[i][i2]));
                    }
                }
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    public String toXML() {
        int[] iArr = this.model.tokensPerTopic;
        StringBuilder sb = new StringBuilder();
        Formatter formatter = new Formatter(sb, Locale.US);
        sb.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
        sb.append("<model>\n");
        for (int i = 0; i < this.numTopics; i++) {
            int[][] iArr2 = this.topicCodocumentMatrices[i];
            formatter.format("<topic id='%d'", Integer.valueOf(i));
            Iterator<TopicScores> it = this.diagnostics.iterator();
            while (it.hasNext()) {
                TopicScores next = it.next();
                formatter.format(" %s='%.4f'", next.name, Double.valueOf(next.scores[i]));
            }
            sb.append(">\n");
            TreeSet<IDSorter> treeSet = this.topicSortedWords.get(i);
            int i2 = this.numTopWords;
            if (treeSet.size() < this.numTopWords) {
                i2 = treeSet.size();
            }
            double d = 0.0d;
            Iterator<IDSorter> it2 = treeSet.iterator();
            for (int i3 = 0; i3 < i2; i3++) {
                IDSorter next2 = it2.next();
                double weight = next2.getWeight() / iArr[i];
                d += weight;
                formatter.format("<word rank='%d' count='%.0f' prob='%.5f' cumulative='%.5f' docs='%d'", Integer.valueOf(i3 + 1), Double.valueOf(next2.getWeight()), Double.valueOf(weight), Double.valueOf(d), Integer.valueOf(iArr2[i3][i3]));
                Iterator<TopicScores> it3 = this.diagnostics.iterator();
                while (it3.hasNext()) {
                    TopicScores next3 = it3.next();
                    if (next3.wordScoresDefined) {
                        formatter.format(" %s='%.4f'", next3.name, Double.valueOf(next3.topicWordScores[i][i3]));
                    }
                }
                formatter.format(">%s</word>\n", this.topicTopWords[i][i3].replaceAll("&", "&amp;").replaceAll("<", "&gt;"));
            }
            sb.append("</topic>\n");
        }
        sb.append("</model>\n");
        return sb.toString();
    }

    public static void main(String[] strArr) throws Exception {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        ParallelTopicModel parallelTopicModel = new ParallelTopicModel(Integer.parseInt(strArr[1]), 5.0d, 0.01d);
        parallelTopicModel.addInstances(load);
        parallelTopicModel.setNumIterations(1000);
        parallelTopicModel.estimate();
        TopicModelDiagnostics topicModelDiagnostics = new TopicModelDiagnostics(parallelTopicModel, 20);
        if (strArr.length == 3) {
            PrintWriter printWriter = new PrintWriter(strArr[2]);
            printWriter.println(topicModelDiagnostics.toXML());
            printWriter.close();
        }
    }
}
