package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:cc/mallet/topics/LDAHyper.class */
public class LDAHyper implements Serializable {
    protected ArrayList<Topication> data;
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01d;
    protected double smoothingOnlyMass;
    protected double[] cachedCoefficients;
    int topicTermCount;
    int betaTopicCount;
    int smoothingOnlyCount;
    protected InstanceList testing;
    protected int[] oneDocTopicCounts;
    protected TIntIntHashMap[] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    public int iterationsSoFar;
    public int numIterations;
    public int burninPeriod;
    public int saveSampleInterval;
    public int optimizeInterval;
    public int showTopicsInterval;
    public int wordsPerTopic;
    protected int outputModelInterval;
    protected String outputModelFilename;
    protected int saveStateInterval;
    protected String stateFilename;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/topics/LDAHyper$Topication.class */
    public class Topication implements Serializable {
        public Instance instance;
        public LDAHyper model;
        public LabelSequence topicSequence;
        public Labeling topicDistribution;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 0;

        public Topication(Instance instance, LDAHyper lDAHyper, LabelSequence labelSequence) {
            this.instance = instance;
            this.model = lDAHyper;
            this.topicSequence = labelSequence;
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(0);
            objectOutputStream.writeObject(this.instance);
            objectOutputStream.writeObject(this.model);
            objectOutputStream.writeObject(this.topicSequence);
            objectOutputStream.writeObject(this.topicDistribution);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.readInt();
            this.instance = (Instance) objectInputStream.readObject();
            this.model = (LDAHyper) objectInputStream.readObject();
            this.topicSequence = (LabelSequence) objectInputStream.readObject();
            this.topicDistribution = (Labeling) objectInputStream.readObject();
        }
    }

    public LDAHyper(int i) {
        this(i, i, 0.01d);
    }

    public LDAHyper(int i, double d, double d2) {
        this(i, d, d2, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int i) {
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        for (int i2 = 0; i2 < i; i2++) {
            labelAlphabet.lookupIndex("topic" + i2);
        }
        return labelAlphabet;
    }

    public LDAHyper(int i, double d, double d2, Randoms randoms) {
        this(newLabelAlphabet(i), d, d2, randoms);
    }

    public LDAHyper(LabelAlphabet labelAlphabet, double d, double d2, Randoms randoms) {
        this.smoothingOnlyMass = 0.0d;
        this.topicTermCount = 0;
        this.betaTopicCount = 0;
        this.smoothingOnlyCount = 0;
        this.testing = null;
        this.iterationsSoFar = 0;
        this.numIterations = 1000;
        this.burninPeriod = 20;
        this.saveSampleInterval = 5;
        this.optimizeInterval = 20;
        this.showTopicsInterval = 10;
        this.wordsPerTopic = 7;
        this.outputModelInterval = 0;
        this.saveStateInterval = 0;
        this.stateFilename = null;
        this.printLogLikelihood = false;
        this.data = new ArrayList<>();
        this.topicAlphabet = labelAlphabet;
        this.numTopics = labelAlphabet.size();
        this.alphaSum = d;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, d / this.numTopics);
        this.beta = d2;
        this.random = randoms;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<Topication> getData() {
        return this.data;
    }

    public int getCountFeatureTopic(int i, int i2) {
        return this.typeTopicCounts[i].get(i2);
    }

    public int getCountTokensPerTopic(int i) {
        return this.tokensPerTopic[i];
    }

    public void setTestingInstances(InstanceList instanceList) {
        this.testing = instanceList;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setBurninPeriod(int i) {
        this.burninPeriod = i;
    }

    public void setTopicDisplay(int i, int i2) {
        this.showTopicsInterval = i;
        this.wordsPerTopic = i2;
    }

    public void setRandomSeed(int i) {
        this.random = new Randoms(i);
    }

    public void setOptimizeInterval(int i) {
        this.optimizeInterval = i;
    }

    public void setModelOutput(int i, String str) {
        this.outputModelInterval = i;
        this.outputModelFilename = str;
    }

    public void setSaveState(int i, String str) {
        this.saveStateInterval = i;
        this.stateFilename = str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int instanceLength(Instance instance) {
        return ((FeatureSequence) instance.getData()).size();
    }

    private void initializeForTypes(Alphabet alphabet) {
        if (this.alphabet == null) {
            this.alphabet = alphabet;
            this.numTypes = alphabet.size();
            this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
            for (int i = 0; i < this.numTypes; i++) {
                this.typeTopicCounts[i] = new TIntIntHashMap();
            }
            this.betaSum = this.beta * this.numTypes;
            return;
        }
        if (alphabet != this.alphabet) {
            throw new IllegalArgumentException("Cannot change Alphabet.");
        }
        if (alphabet.size() != this.numTypes) {
            this.numTypes = alphabet.size();
            TIntIntHashMap[] tIntIntHashMapArr = new TIntIntHashMap[this.numTypes];
            for (int i2 = 0; i2 < this.typeTopicCounts.length; i2++) {
                tIntIntHashMapArr[i2] = this.typeTopicCounts[i2];
            }
            for (int length = this.typeTopicCounts.length; length < this.numTypes; length++) {
                tIntIntHashMapArr[length] = new TIntIntHashMap();
            }
            this.betaSum = this.beta * this.numTypes;
        }
    }

    private void initializeTypeTopicCounts() {
        TIntIntHashMap[] tIntIntHashMapArr = new TIntIntHashMap[this.numTypes];
        for (int i = 0; i < this.typeTopicCounts.length; i++) {
            tIntIntHashMapArr[i] = this.typeTopicCounts[i];
        }
        for (int length = this.typeTopicCounts.length; length < this.numTypes; length++) {
            tIntIntHashMapArr[length] = new TIntIntHashMap();
        }
        this.typeTopicCounts = tIntIntHashMapArr;
    }

    public void addInstances(InstanceList instanceList) {
        initializeForTypes(instanceList.getDataAlphabet());
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[instanceLength(it2.next())]);
            Randoms randoms = new Randoms();
            int[] features = labelSequence.getFeatures();
            for (int i = 0; i < features.length; i++) {
                features[i] = randoms.nextInt(this.numTopics);
            }
            arrayList.add(labelSequence);
        }
        addInstances(instanceList, arrayList);
    }

    public void addInstances(InstanceList instanceList, List<LabelSequence> list) {
        initializeForTypes(instanceList.getDataAlphabet());
        if (!$assertionsDisabled && instanceList.size() != list.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < instanceList.size(); i++) {
            Topication topication = new Topication(instanceList.get(i), this, list.get(i));
            this.data.add(topication);
            FeatureSequence featureSequence = (FeatureSequence) topication.instance.getData();
            LabelSequence labelSequence = topication.topicSequence;
            for (int i2 = 0; i2 < labelSequence.getLength(); i2++) {
                int indexAtPosition = labelSequence.getIndexAtPosition(i2);
                this.typeTopicCounts[featureSequence.getIndexAtPosition(i2)].adjustOrPutValue(indexAtPosition, 1, 1);
                int[] iArr = this.tokensPerTopic;
                iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
            }
        }
        initializeHistogramsAndCachedValues();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initializeHistogramsAndCachedValues() {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < this.data.size(); i3++) {
            int length = ((FeatureSequence) this.data.get(i3).instance.getData()).getLength();
            if (length > i) {
                i = length;
            }
            i2 += length;
        }
        this.smoothingOnlyMass = 0.0d;
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            this.smoothingOnlyMass += (this.alpha[i4] * this.beta) / (this.tokensPerTopic[i4] + this.betaSum);
        }
        this.cachedCoefficients = new double[this.numTopics];
        for (int i5 = 0; i5 < this.numTopics; i5++) {
            this.cachedCoefficients[i5] = this.alpha[i5] / (this.tokensPerTopic[i5] + this.betaSum);
        }
        System.err.println("max tokens: " + i);
        System.err.println("total tokens: " + i2);
        this.docLengthCounts = new int[i + 1];
        this.topicDocCounts = new int[this.numTopics][i + 1];
    }

