package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetFactory;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntIntHashMap;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:cc/mallet/topics/NPTopicModel.class */
public class NPTopicModel implements Serializable {
    private static Logger logger;
    protected Alphabet alphabet;
    protected int maxTopic;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double gamma;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01d;
    protected TIntIntHashMap[] typeTopicCounts;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected int totalDocTopics = 0;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected boolean printLogLikelihood = false;
    protected ArrayList<TopicAssignment> data = new ArrayList<>();
    protected LabelAlphabet topicAlphabet = AlphabetFactory.labelAlphabetOfSize(1);
    protected Randoms random = new Randoms();
    protected TIntIntHashMap tokensPerTopic = new TIntIntHashMap();
    protected TIntIntHashMap docsPerTopic = new TIntIntHashMap();
    protected NumberFormat formatter = NumberFormat.getInstance();

    public NPTopicModel(double d, double d2, double d3) {
        this.alpha = d;
        this.gamma = d2;
        this.beta = d3;
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Non-Parametric LDA");
    }

    public void setTopicDisplay(int i, int i2) {
        this.showTopicsInterval = i;
        this.wordsPerTopic = i2;
    }

    public void setRandomSeed(int i) {
        this.random = new Randoms(i);
    }

    public void addInstances(InstanceList instanceList, int i) {
        this.alphabet = instanceList.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * this.numTypes;
        this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
        for (int i2 = 0; i2 < this.numTypes; i2++) {
            this.typeTopicCounts[i2] = new TIntIntHashMap();
        }
        this.numTopics = i;
        int i3 = 0;
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            Instance next = it2.next();
            i3++;
            TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
            FeatureSequence featureSequence = (FeatureSequence) next.getData();
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[featureSequence.size()]);
            int[] features = labelSequence.getFeatures();
            for (int i4 = 0; i4 < featureSequence.size(); i4++) {
                int nextInt = this.random.nextInt(this.numTopics);
                this.tokensPerTopic.adjustOrPutValue(nextInt, 1, 1);
                features[i4] = nextInt;
                if (tIntIntHashMap.containsKey(nextInt)) {
                    tIntIntHashMap.adjustValue(nextInt, 1);
                } else {
                    this.docsPerTopic.adjustOrPutValue(nextInt, 1, 1);
                    this.totalDocTopics++;
                    tIntIntHashMap.put(nextInt, 1);
                }
                this.typeTopicCounts[featureSequence.getIndexAtPosition(i4)].adjustOrPutValue(nextInt, 1, 1);
            }
            this.data.add(new TopicAssignment(next, labelSequence));
        }
        this.maxTopic = this.numTopics - 1;
    }

    public void sample(int i) throws IOException {
        for (int i2 = 1; i2 <= i; i2++) {
            long currentTimeMillis = System.currentTimeMillis();
            for (int i3 = 0; i3 < this.data.size(); i3++) {
                sampleTopicsForOneDoc((FeatureSequence) this.data.get(i3).instance.getData(), this.data.get(i3).topicSequence);
            }
            logger.info(i2 + "\t" + (System.currentTimeMillis() - currentTimeMillis) + "ms\t" + this.numTopics);
            if (this.showTopicsInterval != 0 && i2 % this.showTopicsInterval == 0) {
                logger.info("<" + i2 + "> #Topics: " + this.numTopics + "\n" + topWords(this.wordsPerTopic));
            }
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence featureSequence2) {
        int[] features = featureSequence2.getFeatures();
        int length = featureSequence.getLength();
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        for (int i = 0; i < length; i++) {
            tIntIntHashMap.adjustOrPutValue(features[i], 1, 1);
        }
        double[] dArr = new double[this.numTopics + 1];
        int[] keys = this.docsPerTopic.keys();
        for (int i2 = 0; i2 < length; i2++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i2);
            int i3 = features[i2];
            TIntIntHashMap tIntIntHashMap2 = this.typeTopicCounts[indexAtPosition];
            if (tIntIntHashMap.get(i3) == 1) {
                tIntIntHashMap.remove(i3);
                if (this.docsPerTopic.get(i3) != 1) {
                    this.docsPerTopic.adjustValue(i3, -1);
                    this.totalDocTopics--;
                    this.tokensPerTopic.adjustValue(i3, -1);
                } else {
                    if (!$assertionsDisabled && this.tokensPerTopic.get(i3) != 1) {
                        throw new AssertionError();
                    }
                    this.docsPerTopic.remove(i3);
                    this.totalDocTopics--;
                    this.tokensPerTopic.remove(i3);
                    this.numTopics--;
                    keys = this.docsPerTopic.keys();
                    dArr = new double[this.numTopics + 1];
                }
            } else {
                tIntIntHashMap.adjustValue(i3, -1);
                this.tokensPerTopic.adjustValue(i3, -1);
            }
            if (tIntIntHashMap2.get(i3) == 1) {
                tIntIntHashMap2.remove(i3);
            } else {
                tIntIntHashMap2.adjustValue(i3, -1);
            }
            double d = 0.0d;
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                int i5 = keys[i4];
                dArr[i4] = ((tIntIntHashMap.get(i5) + (this.alpha * (this.docsPerTopic.get(i5) / (this.totalDocTopics + this.gamma)))) * (tIntIntHashMap2.get(i5) + this.beta)) / (this.tokensPerTopic.get(i5) + this.betaSum);
                d += dArr[i4];
            }
            dArr[this.numTopics] = (this.alpha * this.gamma) / (this.numTypes * (this.totalDocTopics + this.gamma));
            double nextUniform = this.random.nextUniform() * (d + dArr[this.numTopics]);
            int i6 = -1;
            while (nextUniform > 0.0d) {
                i6++;
                nextUniform -= dArr[i6];
            }
            if (i6 < this.numTopics) {
                int i7 = keys[i6];
                features[i2] = i7;
                tIntIntHashMap2.adjustOrPutValue(i7, 1, 1);
                this.tokensPerTopic.adjustValue(i7, 1);
                if (tIntIntHashMap.containsKey(i7)) {
                    tIntIntHashMap.adjustValue(i7, 1);
                } else {
                    tIntIntHashMap.put(i7, 1);
                    this.docsPerTopic.adjustValue(i7, 1);
                    this.totalDocTopics++;
                }
            } else {
                int i8 = this.maxTopic + 1;
                this.maxTopic = i8;
                this.numTopics++;
                features[i2] = i8;
                tIntIntHashMap.put(i8, 1);
                this.docsPerTopic.put(i8, 1);
                this.totalDocTopics++;
                tIntIntHashMap2.put(i8, 1);
                this.tokensPerTopic.put(i8, 1);
                keys = this.docsPerTopic.keys();
                dArr = new double[this.numTopics + 1];
            }
        }
    }

    public String topWords(int i) {
        StringBuilder sb = new StringBuilder();
        IDSorter[] iDSorterArr = new IDSorter[this.numTypes];
        for (int i2 : this.docsPerTopic.keys()) {
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                iDSorterArr[i3] = new IDSorter(i3, this.typeTopicCounts[i3].get(i2));
            }
            Arrays.sort(iDSorterArr);
            sb.append(i2 + "\t" + this.tokensPerTopic.get(i2) + "\t");
            for (int i4 = 0; i4 < i && iDSorterArr[i4].getWeight() >= 1.0d; i4++) {
                sb.append(this.alphabet.lookupObject(iDSorterArr[i4].getID()) + LangRequest.DEFAULT_SELECTION);
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    public void printState(File file) throws IOException {
        PrintStream printStream = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))));
        printState(printStream);
        printStream.close();
    }

    public void printState(PrintStream printStream) {
        printStream.println("#doc source pos typeindex type topic");
        for (int i = 0; i < this.data.size(); i++) {
            FeatureSequence featureSequence = (FeatureSequence) this.data.get(i).instance.getData();
            LabelSequence labelSequence = this.data.get(i).topicSequence;
            String obj = this.data.get(i).instance.getSource() != null ? this.data.get(i).instance.getSource().toString() : "NA";
            for (int i2 = 0; i2 < labelSequence.getLength(); i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                int indexAtPosition2 = labelSequence.getIndexAtPosition(i2);
                printStream.print(i);
                printStream.print(' ');
                printStream.print(obj);
                printStream.print(' ');
                printStream.print(i2);
                printStream.print(' ');
                printStream.print(indexAtPosition);
                printStream.print(' ');
                printStream.print(this.alphabet.lookupObject(indexAtPosition));
                printStream.print(' ');
                printStream.print(indexAtPosition2);
                printStream.println();
            }
        }
    }

    public static void main(String[] strArr) throws IOException {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        int parseInt = strArr.length > 1 ? Integer.parseInt(strArr[1]) : 200;
        NPTopicModel nPTopicModel = new NPTopicModel(5.0d, 10.0d, 0.1d);
        nPTopicModel.addInstances(load, parseInt);
        nPTopicModel.sample(1000);
    }

    static {
        $assertionsDisabled = !NPTopicModel.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(NPTopicModel.class.getName());
    }
}
