package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;
import org.jfree.chart.ChartPanel;

/* loaded from: input_file:cc/mallet/topics/PolylingualTopicModel.class */
public class PolylingualTopicModel implements Serializable {
    static CommandOption.SpacedStrings languageInputFiles;
    static CommandOption.String outputModelFilename;
    static CommandOption.String inputModelFilename;
    static CommandOption.String inferencerFilename;
    static CommandOption.String evaluatorFilename;
    static CommandOption.String stateFile;
    static CommandOption.String topicKeysFile;
    static CommandOption.String docTopicsFile;
    static CommandOption.Double docTopicsThreshold;
    static CommandOption.Integer docTopicsMax;
    static CommandOption.Integer outputModelIntervalOption;
    static CommandOption.Integer outputStateIntervalOption;
    static CommandOption.Integer numTopicsOption;
    static CommandOption.Integer numIterationsOption;
    static CommandOption.Integer randomSeedOption;
    static CommandOption.Integer topWordsOption;
    static CommandOption.Integer showTopicsIntervalOption;
    static CommandOption.Integer optimizeIntervalOption;
    static CommandOption.Integer optimizeBurnInOption;
    static CommandOption.Double alphaOption;
    static CommandOption.Double betaOption;
    int numLanguages;
    protected ArrayList<TopicAssignment> data;
    protected LabelAlphabet topicAlphabet;
    protected int numStopwords;
    protected int numTopics;
    HashSet<String> testingIDs;
    protected int topicMask;
    protected int topicBits;
    protected Alphabet[] alphabets;
    protected int[] vocabularySizes;
    protected double[] alpha;
    protected double alphaSum;
    protected double[] betas;
    protected double[] betaSums;
    protected int[] languageMaxTypeCounts;
    public static final double DEFAULT_BETA = 0.01d;
    protected double[] languageSmoothingOnlyMasses;
    protected double[][] languageCachedCoefficients;
    int topicTermCount;
    int betaTopicCount;
    int smoothingOnlyCount;
    protected int[] oneDocTopicCounts;
    protected int[][][] languageTypeTopicCounts;
    protected int[][] languageTokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    protected int iterationsSoFar;
    public int numIterations;
    public int burninPeriod;
    public int saveSampleInterval;
    public int optimizeInterval;
    public int showTopicsInterval;
    public int wordsPerTopic;
    protected int saveModelInterval;
    protected String modelFilename;
    protected int saveStateInterval;
    protected String stateFilename;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/topics/PolylingualTopicModel$TopicAssignment.class */
    public class TopicAssignment implements Serializable {
        public Instance[] instances;
        public LabelSequence[] topicSequences;
        public Labeling topicDistribution;

        public TopicAssignment(Instance[] instanceArr, LabelSequence[] labelSequenceArr) {
            this.instances = instanceArr;
            this.topicSequences = labelSequenceArr;
        }
    }

    public PolylingualTopicModel(int i) {
        this(i, i);
    }

    public PolylingualTopicModel(int i, double d) {
        this(i, d, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int i) {
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        for (int i2 = 0; i2 < i; i2++) {
            labelAlphabet.lookupIndex("topic" + i2);
        }
        return labelAlphabet;
    }

    public PolylingualTopicModel(int i, double d, Randoms randoms) {
        this(newLabelAlphabet(i), d, randoms);
    }

    public PolylingualTopicModel(LabelAlphabet labelAlphabet, double d, Randoms randoms) {
        this.numLanguages = 1;
        this.numStopwords = 0;
        this.testingIDs = null;
        this.topicTermCount = 0;
        this.betaTopicCount = 0;
        this.smoothingOnlyCount = 0;
        this.iterationsSoFar = 1;
        this.numIterations = 1000;
        this.burninPeriod = 5;
        this.saveSampleInterval = 5;
        this.optimizeInterval = 10;
        this.showTopicsInterval = 10;
        this.wordsPerTopic = 7;
        this.saveModelInterval = 0;
        this.saveStateInterval = 0;
        this.stateFilename = null;
        this.printLogLikelihood = false;
        this.data = new ArrayList<>();
        this.topicAlphabet = labelAlphabet;
        this.numTopics = labelAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = (Integer.highestOneBit(this.numTopics) * 2) - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = d;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, d / this.numTopics);
        this.random = randoms;
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Polylingual LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public void loadTestingIDs(File file) throws IOException {
        this.testingIDs = new HashSet<>();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return;
            }
            this.testingIDs.add(readLine);
        }
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setBurninPeriod(int i) {
        this.burninPeriod = i;
    }

    public void setTopicDisplay(int i, int i2) {
        this.showTopicsInterval = i;
        this.wordsPerTopic = i2;
    }

    public void setRandomSeed(int i) {
        this.random = new Randoms(i);
    }

    public void setOptimizeInterval(int i) {
        this.optimizeInterval = i;
    }

    public void setModelOutput(int i, String str) {
        this.saveModelInterval = i;
        this.modelFilename = str;
    }