    public void estimate() throws IOException {
        estimate(this.numIterations);
    }

    public void estimate(int i) throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        int i2 = this.iterationsSoFar + i;
        while (this.iterationsSoFar <= i2) {
            long currentTimeMillis2 = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                printTopWords(System.out, this.wordsPerTopic, false);
                if (this.testing != null) {
                    System.out.println(modelLogLikelihood() + "\t" + empiricalLikelihood(1000, this.testing) + "\t" + topicLabelMutualInformation());
                }
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.smoothingOnlyMass = 0.0d;
                for (int i3 = 0; i3 < this.numTopics; i3++) {
                    this.smoothingOnlyMass += (this.alpha[i3] * this.beta) / (this.tokensPerTopic[i3] + this.betaSum);
                    this.cachedCoefficients[i3] = this.alpha[i3] / (this.tokensPerTopic[i3] + this.betaSum);
                }
                clearHistograms();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            int size = this.data.size();
            for (int i4 = 0; i4 < size; i4++) {
                sampleTopicsForOneDoc((FeatureSequence) this.data.get(i4).instance.getData(), this.data.get(i4).topicSequence, this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0, true);
            }
            long currentTimeMillis3 = System.currentTimeMillis() - currentTimeMillis2;
            if (currentTimeMillis3 < 1000) {
                System.out.print(currentTimeMillis3 + "ms ");
            } else {
                System.out.print((currentTimeMillis3 / 1000) + "s ");
            }
            if (this.iterationsSoFar % 10 == 0) {
                System.out.println("<" + this.iterationsSoFar + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(modelLogLikelihood());
                }
            }
            System.out.flush();
            this.iterationsSoFar++;
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int i = 0; i < this.topicDocCounts.length; i++) {
            Arrays.fill(this.topicDocCounts[i], 0);
        }
    }

    private void oldSampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence featureSequence2, boolean z, boolean z2) {
        System.currentTimeMillis();
        int[] features = featureSequence2.getFeatures();
        int length = featureSequence.getLength();
        Arrays.fill(this.oneDocTopicCounts, 0);
        if (z2) {
            for (int i = 0; i < length; i++) {
                int[] iArr = this.oneDocTopicCounts;
                int i2 = features[i];
                iArr[i2] = iArr[i2] + 1;
            }
        }
        for (int i3 = 0; i3 < length; i3++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i3);
            int i4 = features[i3];
            TIntIntHashMap tIntIntHashMap = this.typeTopicCounts[indexAtPosition];
            if (!$assertionsDisabled && tIntIntHashMap.size() == 0) {
                throw new AssertionError();
            }
            if (z2) {
                int[] iArr2 = this.oneDocTopicCounts;
                iArr2[i4] = iArr2[i4] - 1;
                int adjustOrPutValue = tIntIntHashMap.adjustOrPutValue(i4, -1, -1);
                if (adjustOrPutValue == 0) {
                    tIntIntHashMap.remove(i4);
                } else if (adjustOrPutValue == -1) {
                    throw new IllegalStateException("Token count in topic went negative.");
                }
                int[] iArr3 = this.tokensPerTopic;
                iArr3[i4] = iArr3[i4] - 1;
            }
            int[] keys = tIntIntHashMap.keys();
            int[] values = tIntIntHashMap.getValues();
            double[] dArr = new double[keys.length];
            double d = 0.0d;
            for (int i5 = 0; i5 < values.length; i5++) {
                int i6 = keys[i5];
                double d2 = ((values[i5] + this.beta) / (this.tokensPerTopic[i6] + this.betaSum)) * (this.oneDocTopicCounts[i6] + this.alpha[i6]);
                d += d2;
                dArr[i6] = d2;
            }
            int i7 = keys[this.random.nextDiscrete(dArr, d)];
            if (z2) {
                features[i3] = i7;
                int[] iArr4 = this.oneDocTopicCounts;
                iArr4[i7] = iArr4[i7] + 1;
                this.typeTopicCounts[indexAtPosition].adjustOrPutValue(i7, 1, 1);
                int[] iArr5 = this.tokensPerTopic;
                iArr5[i7] = iArr5[i7] + 1;
            }
        }
        if (z) {
            int[] iArr6 = this.docLengthCounts;
            iArr6[length] = iArr6[length] + 1;
            for (int i8 = 0; i8 < this.numTopics; i8++) {
                int[] iArr7 = this.topicDocCounts[i8];
                int i9 = this.oneDocTopicCounts[i8];
                iArr7[i9] = iArr7[i9] + 1;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void sampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence featureSequence2, boolean z, boolean z2) {
        int[] features = featureSequence2.getFeatures();
        int length = featureSequence.getLength();
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        for (int i = 0; i < length; i++) {
            tIntIntHashMap.adjustOrPutValue(features[i], 1, 1);
        }
        double d = 0.0d;
        for (int i2 : tIntIntHashMap.keys()) {
            int i3 = tIntIntHashMap.get(i2);
            d += (this.beta * i3) / (this.tokensPerTopic[i2] + this.betaSum);
            this.cachedCoefficients[i2] = (this.alpha[i2] + i3) / (this.tokensPerTopic[i2] + this.betaSum);
        }
        double[] dArr = new double[this.numTopics];
        for (int i4 = 0; i4 < length; i4++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i4);
            int i5 = features[i4];
            TIntIntHashMap tIntIntHashMap2 = this.typeTopicCounts[indexAtPosition];
            if (!$assertionsDisabled && tIntIntHashMap2.get(i5) < 0) {
                throw new AssertionError();
            }
            if (tIntIntHashMap2.get(i5) == 1) {
                tIntIntHashMap2.remove(i5);
            } else {
                tIntIntHashMap2.adjustValue(i5, -1);
            }
            this.smoothingOnlyMass -= (this.alpha[i5] * this.beta) / (this.tokensPerTopic[i5] + this.betaSum);
            double d2 = d - ((this.beta * tIntIntHashMap.get(i5)) / (this.tokensPerTopic[i5] + this.betaSum));
            if (tIntIntHashMap.get(i5) == 1) {
                tIntIntHashMap.remove(i5);
            } else {
                tIntIntHashMap.adjustValue(i5, -1);
            }
            int[] iArr = this.tokensPerTopic;
            iArr[i5] = iArr[i5] - 1;
            this.smoothingOnlyMass += (this.alpha[i5] * this.beta) / (this.tokensPerTopic[i5] + this.betaSum);
            double d3 = d2 + ((this.beta * tIntIntHashMap.get(i5)) / (this.tokensPerTopic[i5] + this.betaSum));
            this.cachedCoefficients[i5] = (this.alpha[i5] + tIntIntHashMap.get(i5)) / (this.tokensPerTopic[i5] + this.betaSum);
            double d4 = 0.0d;
            int[] keys = tIntIntHashMap2.keys();
            int[] values = tIntIntHashMap2.getValues();
            for (int i6 = 0; i6 < keys.length; i6++) {
                double d5 = this.cachedCoefficients[keys[i6]] * values[i6];
                d4 += d5;
                dArr[i6] = d5;
            }
            double nextUniform = this.random.nextUniform() * (this.smoothingOnlyMass + d3 + d4);
            int i7 = -1;
            if (nextUniform >= d4) {
                double d6 = nextUniform - d4;
                if (d6 >= d3) {
                    nextUniform = (d6 - d3) / this.beta;
                    int i8 = 0;
                    while (true) {
                        if (i8 >= this.numTopics) {
                            break;
                        }
                        nextUniform -= this.alpha[i8] / (this.tokensPerTopic[i8] + this.betaSum);
                        if (nextUniform <= 0.0d) {
                            i7 = i8;
                            break;
                        }
                        i8++;
                    }
                } else {
                    nextUniform = d6 / this.beta;
                    int[] keys2 = tIntIntHashMap.keys();
                    int[] values2 = tIntIntHashMap.getValues();
                    for (int i9 = 0; i9 < keys2.length; i9++) {
                        i7 = keys2[i9];
                        nextUniform -= values2[i9] / (this.tokensPerTopic[i7] + this.betaSum);
                        if (nextUniform <= 0.0d) {
                            break;
                        }
                    }
                }
            } else {
                int i10 = -1;
                while (nextUniform > 0.0d) {
                    i10++;
                    nextUniform -= dArr[i10];
                }
                i7 = keys[i10];
            }
            if (i7 == -1) {
                System.err.println("LDAHyper sampling error: " + nextUniform + LangRequest.DEFAULT_SELECTION + nextUniform + LangRequest.DEFAULT_SELECTION + this.smoothingOnlyMass + LangRequest.DEFAULT_SELECTION + d3 + LangRequest.DEFAULT_SELECTION + d4);
                i7 = this.numTopics - 1;
            }
            features[i4] = i7;
            tIntIntHashMap2.adjustOrPutValue(i7, 1, 1);
            this.smoothingOnlyMass -= (this.alpha[i7] * this.beta) / (this.tokensPerTopic[i7] + this.betaSum);
            double d7 = d3 - ((this.beta * tIntIntHashMap.get(i7)) / (this.tokensPerTopic[i7] + this.betaSum));
            tIntIntHashMap.adjustOrPutValue(i7, 1, 1);
            int[] iArr2 = this.tokensPerTopic;
            int i11 = i7;
            iArr2[i11] = iArr2[i11] + 1;
            this.cachedCoefficients[i7] = (this.alpha[i7] + tIntIntHashMap.get(i7)) / (this.tokensPerTopic[i7] + this.betaSum);
            this.smoothingOnlyMass += (this.alpha[i7] * this.beta) / (this.tokensPerTopic[i7] + this.betaSum);
            d = d7 + ((this.beta * tIntIntHashMap.get(i7)) / (this.tokensPerTopic[i7] + this.betaSum));
            if (!$assertionsDisabled && tIntIntHashMap2.get(i7) < 0) {
                throw new AssertionError();
            }
        }
        for (int i12 : tIntIntHashMap.keys()) {
            this.cachedCoefficients[i12] = this.alpha[i12] / (this.tokensPerTopic[i12] + this.betaSum);
        }
        if (z) {
            int[] iArr3 = this.docLengthCounts;
            iArr3[length] = iArr3[length] + 1;
            for (int i13 : tIntIntHashMap.keys()) {
                int[] iArr4 = this.topicDocCounts[i13];
                int i14 = tIntIntHashMap.get(i13);
                iArr4[i14] = iArr4[i14] + 1;
            }
        }
    }

    public IDSorter[] getSortedTopicWords(int i) {
        IDSorter[] iDSorterArr = new IDSorter[this.numTypes];
        for (int i2 = 0; i2 < this.numTypes; i2++) {
            iDSorterArr[i2] = new IDSorter(i2, this.typeTopicCounts[i2].get(i));
        }
        Arrays.sort(iDSorterArr);
        return iDSorterArr;
    }

    public void printTopWords(File file, int i, boolean z) throws IOException {
        PrintStream printStream = new PrintStream(file);
        printTopWords(printStream, i, z);
        printStream.close();
    }

    public void printTopWords(PrintStream printStream, int i, boolean z) {
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            TreeSet treeSet = new TreeSet();
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                if (this.typeTopicCounts[i3].containsKey(i2)) {
                    treeSet.add(new IDSorter(i3, this.typeTopicCounts[i3].get(i2)));
                }
            }
            if (z) {
                printStream.println("Topic " + i2);
                Iterator it2 = treeSet.iterator();
                for (int i4 = 1; it2.hasNext() && i4 < i; i4++) {
                    IDSorter iDSorter = (IDSorter) it2.next();
                    printStream.println(this.alphabet.lookupObject(iDSorter.getID()) + "\t" + ((int) iDSorter.getWeight()));
                }
            } else {
                printStream.print(i2 + "\t" + this.formatter.format(this.alpha[i2]) + "\t" + this.tokensPerTopic[i2] + "\t");
                Iterator it3 = treeSet.iterator();
                for (int i5 = 1; it3.hasNext() && i5 < i; i5++) {
                    printStream.print(this.alphabet.lookupObject(((IDSorter) it3.next()).getID()) + LangRequest.DEFAULT_SELECTION);
                }
                printStream.println();
            }
        }
    }

    public void topicXMLReport(PrintWriter printWriter, int i) {
        printWriter.println("<?xml version='1.0' ?>");
        printWriter.println("<topicModel>");
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            printWriter.println("  <topic id='" + i2 + "' alpha='" + this.alpha[i2] + "' totalTokens='" + this.tokensPerTopic[i2] + "'>");
            TreeSet treeSet = new TreeSet();
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                if (this.typeTopicCounts[i3].containsKey(i2)) {
                    treeSet.add(new IDSorter(i3, this.typeTopicCounts[i3].get(i2)));
                }
            }
            Iterator it2 = treeSet.iterator();
            for (int i4 = 1; it2.hasNext() && i4 < i; i4++) {
                printWriter.println("    <word rank='" + i4 + "'>" + this.alphabet.lookupObject(((IDSorter) it2.next()).getID()) + "</word>");
            }
            printWriter.println("  </topic>");
        }
        printWriter.println("</topicModel>");
    }

    public void topicXMLReportPhrases(PrintStream printStream, int i) {
        int numTopics = getNumTopics();
        TObjectIntHashMap[] tObjectIntHashMapArr = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = getAlphabet();
        for (int i2 = 0; i2 < numTopics; i2++) {
            tObjectIntHashMapArr[i2] = new TObjectIntHashMap();
        }
        for (int i3 = 0; i3 < getData().size(); i3++) {
            FeatureSequence featureSequence = (FeatureSequence) getData().get(i3).instance.getData();
            boolean z = featureSequence instanceof FeatureSequenceWithBigrams;
            int i4 = -1;
            int i5 = -1;
            StringBuffer stringBuffer = null;
            int size = featureSequence.size();
            for (int i6 = 0; i6 < size; i6++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i6);
                int indexAtPosition2 = getData().get(i3).topicSequence.getIndexAtPosition(i6);
                if (indexAtPosition2 != i4 || (z && ((FeatureSequenceWithBigrams) featureSequence).getBiIndexAtPosition(i6) == -1)) {
                    if (stringBuffer != null) {
                        String stringBuffer2 = stringBuffer.toString();
                        if (tObjectIntHashMapArr[i4].get(stringBuffer2) == 0) {
                            tObjectIntHashMapArr[i4].put(stringBuffer2, 0);
                        }
                        tObjectIntHashMapArr[i4].increment(stringBuffer2);
                        i5 = -1;
                        i4 = -1;
                        stringBuffer = null;
                    } else {
                        i4 = indexAtPosition2;
                        i5 = indexAtPosition;
                    }
                } else if (stringBuffer == null) {
                    stringBuffer = new StringBuffer(alphabet.lookupObject(i5).toString() + LangRequest.DEFAULT_SELECTION + alphabet.lookupObject(indexAtPosition));
                } else {
                    stringBuffer.append(LangRequest.DEFAULT_SELECTION);
                    stringBuffer.append(alphabet.lookupObject(indexAtPosition));
                }
            }
        }
        printStream.println("<?xml version='1.0' ?>");
        printStream.println("<topics>");
        double[] dArr = new double[alphabet.size()];
        for (int i7 = 0; i7 < numTopics; i7++) {
            printStream.print("  <topic id=\"" + i7 + "\" alpha=\"" + this.alpha[i7] + "\" totalTokens=\"" + this.tokensPerTopic[i7] + "\" ");
            PrintStream printStream2 = new PrintStream(new ByteArrayOutputStream());
            AugmentableFeatureVector augmentableFeatureVector = new AugmentableFeatureVector(new Alphabet());
            for (int i8 = 0; i8 < alphabet.size(); i8++) {
                dArr[i8] = getCountFeatureTopic(i8, i7) / getCountTokensPerTopic(i7);
            }
            RankedFeatureVector rankedFeatureVector = new RankedFeatureVector(alphabet, dArr);
            for (int i9 = 0; i9 < i; i9++) {
                int indexAtRank = rankedFeatureVector.getIndexAtRank(i9);
                printStream2.println("      <term weight=\"" + dArr[indexAtRank] + "\" count=\"" + getCountFeatureTopic(indexAtRank, i7) + "\">" + alphabet.lookupObject(indexAtRank) + "</term>");
                if (i9 < 20) {
                    augmentableFeatureVector.add(alphabet.lookupObject(indexAtRank), getCountFeatureTopic(indexAtRank, i7));
                }
            }
            Object[] keys = tObjectIntHashMapArr[i7].keys();
            int[] values = tObjectIntHashMapArr[i7].getValues();
            double[] dArr2 = new double[keys.length];
            for (int i10 = 0; i10 < dArr2.length; i10++) {
                dArr2[i10] = values[i10];
            }
            double sum = MatrixOps.sum(dArr2);
            Alphabet alphabet2 = new Alphabet(keys);
            RankedFeatureVector rankedFeatureVector2 = new RankedFeatureVector(alphabet2, dArr2);
            int numLocations = rankedFeatureVector2.numLocations() < i ? rankedFeatureVector2.numLocations() : i;
            for (int i11 = 0; i11 < numLocations; i11++) {
                int indexAtRank2 = rankedFeatureVector2.getIndexAtRank(i11);
                printStream2.println("      <phrase weight=\"" + (dArr2[indexAtRank2] / sum) + "\" count=\"" + values[indexAtRank2] + "\">" + alphabet2.lookupObject(indexAtRank2) + "</phrase>");
                if (i11 < 20 && values[indexAtRank2] > 20) {
                    augmentableFeatureVector.add(alphabet2.lookupObject(indexAtRank2), 100 * values[indexAtRank2]);
                }
            }
            StringBuffer stringBuffer3 = new StringBuffer();
            RankedFeatureVector rankedFeatureVector3 = new RankedFeatureVector(augmentableFeatureVector.getAlphabet(), augmentableFeatureVector);
            int i12 = 10;
            for (int i13 = 0; i13 < i12 && i13 < rankedFeatureVector3.numLocations(); i13++) {
                if (stringBuffer3.indexOf(rankedFeatureVector3.getObjectAtRank(i13).toString()) == -1) {
                    stringBuffer3.append(rankedFeatureVector3.getObjectAtRank(i13));
                    if (i13 < i12 - 1) {
                        stringBuffer3.append(", ");
                    }
                } else {
                    i12++;
                }
            }
            printStream.println("titles=\"" + stringBuffer3.toString() + "\">");
            printStream.print(printStream2.toString());
            printStream.println("  </topic>");
        }
        printStream.println("</topics>");
    }

    public void printDocumentTopics(File file) throws IOException {
        printDocumentTopics(new PrintWriter(new FileWriter(file)));
    }

    public void printDocumentTopics(PrintWriter printWriter) {
        printDocumentTopics(printWriter, 0.0d, -1);
    }

    public void printDocumentTopics(PrintWriter printWriter, double d, int i) {
        printWriter.print("#doc source topic proportion ...\n");
        int[] iArr = new int[this.numTopics];
        IDSorter[] iDSorterArr = new IDSorter[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            iDSorterArr[i2] = new IDSorter(i2, i2);
        }
        if (i < 0 || i > this.numTopics) {
            i = this.numTopics;
        }
        for (int i3 = 0; i3 < this.data.size(); i3++) {
            int[] features = this.data.get(i3).topicSequence.getFeatures();
            printWriter.print(i3);
            printWriter.print(' ');
            if (this.data.get(i3).instance.getSource() != null) {
                printWriter.print(this.data.get(i3).instance.getSource());
            } else {
                printWriter.print("null-source");
            }
            printWriter.print(' ');
            for (int i4 : features) {
                iArr[i4] = iArr[i4] + 1;
            }
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                iDSorterArr[i5].set(i5, iArr[i5] / r0);
            }
            Arrays.sort(iDSorterArr);
            for (int i6 = 0; i6 < i && iDSorterArr[i6].getWeight() >= d; i6++) {
                printWriter.print(iDSorterArr[i6].getID() + LangRequest.DEFAULT_SELECTION + iDSorterArr[i6].getWeight() + LangRequest.DEFAULT_SELECTION);
            }
            printWriter.print(" \n");
            Arrays.fill(iArr, 0);
        }
    }

    public void printState(File file) throws IOException {
        PrintStream printStream = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))));
        printState(printStream);
        printStream.close();
    }

    public void printState(PrintStream printStream) {
        printStream.println("#doc source pos typeindex type topic");
        for (int i = 0; i < this.data.size(); i++) {
            FeatureSequence featureSequence = (FeatureSequence) this.data.get(i).instance.getData();
            LabelSequence labelSequence = this.data.get(i).topicSequence;
            String obj = this.data.get(i).instance.getSource() != null ? this.data.get(i).instance.getSource().toString() : "NA";
            for (int i2 = 0; i2 < labelSequence.getLength(); i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                int indexAtPosition2 = labelSequence.getIndexAtPosition(i2);
                printStream.print(i);
                printStream.print(' ');
                printStream.print(obj);
                printStream.print(' ');
                printStream.print(i2);
                printStream.print(' ');
                printStream.print(indexAtPosition);
                printStream.print(' ');
                printStream.print(this.alphabet.lookupObject(indexAtPosition));
                printStream.print(' ');
                printStream.print(indexAtPosition2);
                printStream.println();
            }
        }
    }

    public void write(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(this);
            objectOutputStream.close();
        } catch (IOException e) {
            System.err.println("LDAHyper.write: Exception writing LDAHyper to file " + file + ": " + e);
        }
    }

    public static LDAHyper read(File file) {
        LDAHyper lDAHyper = null;
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
            lDAHyper = (LDAHyper) objectInputStream.readObject();
            lDAHyper.initializeTypeTopicCounts();
            objectInputStream.close();
        } catch (IOException e) {
            System.err.println("Exception reading file " + file + ": " + e);
        } catch (ClassNotFoundException e2) {
            System.err.println("Exception reading file " + file + ": " + e2);
        }
        return lDAHyper;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeObject(this.data);
        objectOutputStream.writeObject(this.alphabet);
        objectOutputStream.writeObject(this.topicAlphabet);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeObject(this.alpha);
        objectOutputStream.writeDouble(this.beta);
        objectOutputStream.writeDouble(this.betaSum);
        objectOutputStream.writeDouble(this.smoothingOnlyMass);
        objectOutputStream.writeObject(this.cachedCoefficients);
        objectOutputStream.writeInt(this.iterationsSoFar);
        objectOutputStream.writeInt(this.numIterations);
        objectOutputStream.writeInt(this.burninPeriod);
        objectOutputStream.writeInt(this.saveSampleInterval);
        objectOutputStream.writeInt(this.optimizeInterval);
        objectOutputStream.writeInt(this.showTopicsInterval);
        objectOutputStream.writeInt(this.wordsPerTopic);
        objectOutputStream.writeInt(this.outputModelInterval);
        objectOutputStream.writeObject(this.outputModelFilename);
        objectOutputStream.writeInt(this.saveStateInterval);
        objectOutputStream.writeObject(this.stateFilename);
        objectOutputStream.writeObject(this.random);
        objectOutputStream.writeObject(this.formatter);
        objectOutputStream.writeBoolean(this.printLogLikelihood);
        objectOutputStream.writeObject(this.docLengthCounts);
        objectOutputStream.writeObject(this.topicDocCounts);
        for (int i = 0; i < this.numTypes; i++) {
            objectOutputStream.writeObject(this.typeTopicCounts[i]);
        }
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            objectOutputStream.writeInt(this.tokensPerTopic[i2]);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.data = (ArrayList) objectInputStream.readObject();
        this.alphabet = (Alphabet) objectInputStream.readObject();
        this.topicAlphabet = (LabelAlphabet) objectInputStream.readObject();
        this.numTopics = objectInputStream.readInt();
        this.alpha = (double[]) objectInputStream.readObject();
        this.beta = objectInputStream.readDouble();
        this.betaSum = objectInputStream.readDouble();
        this.smoothingOnlyMass = objectInputStream.readDouble();
        this.cachedCoefficients = (double[]) objectInputStream.readObject();
        this.iterationsSoFar = objectInputStream.readInt();
        this.numIterations = objectInputStream.readInt();
        this.burninPeriod = objectInputStream.readInt();
        this.saveSampleInterval = objectInputStream.readInt();
        this.optimizeInterval = objectInputStream.readInt();
        this.showTopicsInterval = objectInputStream.readInt();
        this.wordsPerTopic = objectInputStream.readInt();
        this.outputModelInterval = objectInputStream.readInt();
        this.outputModelFilename = (String) objectInputStream.readObject();
        this.saveStateInterval = objectInputStream.readInt();
        this.stateFilename = (String) objectInputStream.readObject();
        this.random = (Randoms) objectInputStream.readObject();
        this.formatter = (NumberFormat) objectInputStream.readObject();
        this.printLogLikelihood = objectInputStream.readBoolean();
        this.docLengthCounts = (int[]) objectInputStream.readObject();
        this.topicDocCounts = (int[][]) objectInputStream.readObject();
        this.data.size();
        this.numTypes = this.alphabet.size();
        this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
        for (int i = 0; i < this.numTypes; i++) {
            this.typeTopicCounts[i] = (TIntIntHashMap) objectInputStream.readObject();
        }
        this.tokensPerTopic = new int[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            this.tokensPerTopic[i2] = objectInputStream.readInt();
        }
    }

    public double topicLabelMutualInformation() {
        if (this.data.get(0).instance.getTargetAlphabet() == null) {
            return 0.0d;
        }
        int size = this.data.get(0).instance.getTargetAlphabet().size();
        int[][] iArr = new int[this.numTopics][size];
        int[] iArr2 = new int[this.numTopics];
        int[] iArr3 = new int[size];
        int i = 0;
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            int bestIndex = this.data.get(i2).instance.getLabeling().getBestIndex();
            for (int i3 : this.data.get(i2).topicSequence.getFeatures()) {
                int[] iArr4 = iArr[i3];
                iArr4[bestIndex] = iArr4[bestIndex] + 1;
                iArr2[i3] = iArr2[i3] + 1;
                iArr3[bestIndex] = iArr3[bestIndex] + 1;
                i++;
            }
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double log = Math.log(2.0d);
        for (int i4 = 0; i4 < iArr2.length; i4++) {
            if (iArr2[i4] != 0) {
                double d4 = iArr2[i4] / i;
                d -= (d4 * Math.log(d4)) / log;
            }
        }
        for (int i5 = 0; i5 < iArr3.length; i5++) {
            if (iArr3[i5] != 0) {
                double d5 = iArr3[i5] / i;
                d2 -= (d5 * Math.log(d5)) / log;
            }
        }
        for (int i6 = 0; i6 < iArr2.length; i6++) {
            for (int i7 = 0; i7 < iArr3.length; i7++) {
                if (iArr[i6][i7] != 0) {
                    double d6 = iArr[i6][i7] / i;
                    d3 -= (d6 * Math.log(d6)) / log;
                }
            }
        }
        return (d + d2) - d3;
    }

    public double empiricalLikelihood(int i, InstanceList instanceList) {
        double[][] dArr = new double[instanceList.size()][i];
        double[] dArr2 = new double[this.numTypes];
        Dirichlet dirichlet = new Dirichlet(this.alpha);
        for (int i2 = 0; i2 < i; i2++) {
            double[] nextDistribution = dirichlet.nextDistribution();
            Arrays.fill(dArr2, 0.0d);
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                for (int i4 = 0; i4 < this.numTypes; i4++) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + ((nextDistribution[i3] * (this.beta + this.typeTopicCounts[i4].get(i3))) / (this.betaSum + this.tokensPerTopic[i3]));
                }
            }
            for (int i6 = 0; i6 < this.numTypes; i6++) {
                if (!$assertionsDisabled && dArr2[i6] <= 0.0d) {
                    throw new AssertionError();
                }
                dArr2[i6] = Math.log(dArr2[i6]);
            }
            for (int i7 = 0; i7 < instanceList.size(); i7++) {
                FeatureSequence featureSequence = (FeatureSequence) instanceList.get(i7).getData();
                int length = featureSequence.getLength();
                for (int i8 = 0; i8 < length; i8++) {
                    int indexAtPosition = featureSequence.getIndexAtPosition(i8);
                    if (indexAtPosition < this.numTypes) {
                        double[] dArr3 = dArr[i7];
                        int i9 = i2;
                        dArr3[i9] = dArr3[i9] + dArr2[indexAtPosition];
                    }
                }
            }
        }
        double d = 0.0d;
        double log = Math.log(i);
        for (int i10 = 0; i10 < instanceList.size(); i10++) {
            double d2 = Double.NEGATIVE_INFINITY;
            for (int i11 = 0; i11 < i; i11++) {
                if (dArr[i10][i11] > d2) {
                    d2 = dArr[i10][i11];
                }
            }
            double d3 = 0.0d;
            for (int i12 = 0; i12 < i; i12++) {
                d3 += Math.exp(dArr[i10][i12] - d2);
            }
            d += (Math.log(d3) + d2) - log;
        }
        return d;
    }

    public double modelLogLikelihood() {
        double d = 0.0d;
        int[] iArr = new int[this.numTopics];
        double[] dArr = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            dArr[i] = Dirichlet.logGammaStirling(this.alpha[i]);
        }
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            for (int i3 : this.data.get(i2).topicSequence.getFeatures()) {
                iArr[i3] = iArr[i3] + 1;
            }
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                if (iArr[i4] > 0) {
                    d += Dirichlet.logGammaStirling(this.alpha[i4] + iArr[i4]) - dArr[i4];
                }
            }
            d -= Dirichlet.logGammaStirling(this.alphaSum + r0.length);
            Arrays.fill(iArr, 0);
        }
        double size = d + (this.data.size() * Dirichlet.logGammaStirling(this.alphaSum));
        int i5 = 0;
        for (int i6 = 0; i6 < this.numTypes; i6++) {
            for (int i7 : this.typeTopicCounts[i6].keys()) {
                int i8 = this.typeTopicCounts[i6].get(i7);
                if (i8 > 0) {
                    i5++;
                    size += Dirichlet.logGammaStirling(this.beta + i8);
                }
            }
        }
        for (int i9 = 0; i9 < this.numTopics; i9++) {
            size -= Dirichlet.logGammaStirling((this.beta * this.numTopics) + this.tokensPerTopic[i9]);
        }
        return size + (Dirichlet.logGammaStirling(this.beta * this.numTopics) - (Dirichlet.logGammaStirling(this.beta) * i5));
    }

    public static void main(String[] strArr) throws IOException {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        int parseInt = strArr.length > 1 ? Integer.parseInt(strArr[1]) : 200;
        InstanceList load2 = strArr.length > 2 ? InstanceList.load(new File(strArr[2])) : null;
        LDAHyper lDAHyper = new LDAHyper(parseInt, 50.0d, 0.01d);
        lDAHyper.printLogLikelihood = true;
        lDAHyper.setTopicDisplay(50, 7);
        lDAHyper.addInstances(load);
        lDAHyper.estimate();
    }

    static {
        $assertionsDisabled = !LDAHyper.class.desiredAssertionStatus();
    }
}