    public void setSaveState(int i, String str) {
        this.saveStateInterval = i;
        this.stateFilename = str;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v22, types: [int[][], int[][][]] */
    public void addInstances(InstanceList[] instanceListArr) {
        this.numLanguages = instanceListArr.length;
        this.languageTokensPerTopic = new int[this.numLanguages][this.numTopics];
        this.alphabets = new Alphabet[this.numLanguages];
        this.vocabularySizes = new int[this.numLanguages];
        this.betas = new double[this.numLanguages];
        this.betaSums = new double[this.numLanguages];
        this.languageMaxTypeCounts = new int[this.numLanguages];
        this.languageTypeTopicCounts = new int[this.numLanguages];
        int size = instanceListArr[0].size();
        HashSet[] hashSetArr = new HashSet[this.numLanguages];
        for (int i = 0; i < this.numLanguages; i++) {
            if (instanceListArr[i].size() != size) {
                System.err.println("Warning: language " + i + " has " + instanceListArr[i].size() + " instances, lang 0 has " + size);
            }
            this.alphabets[i] = instanceListArr[i].getDataAlphabet();
            this.vocabularySizes[i] = this.alphabets[i].size();
            this.betas[i] = 0.01d;
            this.betaSums[i] = this.betas[i] * this.vocabularySizes[i];
            this.languageTypeTopicCounts[i] = new int[this.vocabularySizes[i]];
            int[][] iArr = this.languageTypeTopicCounts[i];
            int[] iArr2 = new int[this.vocabularySizes[i]];
            Iterator<Instance> it = instanceListArr[i].iterator();
            while (it.hasNext()) {
                Instance next = it.next();
                if (this.testingIDs == null || !this.testingIDs.contains(next.getName())) {
                    FeatureSequence featureSequence = (FeatureSequence) next.getData();
                    for (int i2 = 0; i2 < featureSequence.getLength(); i2++) {
                        int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                        iArr2[indexAtPosition] = iArr2[indexAtPosition] + 1;
                    }
                }
            }
            for (int i3 = 0; i3 < this.vocabularySizes[i]; i3++) {
                if (iArr2[i3] > this.languageMaxTypeCounts[i]) {
                    this.languageMaxTypeCounts[i] = iArr2[i3];
                }
                iArr[i3] = new int[Math.min(this.numTopics, iArr2[i3])];
            }
        }
        for (int i4 = 0; i4 < size; i4++) {
            if (this.testingIDs == null || !this.testingIDs.contains(instanceListArr[0].get(i4).getName())) {
                Instance[] instanceArr = new Instance[this.numLanguages];
                LabelSequence[] labelSequenceArr = new LabelSequence[this.numLanguages];
                for (int i5 = 0; i5 < this.numLanguages; i5++) {
                    int[][] iArr3 = this.languageTypeTopicCounts[i5];
                    int[] iArr4 = this.languageTokensPerTopic[i5];
                    instanceArr[i5] = instanceListArr[i5].get(i4);
                    FeatureSequence featureSequence2 = (FeatureSequence) instanceArr[i5].getData();
                    labelSequenceArr[i5] = new LabelSequence(this.topicAlphabet, new int[featureSequence2.size()]);
                    int[] features = labelSequenceArr[i5].getFeatures();
                    for (int i6 = 0; i6 < featureSequence2.size(); i6++) {
                        int[] iArr5 = iArr3[featureSequence2.getIndexAtPosition(i6)];
                        int nextInt = this.random.nextInt(this.numTopics);
                        features[i6] = nextInt;
                        iArr4[nextInt] = iArr4[nextInt] + 1;
                        int i7 = 0;
                        int i8 = iArr5[0];
                        int i9 = this.topicMask;
                        while (true) {
                            int i10 = i8 & i9;
                            if (iArr5[i7] <= 0 || i10 == nextInt) {
                                break;
                            }
                            i7++;
                            i8 = iArr5[i7];
                            i9 = this.topicMask;
                        }
                        int i11 = iArr5[i7] >> this.topicBits;
                        if (i11 == 0) {
                            iArr5[i7] = (1 << this.topicBits) + nextInt;
                        } else {
                            iArr5[i7] = ((i11 + 1) << this.topicBits) + nextInt;
                            while (i7 > 0 && iArr5[i7] > iArr5[i7 - 1]) {
                                int i12 = iArr5[i7];
                                iArr5[i7] = iArr5[i7 - 1];
                                iArr5[i7 - 1] = i12;
                                i7--;
                            }
                        }
                    }
                }
                this.data.add(new TopicAssignment(instanceArr, labelSequenceArr));
            }
        }
        initializeHistograms();
        this.languageSmoothingOnlyMasses = new double[this.numLanguages];
        this.languageCachedCoefficients = new double[this.numLanguages][this.numTopics];
        cacheValues();
    }

    private void initializeHistograms() {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < this.data.size(); i3++) {
            int i4 = 0;
            for (LabelSequence labelSequence : this.data.get(i3).topicSequences) {
                i4 += labelSequence.getLength();
            }
            if (i4 > i) {
                i = i4;
            }
            i2 += i4;
        }
        System.err.println("max tokens: " + i);
        System.err.println("total tokens: " + i2);
        this.docLengthCounts = new int[i + 1];
        this.topicDocCounts = new int[this.numTopics][i + 1];
    }

    private void cacheValues() {
        for (int i = 0; i < this.numLanguages; i++) {
            this.languageSmoothingOnlyMasses[i] = 0.0d;
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                double[] dArr = this.languageSmoothingOnlyMasses;
                int i3 = i;
                dArr[i3] = dArr[i3] + ((this.alpha[i2] * this.betas[i]) / (this.languageTokensPerTopic[i][i2] + this.betaSums[i]));
                this.languageCachedCoefficients[i][i2] = this.alpha[i2] / (this.languageTokensPerTopic[i][i2] + this.betaSums[i]);
            }
        }
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int i = 0; i < this.topicDocCounts.length; i++) {
            Arrays.fill(this.topicDocCounts[i], 0);
        }
    }

    public void estimate() throws IOException {
        estimate(this.numIterations);
    }

    public void estimate(int i) throws IOException {
        System.currentTimeMillis();
        int i2 = this.iterationsSoFar + i;
        long j = 0;
        while (this.iterationsSoFar <= i2) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                optimizeBetas();
                clearHistograms();
                cacheValues();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            for (int i3 = 0; i3 < this.data.size(); i3++) {
                sampleTopicsForOneDoc(this.data.get(i3), this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0);
            }
            long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
            j += currentTimeMillis2;
            if ((this.iterationsSoFar + 1) % 10 == 0) {
                System.out.println(currentTimeMillis2 + "\t" + j + "\t" + modelLogLikelihood());
            } else {
                System.out.print(currentTimeMillis2 + " ");
            }
            this.iterationsSoFar++;
        }
    }

    public void optimizeBetas() {
        for (int i = 0; i < this.numLanguages; i++) {
            int[] iArr = new int[this.languageMaxTypeCounts[i] + 1];
            int[][] iArr2 = this.languageTypeTopicCounts[i];
            int[] iArr3 = this.languageTokensPerTopic[i];
            for (int i2 = 0; i2 < this.vocabularySizes[i]; i2++) {
                int[] iArr4 = iArr2[i2];
                for (int i3 = 0; i3 < iArr4.length && iArr4[i3] > 0; i3++) {
                    int i4 = iArr4[i3] >> this.topicBits;
                    iArr[i4] = iArr[i4] + 1;
                }
            }
            int i5 = 0;
            for (int i6 = 0; i6 < this.numTopics; i6++) {
                if (iArr3[i6] > i5) {
                    i5 = iArr3[i6];
                }
            }
            int[] iArr5 = new int[i5 + 1];
            for (int i7 = 0; i7 < this.numTopics; i7++) {
                int i8 = iArr3[i7];
                iArr5[i8] = iArr5[i8] + 1;
            }
            this.betaSums[i] = Dirichlet.learnSymmetricConcentration(iArr, iArr5, this.vocabularySizes[i], this.betaSums[i]);
            this.betas[i] = this.betaSums[i] / this.vocabularySizes[i];
        }
    }

    protected void sampleTopicsForOneDoc(TopicAssignment topicAssignment, boolean z) {
        int[] iArr = new int[this.numTopics];
        int[] iArr2 = new int[this.numTopics];
        for (int i = 0; i < this.numLanguages; i++) {
            int[] features = topicAssignment.topicSequences[i].getFeatures();
            int length = topicAssignment.topicSequences[i].getLength();
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = features[i2];
                iArr[i3] = iArr[i3] + 1;
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < this.numTopics; i5++) {
            if (iArr[i5] != 0) {
                iArr2[i4] = i5;
                i4++;
            }
        }
        int i6 = i4;
        for (int i7 = 0; i7 < this.numLanguages; i7++) {
            int[] features2 = topicAssignment.topicSequences[i7].getFeatures();
            int length2 = topicAssignment.topicSequences[i7].getLength();
            FeatureSequence featureSequence = (FeatureSequence) topicAssignment.instances[i7].getData();
            int[][] iArr3 = this.languageTypeTopicCounts[i7];
            int[] iArr4 = this.languageTokensPerTopic[i7];
            double d = this.betas[i7];
            double d2 = this.betaSums[i7];
            double d3 = this.languageSmoothingOnlyMasses[i7];
            double[] dArr = this.languageCachedCoefficients[i7];
            double d4 = 0.0d;
            for (int i8 = 0; i8 < i6; i8++) {
                int i9 = iArr2[i8];
                int i10 = iArr[i9];
                d4 += (d * i10) / (iArr4[i9] + d2);
                dArr[i9] = (this.alpha[i9] + i10) / (iArr4[i9] + d2);
            }
            double[] dArr2 = new double[this.numTopics];
            for (int i11 = 0; i11 < length2; i11++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i11);
                int i12 = features2[i11];
                if (i12 != -1) {
                    int[] iArr5 = iArr3[indexAtPosition];
                    double d5 = d3 - ((this.alpha[i12] * d) / (iArr4[i12] + d2));
                    double d6 = d4 - ((d * iArr[i12]) / (iArr4[i12] + d2));
                    iArr[i12] = iArr[i12] - 1;
                    if (iArr[i12] == 0) {
                        int i13 = 0;
                        while (iArr2[i13] != i12) {
                            i13++;
                        }
                        while (i13 < i6) {
                            if (i13 < iArr2.length - 1) {
                                iArr2[i13] = iArr2[i13 + 1];
                            }
                            i13++;
                        }
                        i6--;
                    }
                    iArr4[i12] = iArr4[i12] - 1;
                    double d7 = d5 + ((this.alpha[i12] * d) / (iArr4[i12] + d2));
                    double d8 = d6 + ((d * iArr[i12]) / (iArr4[i12] + d2));
                    dArr[i12] = (this.alpha[i12] + iArr[i12]) / (iArr4[i12] + d2);
                    int i14 = 0;
                    boolean z2 = false;
                    double d9 = 0.0d;
                    while (i14 < iArr5.length && iArr5[i14] > 0) {
                        int i15 = iArr5[i14] & this.topicMask;
                        int i16 = iArr5[i14] >> this.topicBits;
                        if (z2 || i15 != i12) {
                            double d10 = dArr[i15] * i16;
                            d9 += d10;
                            dArr2[i14] = d10;
                            i14++;
                        } else {
                            int i17 = i16 - 1;
                            if (i17 == 0) {
                                iArr5[i14] = 0;
                            } else {
                                iArr5[i14] = (i17 << this.topicBits) + i12;
                            }
                            for (int i18 = i14; i18 < iArr5.length - 1 && iArr5[i18] < iArr5[i18 + 1]; i18++) {
                                int i19 = iArr5[i18];
                                iArr5[i18] = iArr5[i18 + 1];
                                iArr5[i18 + 1] = i19;
                            }
                            z2 = true;
                        }
                    }
                    double nextUniform = this.random.nextUniform() * (d7 + d8 + d9);
                    int i20 = -1;
                    if (nextUniform < d9) {
                        int i21 = -1;
                        while (nextUniform > 0.0d) {
                            i21++;
                            nextUniform -= dArr2[i21];
                        }
                        i20 = iArr5[i21] & this.topicMask;
                        iArr5[i21] = (((iArr5[i21] >> this.topicBits) + 1) << this.topicBits) + i20;
                        while (i21 > 0 && iArr5[i21] > iArr5[i21 - 1]) {
                            int i22 = iArr5[i21];
                            iArr5[i21] = iArr5[i21 - 1];
                            iArr5[i21 - 1] = i22;
                            i21--;
                        }
                    } else {
                        double d11 = nextUniform - d9;
                        if (d11 >= d8) {
                            i20 = 0;
                            double d12 = (d11 - d8) / d;
                            double d13 = this.alpha[0];
                            int i23 = iArr4[0];
                            while (true) {
                                nextUniform = d12 - (d13 / (i23 + d2));
                                if (nextUniform <= 0.0d) {
                                    break;
                                }
                                i20++;
                                d12 = nextUniform;
                                d13 = this.alpha[i20];
                                i23 = iArr4[i20];
                            }
                        } else {
                            nextUniform = d11 / d;
                            int i24 = 0;
                            while (true) {
                                if (i24 >= i6) {
                                    break;
                                }
                                int i25 = iArr2[i24];
                                nextUniform -= iArr[i25] / (iArr4[i25] + d2);
                                if (nextUniform <= 0.0d) {
                                    i20 = i25;
                                    break;
                                }
                                i24++;
                            }
                        }
                        int i26 = 0;
                        while (iArr5[i26] > 0 && (iArr5[i26] & this.topicMask) != i20) {
                            i26++;
                        }
                        if (iArr5[i26] == 0) {
                            iArr5[i26] = (1 << this.topicBits) + i20;
                        } else {
                            iArr5[i26] = (((iArr5[i26] >> this.topicBits) + 1) << this.topicBits) + i20;
                            while (i26 > 0 && iArr5[i26] > iArr5[i26 - 1]) {
                                int i27 = iArr5[i26];
                                iArr5[i26] = iArr5[i26 - 1];
                                iArr5[i26 - 1] = i27;
                                i26--;
                            }
                        }
                    }
                    if (i20 == -1) {
                        System.err.println("PolylingualTopicModel sampling error: " + nextUniform + " " + nextUniform + " " + d7 + " " + d8 + " " + d9);
                        i20 = this.numTopics - 1;
                    }
                    features2[i11] = i20;
                    double d14 = d7 - ((this.alpha[i20] * d) / (iArr4[i20] + d2));
                    double d15 = d8 - ((d * iArr[i20]) / (iArr4[i20] + d2));
                    int i28 = i20;
                    iArr[i28] = iArr[i28] + 1;
                    if (iArr[i20] == 1) {
                        int i29 = i6;
                        while (i29 > 0 && iArr2[i29 - 1] > i20) {
                            iArr2[i29] = iArr2[i29 - 1];
                            i29--;
                        }
                        iArr2[i29] = i20;
                        i6++;
                    }
                    int i30 = i20;
                    iArr4[i30] = iArr4[i30] + 1;
                    dArr[i20] = (this.alpha[i20] + iArr[i20]) / (iArr4[i20] + d2);
                    d3 = d14 + ((this.alpha[i20] * d) / (iArr4[i20] + d2));
                    d4 = d15 + ((d * iArr[i20]) / (iArr4[i20] + d2));
                    this.languageSmoothingOnlyMasses[i7] = d3;
                }
            }
        }
        if (z) {
            int i31 = 0;
            for (int i32 = 0; i32 < i6; i32++) {
                int i33 = iArr2[i32];
                int[] iArr6 = this.topicDocCounts[i33];
                int i34 = iArr[i33];
                iArr6[i34] = iArr6[i34] + 1;
                i31 += iArr[i33];
            }
            int[] iArr7 = this.docLengthCounts;
            int i35 = i31;
            iArr7[i35] = iArr7[i35] + 1;
        }
    }

    public void printTopWords(File file, int i, boolean z) throws IOException {
        PrintStream printStream = new PrintStream(file);
        printTopWords(printStream, i, z);
        printStream.close();
    }

    public void printTopWords(PrintStream printStream, int i, boolean z) {
        TreeSet[][] treeSetArr = new TreeSet[this.numLanguages][this.numTopics];
        for (int i2 = 0; i2 < this.numLanguages; i2++) {
            TreeSet[] treeSetArr2 = treeSetArr[i2];
            int[][] iArr = this.languageTypeTopicCounts[i2];
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                treeSetArr2[i3] = new TreeSet();
            }
            for (int i4 = 0; i4 < this.vocabularySizes[i2]; i4++) {
                int[] iArr2 = iArr[i4];
                for (int i5 = 0; i5 < iArr2.length && iArr2[i5] > 0; i5++) {
                    treeSetArr2[iArr2[i5] & this.topicMask].add(new IDSorter(i4, iArr2[i5] >> this.topicBits));
                }
            }
        }
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            printStream.println(i6 + "\t" + this.formatter.format(this.alpha[i6]));
            for (int i7 = 0; i7 < this.numLanguages; i7++) {
                printStream.print(" " + i7 + "\t" + this.languageTokensPerTopic[i7][i6] + "\t" + this.betas[i7] + "\t");
                TreeSet treeSet = treeSetArr[i7][i6];
                Alphabet alphabet = this.alphabets[i7];
                Iterator it = treeSet.iterator();
                for (int i8 = 1; it.hasNext() && i8 < i; i8++) {
                    printStream.print(alphabet.lookupObject(((IDSorter) it.next()).getID()) + " ");
                }
                printStream.println();
            }
        }
    }

    public void printDocumentTopics(File file) throws IOException {
        printDocumentTopics(new PrintWriter(file, "UTF-8"));
    }

    public void printDocumentTopics(PrintWriter printWriter) {
        printDocumentTopics(printWriter, 0.0d, -1);
    }

    public void printDocumentTopics(PrintWriter printWriter, double d, int i) {
        printWriter.print("#doc source topic proportion ...\n");
        int[] iArr = new int[this.numTopics];
        IDSorter[] iDSorterArr = new IDSorter[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            iDSorterArr[i2] = new IDSorter(i2, i2);
        }
        if (i < 0 || i > this.numTopics) {
            i = this.numTopics;
        }
        for (int i3 = 0; i3 < this.data.size(); i3++) {
            printWriter.print(i3);
            printWriter.print(' ');
            int i4 = 0;
            for (int i5 = 0; i5 < this.numLanguages; i5++) {
                LabelSequence labelSequence = this.data.get(i3).topicSequences[i5];
                int[] features = labelSequence.getFeatures();
                int length = labelSequence.getLength();
                i4 += length;
                for (int i6 = 0; i6 < length; i6++) {
                    int i7 = features[i6];
                    iArr[i7] = iArr[i7] + 1;
                }
            }
            for (int i8 = 0; i8 < this.numTopics; i8++) {
                iDSorterArr[i8].set(i8, iArr[i8] / i4);
            }
            Arrays.sort(iDSorterArr);
            for (int i9 = 0; i9 < i && iDSorterArr[i9].getWeight() >= d; i9++) {
                printWriter.print(iDSorterArr[i9].getID() + " " + iDSorterArr[i9].getWeight() + " ");
            }
            printWriter.print(" \n");
            Arrays.fill(iArr, 0);
        }
    }

    public void printState(File file) throws IOException {
        PrintStream printStream = new PrintStream((OutputStream) new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))), false, "UTF-8");
        printState(printStream);
        printStream.close();
    }

    public void printState(PrintStream printStream) {
        printStream.println("#doc lang pos typeindex type topic");
        for (int i = 0; i < this.data.size(); i++) {
            for (int i2 = 0; i2 < this.numLanguages; i2++) {
                FeatureSequence featureSequence = (FeatureSequence) this.data.get(i).instances[i2].getData();
                LabelSequence labelSequence = this.data.get(i).topicSequences[i2];
                for (int i3 = 0; i3 < labelSequence.getLength(); i3++) {
                    int indexAtPosition = featureSequence.getIndexAtPosition(i3);
                    int indexAtPosition2 = labelSequence.getIndexAtPosition(i3);
                    printStream.print(i);
                    printStream.print(' ');
                    printStream.print(i2);
                    printStream.print(' ');
                    printStream.print(i3);
                    printStream.print(' ');
                    printStream.print(indexAtPosition);
                    printStream.print(' ');
                    printStream.print(this.alphabets[i2].lookupObject(indexAtPosition));
                    printStream.print(' ');
                    printStream.print(indexAtPosition2);
                    printStream.println();
                }
            }
        }
    }

    public double modelLogLikelihood() {
        double d = 0.0d;
        int[] iArr = new int[this.numTopics];
        double[] dArr = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            dArr[i] = Dirichlet.logGammaStirling(this.alpha[i]);
        }
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            int i3 = 0;
            for (int i4 = 0; i4 < this.numLanguages; i4++) {
                LabelSequence labelSequence = this.data.get(i2).topicSequences[i4];
                int[] features = labelSequence.getFeatures();
                i3 += labelSequence.getLength();
                for (int i5 = 0; i5 < labelSequence.getLength(); i5++) {
                    int i6 = features[i5];
                    iArr[i6] = iArr[i6] + 1;
                }
            }
            for (int i7 = 0; i7 < this.numTopics; i7++) {
                if (iArr[i7] > 0) {
                    d += Dirichlet.logGammaStirling(this.alpha[i7] + iArr[i7]) - dArr[i7];
                }
            }
            d -= Dirichlet.logGammaStirling(this.alphaSum + i3);
            Arrays.fill(iArr, 0);
        }
        double size = d + (this.data.size() * Dirichlet.logGammaStirling(this.alphaSum));
        for (int i8 = 0; i8 < this.numLanguages; i8++) {
            int[][] iArr2 = this.languageTypeTopicCounts[i8];
            int[] iArr3 = this.languageTokensPerTopic[i8];
            double d2 = this.betas[i8];
            int i9 = 0;
            for (int i10 = 0; i10 < this.vocabularySizes[i8]; i10++) {
                int[] iArr4 = iArr2[i10];
                for (int i11 = 0; i11 < iArr4.length && iArr4[i11] > 0; i11++) {
                    int i12 = iArr4[i11] & this.topicMask;
                    int i13 = iArr4[i11] >> this.topicBits;
                    i9++;
                    size += Dirichlet.logGammaStirling(d2 + i13);
                    if (Double.isNaN(size)) {
                        System.out.println(i13);
                        System.exit(1);
                    }
                }
            }
            for (int i14 = 0; i14 < this.numTopics; i14++) {
                size -= Dirichlet.logGammaStirling((d2 * this.numTopics) + iArr3[i14]);
                if (Double.isNaN(size)) {
                    System.out.println("after topic " + i14 + " " + iArr3[i14]);
                    System.exit(1);
                }
            }
            size += Dirichlet.logGammaStirling(d2 * this.numTopics) - (Dirichlet.logGammaStirling(d2) * i9);
        }
        if (Double.isNaN(size)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return size;
    }

    public TopicInferencer getInferencer(int i) {
        return new TopicInferencer(this.languageTypeTopicCounts[i], this.languageTokensPerTopic[i], this.alphabets[i], this.alpha, this.betas[i], this.betaSums[i]);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeInt(this.numLanguages);
        objectOutputStream.writeObject(this.data);
        objectOutputStream.writeObject(this.topicAlphabet);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeObject(this.testingIDs);
        objectOutputStream.writeInt(this.topicMask);
        objectOutputStream.writeInt(this.topicBits);
        objectOutputStream.writeObject(this.alphabets);
        objectOutputStream.writeObject(this.vocabularySizes);
        objectOutputStream.writeObject(this.alpha);
        objectOutputStream.writeDouble(this.alphaSum);
        objectOutputStream.writeObject(this.betas);
        objectOutputStream.writeObject(this.betaSums);
        objectOutputStream.writeObject(this.languageMaxTypeCounts);
        objectOutputStream.writeObject(this.languageTypeTopicCounts);
        objectOutputStream.writeObject(this.languageTokensPerTopic);
        objectOutputStream.writeObject(this.languageSmoothingOnlyMasses);
        objectOutputStream.writeObject(this.languageCachedCoefficients);
        objectOutputStream.writeObject(this.docLengthCounts);
        objectOutputStream.writeObject(this.topicDocCounts);
        objectOutputStream.writeInt(this.numIterations);
        objectOutputStream.writeInt(this.burninPeriod);
        objectOutputStream.writeInt(this.saveSampleInterval);
        objectOutputStream.writeInt(this.optimizeInterval);
        objectOutputStream.writeInt(this.showTopicsInterval);
        objectOutputStream.writeInt(this.wordsPerTopic);
        objectOutputStream.writeInt(this.saveStateInterval);
        objectOutputStream.writeObject(this.stateFilename);
        objectOutputStream.writeInt(this.saveModelInterval);
        objectOutputStream.writeObject(this.modelFilename);
        objectOutputStream.writeObject(this.random);
        objectOutputStream.writeObject(this.formatter);
        objectOutputStream.writeBoolean(this.printLogLikelihood);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.numLanguages = objectInputStream.readInt();
        this.data = (ArrayList) objectInputStream.readObject();
        this.topicAlphabet = (LabelAlphabet) objectInputStream.readObject();
        this.numTopics = objectInputStream.readInt();
        this.testingIDs = (HashSet) objectInputStream.readObject();
        this.topicMask = objectInputStream.readInt();
        this.topicBits = objectInputStream.readInt();
        this.alphabets = (Alphabet[]) objectInputStream.readObject();
        this.vocabularySizes = (int[]) objectInputStream.readObject();
        this.alpha = (double[]) objectInputStream.readObject();
        this.alphaSum = objectInputStream.readDouble();
        this.betas = (double[]) objectInputStream.readObject();
        this.betaSums = (double[]) objectInputStream.readObject();
        this.languageMaxTypeCounts = (int[]) objectInputStream.readObject();
        this.languageTypeTopicCounts = (int[][][]) objectInputStream.readObject();
        this.languageTokensPerTopic = (int[][]) objectInputStream.readObject();
        this.languageSmoothingOnlyMasses = (double[]) objectInputStream.readObject();
        this.languageCachedCoefficients = (double[][]) objectInputStream.readObject();
        this.docLengthCounts = (int[]) objectInputStream.readObject();
        this.topicDocCounts = (int[][]) objectInputStream.readObject();
        this.numIterations = objectInputStream.readInt();
        this.burninPeriod = objectInputStream.readInt();
        this.saveSampleInterval = objectInputStream.readInt();
        this.optimizeInterval = objectInputStream.readInt();
        this.showTopicsInterval = objectInputStream.readInt();
        this.wordsPerTopic = objectInputStream.readInt();
        this.saveStateInterval = objectInputStream.readInt();
        this.stateFilename = (String) objectInputStream.readObject();
        this.saveModelInterval = objectInputStream.readInt();
        this.modelFilename = (String) objectInputStream.readObject();
        this.random = (Randoms) objectInputStream.readObject();
        this.formatter = (NumberFormat) objectInputStream.readObject();
        this.printLogLikelihood = objectInputStream.readBoolean();
    }

    public void write(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(this);
            objectOutputStream.close();
        } catch (IOException e) {
            System.err.println("Problem serializing PolylingualTopicModel to file " + file + ": " + e);
        }
    }

    public static PolylingualTopicModel read(File file) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
        PolylingualTopicModel polylingualTopicModel = (PolylingualTopicModel) objectInputStream.readObject();
        objectInputStream.close();
        polylingualTopicModel.initializeHistograms();
        return polylingualTopicModel;
    }

    public static void main(String[] strArr) throws IOException {
        CommandOption.setSummary(PolylingualTopicModel.class, "A tool for estimating, saving and printing diagnostics for topic models over comparable corpora.");
        CommandOption.process(PolylingualTopicModel.class, strArr);
        PolylingualTopicModel polylingualTopicModel = null;
        if (inputModelFilename.value != null) {
            try {
                polylingualTopicModel = read(new File(inputModelFilename.value));
            } catch (Exception e) {
                System.err.println("Unable to restore saved topic model " + inputModelFilename.value + ": " + e);
                System.exit(1);
            }
        } else {
            InstanceList[] instanceListArr = new InstanceList[languageInputFiles.value.length];
            for (int i = 0; i < instanceListArr.length; i++) {
                instanceListArr[i] = InstanceList.load(new File(languageInputFiles.value[i]));
                if (instanceListArr[i] != null) {
                    System.out.println(i + " is not null");
                } else {
                    System.out.println(i + " is null");
                }
            }
            System.out.println("Data loaded.");
            if (instanceListArr[0].size() > 0 && instanceListArr[0].get(0) != null && !(instanceListArr[0].get(0).getData() instanceof FeatureSequence)) {
                System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
                System.exit(1);
            }
            polylingualTopicModel = new PolylingualTopicModel(numTopicsOption.value, alphaOption.value);
            if (randomSeedOption.value != 0) {
                polylingualTopicModel.setRandomSeed(randomSeedOption.value);
            }
            polylingualTopicModel.addInstances(instanceListArr);
        }
        polylingualTopicModel.setTopicDisplay(showTopicsIntervalOption.value, topWordsOption.value);
        polylingualTopicModel.setNumIterations(numIterationsOption.value);
        polylingualTopicModel.setOptimizeInterval(optimizeIntervalOption.value);
        polylingualTopicModel.setBurninPeriod(optimizeBurnInOption.value);
        if (outputStateIntervalOption.value != 0) {
            polylingualTopicModel.setSaveState(outputStateIntervalOption.value, stateFile.value);
        }
        if (outputModelIntervalOption.value != 0) {
            polylingualTopicModel.setModelOutput(outputModelIntervalOption.value, outputModelFilename.value);
        }
        polylingualTopicModel.estimate();
        if (topicKeysFile.value != null) {
            polylingualTopicModel.printTopWords(new File(topicKeysFile.value), topWordsOption.value, false);
        }
        if (stateFile.value != null) {
            polylingualTopicModel.printState(new File(stateFile.value));
        }
        if (docTopicsFile.value != null) {
            PrintWriter printWriter = new PrintWriter(new FileWriter(new File(docTopicsFile.value)));
            polylingualTopicModel.printDocumentTopics(printWriter, docTopicsThreshold.value, docTopicsMax.value);
            printWriter.close();
        }
        if (inferencerFilename.value != null) {
            for (int i2 = 0; i2 < polylingualTopicModel.numLanguages; i2++) {
                try {
                    ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(inferencerFilename.value + "." + i2));
                    objectOutputStream.writeObject(polylingualTopicModel.getInferencer(i2));
                    objectOutputStream.close();
                } catch (Exception e2) {
                    System.err.println(e2.getMessage());
                }
            }
        }
        if (outputModelFilename.value != null) {
            if (!$assertionsDisabled && polylingualTopicModel == null) {
                throw new AssertionError();
            }
            polylingualTopicModel.write(new File(outputModelFilename.value));
        }
    }

    static {
        $assertionsDisabled = !PolylingualTopicModel.class.desiredAssertionStatus();
        languageInputFiles = new CommandOption.SpacedStrings(PolylingualTopicModel.class, "language-inputs", "FILENAME [FILENAME ...]", true, null, "Filenames for polylingual topic model. Each language should have its own file, with the same number of instances in each file. If a document is missing in one language, there should be an empty instance.", null);
        outputModelFilename = new CommandOption.String(PolylingualTopicModel.class, "output-model", "FILENAME", true, null, "The filename in which to write the binary topic model at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
        inputModelFilename = new CommandOption.String(PolylingualTopicModel.class, "input-model", "FILENAME", true, null, "The filename from which to read the binary topic model to which the --input will be appended, allowing incremental training.  By default this is null, indicating that no file will be read.", null);
        inferencerFilename = new CommandOption.String(PolylingualTopicModel.class, "inferencer-filename", "FILENAME", true, null, "A topic inferencer applies a previously trained topic model to new documents.  By default this is null, indicating that no file will be written.", null);
        evaluatorFilename = new CommandOption.String(PolylingualTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents.  By default this is null, indicating that no file will be written.", null);
        stateFile = new CommandOption.String(PolylingualTopicModel.class, "output-state", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
        topicKeysFile = new CommandOption.String(PolylingualTopicModel.class, "output-topic-keys", "FILENAME", true, null, "The filename in which to write the top words for each topic and any Dirichlet parameters.  By default this is null, indicating that no file will be written.", null);
        docTopicsFile = new CommandOption.String(PolylingualTopicModel.class, "output-doc-topics", "FILENAME", true, null, "The filename in which to write the topic proportions per document, at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
        docTopicsThreshold = new CommandOption.Double(PolylingualTopicModel.class, "doc-topics-threshold", "DECIMAL", true, 0.0d, "When writing topic proportions per document with --output-doc-topics, do not print topics with proportions less than this threshold value.", null);
        docTopicsMax = new CommandOption.Integer(PolylingualTopicModel.class, "doc-topics-max", "INTEGER", true, -1, "When writing topic proportions per document with --output-doc-topics, do not print more than INTEGER number of topics.  A negative value indicates that all topics should be printed.", null);
        outputModelIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "output-model-interval", "INTEGER", true, 0, "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file.  You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);
        outputStateIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "output-state-interval", "INTEGER", true, 0, "The number of iterations between writing the sampling state to a text file.  You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);
        numTopicsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null);
        numIterationsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling.", null);
        randomSeedOption = new CommandOption.Integer(PolylingualTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
        topWordsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-top-words", "INTEGER", true, 20, "The number of most probable words to print for each topic after model estimation.", null);
        showTopicsIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "show-topics-interval", "INTEGER", true, 50, "The number of iterations between printing a brief summary of the topics so far.", null);
        optimizeIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "optimize-interval", "INTEGER", true, 0, "The number of iterations between reestimating dirichlet hyperparameters.", null);
        optimizeBurnInOption = new CommandOption.Integer(PolylingualTopicModel.class, "optimize-burn-in", "INTEGER", true, ChartPanel.DEFAULT_MINIMUM_DRAW_HEIGHT, "The number of iterations to run before first estimating dirichlet hyperparameters.", null);
        alphaOption = new CommandOption.Double(PolylingualTopicModel.class, com.rapidminer.extension.operator.text_processing.mallet.LDA.PARAMETER_ALPHA, "DECIMAL", true, 50.0d, "Alpha parameter: smoothing over topic distribution.", null);
        betaOption = new CommandOption.Double(PolylingualTopicModel.class, com.rapidminer.extension.operator.text_processing.mallet.LDA.PARAMETER_BETA, "DECIMAL", true, 0.01d, "Beta parameter: smoothing over unigram distribution.", null);
    }
}
