package cc.mallet.topics;

import bsh.org.objectweb.asm.Constants;
import cc.mallet.types.IDSorter;
import cc.mallet.util.Randoms;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;

/* loaded from: input_file:cc/mallet/topics/MultinomialHMM.class */
public class MultinomialHMM {
    int numTopics;
    int numStates;
    int numDocs;
    int numSequences;
    double[] alpha;
    double alphaSum;
    double beta;
    double betaSum;
    double gamma;
    double gammaSum;
    double pi;
    double sumPi;
    TIntObjectHashMap<TIntIntHashMap> documentTopics;
    int[] documentSequenceIDs;
    int[] documentStates;
    int[][] stateTopicCounts;
    int[] stateTopicTotals;
    int[][] stateStateTransitions;
    int[] stateTransitionTotals;
    int[] initialStateCounts;
    int[] maxTokensPerTopic;
    int maxDocLength;
    double[][][] topicLogGammaCache;
    double[][] docLogGammaCache;
    String[] topicKeys;
    Randoms random;
    static final /* synthetic */ boolean $assertionsDisabled;
    int numIterations = 1000;
    int burninPeriod = 200;
    int saveSampleInterval = 10;
    int optimizeInterval = 0;
    int showTopicsInterval = 50;
    NumberFormat formatter = NumberFormat.getInstance();

    public MultinomialHMM(int i, String str, int i2) throws IOException {
        this.formatter.setMaximumFractionDigits(5);
        System.out.println("LDA HMM: " + i);
        this.documentTopics = new TIntObjectHashMap<>();
        this.numTopics = i;
        this.alphaSum = i;
        this.alpha = new double[i];
        Arrays.fill(this.alpha, this.alphaSum / this.numTopics);
        this.topicKeys = new String[this.numTopics];
        loadTopicsFromFile(str);
        this.documentStates = new int[this.numDocs];
        this.documentSequenceIDs = new int[this.numDocs];
        this.maxTokensPerTopic = new int[this.numTopics];
        this.maxDocLength = 0;
        for (int i3 = 0; i3 < this.numDocs; i3++) {
            if (this.documentTopics.containsKey(i3)) {
                TIntIntHashMap tIntIntHashMap = this.documentTopics.get(i3);
                int i4 = 0;
                for (int i5 : tIntIntHashMap.keys()) {
                    int i6 = tIntIntHashMap.get(i5);
                    if (i6 > this.maxTokensPerTopic[i5]) {
                        this.maxTokensPerTopic[i5] = i6;
                    }
                    i4 += i6;
                }
                if (i4 > this.maxDocLength) {
                    this.maxDocLength = i4;
                }
            }
        }
        this.numStates = i2;
        this.initialStateCounts = new int[i2];
        this.topicLogGammaCache = new double[i2][this.numTopics];
        for (int i7 = 0; i7 < i2; i7++) {
            for (int i8 = 0; i8 < this.numTopics; i8++) {
                this.topicLogGammaCache[i7][i8] = new double[this.maxTokensPerTopic[i8] + 1];
            }
        }
        System.out.println(this.maxDocLength);
        this.docLogGammaCache = new double[i2][this.maxDocLength + 1];
    }

    public void setGamma(double d) {
        this.gamma = d;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setBurninPeriod(int i) {
        this.burninPeriod = i;
    }

    public void setTopicDisplayInterval(int i) {
        this.showTopicsInterval = i;
    }

    public void setRandomSeed(int i) {
        this.random = new Randoms(i);
    }

    public void setOptimizeInterval(int i) {
        this.optimizeInterval = i;
    }

    public void initialize() {
        if (this.random == null) {
            this.random = new Randoms();
        }
        this.gammaSum = this.gamma * this.numStates;
        this.stateTopicCounts = new int[this.numStates][this.numTopics];
        this.stateTopicTotals = new int[this.numStates];
        this.stateStateTransitions = new int[this.numStates][this.numStates];
        this.stateTransitionTotals = new int[this.numStates];
        this.pi = 1000.0d;
        this.sumPi = this.numStates * this.pi;
        this.numSequences = 0;
        TIntIntHashMap tIntIntHashMap = new TIntIntHashMap();
        for (int i = 0; i < this.numTopics; i++) {
            tIntIntHashMap.put(i, 1);
        }
        for (int i2 = 0; i2 < this.numStates; i2++) {
            recacheStateTopicDistribution(i2, tIntIntHashMap);
        }
        for (int i3 = 0; i3 < this.numDocs; i3++) {
            sampleState(i3, this.random, true);
        }
    }

    private void recacheStateTopicDistribution(int i, TIntIntHashMap tIntIntHashMap) {
        int[] iArr = this.stateTopicCounts[i];
        double[][] dArr = this.topicLogGammaCache[i];
        for (int i2 : tIntIntHashMap.keys()) {
            double[] dArr2 = dArr[i2];
            dArr2[0] = 0.0d;
            for (int i3 = 1; i3 < dArr2.length; i3++) {
                dArr2[i3] = dArr2[i3 - 1] + Math.log(((this.alpha[i2] + i3) - 1.0d) + iArr[i2]);
            }
        }
        this.docLogGammaCache[i][0] = 0.0d;
        for (int i4 = 1; i4 < this.docLogGammaCache[i].length; i4++) {
            this.docLogGammaCache[i][i4] = this.docLogGammaCache[i][i4 - 1] + Math.log(((this.alphaSum + i4) - 1.0d) + this.stateTopicTotals[i]);
        }
    }

    public void sample() throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        for (int i = 1; i <= this.numIterations; i++) {
            long currentTimeMillis2 = System.currentTimeMillis();
            for (int i2 = 0; i2 < this.numDocs; i2++) {
                sampleState(i2, this.random, false);
            }
            System.out.print((System.currentTimeMillis() - currentTimeMillis2) + LangRequest.DEFAULT_SELECTION);
            if (i % 10 == 0) {
                System.out.println("<" + i + "> ");
                PrintWriter printWriter = new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + i)));
                printWriter.print(stateTransitionMatrix());
                printWriter.close();
                PrintWriter printWriter2 = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + i)));
                printWriter2.print(stateTopics());
                printWriter2.close();
                if (i % 10 == 0) {
                    PrintWriter printWriter3 = new PrintWriter(new BufferedWriter(new FileWriter("states." + i)));
                    for (int i3 = 0; i3 < this.documentStates.length; i3++) {
                        printWriter3.println(this.documentStates[i3]);
                    }
                    printWriter3.close();
                }
            }
            System.out.flush();
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    public void loadTopicsFromFile(String str) throws IOException {
        BufferedReader bufferedReader = str.endsWith(".gz") ? new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(str)))) : new BufferedReader(new FileReader(new File(str)));
        this.numDocs = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                System.out.println("loaded topics, " + this.numDocs + " documents");
                return;
            }
            if (!readLine.startsWith("#")) {
                String[] split = readLine.split(LangRequest.DEFAULT_SELECTION);
                int parseInt = Integer.parseInt(split[0]);
                Integer.parseInt(split[1]);
                Integer.parseInt(split[2]);
                int parseInt2 = Integer.parseInt(split[4]);
                if (!this.documentTopics.containsKey(parseInt)) {
                    this.documentTopics.put(parseInt, new TIntIntHashMap());
                }
                if (this.documentTopics.get(parseInt).containsKey(parseInt2)) {
                    this.documentTopics.get(parseInt).increment(parseInt2);
                } else {
                    this.documentTopics.get(parseInt).put(parseInt2, 1);
                }
                if (parseInt >= this.numDocs) {
                    this.numDocs = parseInt + 1;
                }
            }
        }
    }

    public void loadAlphaFromFile(String str) throws IOException {
        this.alphaSum = 0.0d;
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                System.out.println("loaded alpha");
                return;
            } else if (!readLine.equals("")) {
                String[] split = readLine.split("\\s+");
                int parseInt = Integer.parseInt(split[0]);
                this.alpha[parseInt] = 1.0d;
                this.alphaSum += this.alpha[parseInt];
                StringBuffer stringBuffer = new StringBuffer();
                for (int i = 2; i < split.length; i++) {
                    stringBuffer.append(split[i] + LangRequest.DEFAULT_SELECTION);
                }
                this.topicKeys[parseInt] = stringBuffer.toString();
            }
        }
    }

    public void loadSequenceIDsFromFile(String str) throws IOException {
        int i = 0;
        int i2 = -1;
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            }
            int parseInt = Integer.parseInt(readLine.split("\\t")[0]);
            this.documentSequenceIDs[i] = parseInt;
            if (parseInt != i2) {
                this.numSequences++;
            }
            i2 = parseInt;
            i++;
        }
        bufferedReader.close();
        if (i != this.numDocs) {
            System.out.println("Warning: number of documents with topics (" + this.numDocs + ") is not equal to number of docs with sequence IDs (" + i + ")");
        }
        System.out.println("loaded sequence");
    }

    private void sampleState(int i, Randoms randoms, boolean z) {
        System.currentTimeMillis();
        if (this.documentTopics.containsKey(i)) {
            TIntIntHashMap tIntIntHashMap = this.documentTopics.get(i);
            int i2 = this.documentStates[i];
            int[] iArr = this.stateTopicCounts[i2];
            int i3 = 0;
            for (int i4 : tIntIntHashMap.keys()) {
                int i5 = tIntIntHashMap.get(i4);
                if (!z) {
                    iArr[i4] = iArr[i4] - i5;
                }
                i3 += i5;
            }
            if (!z) {
                int[] iArr2 = this.stateTopicTotals;
                iArr2[i2] = iArr2[i2] - i3;
                recacheStateTopicDistribution(i2, tIntIntHashMap);
            }
            int i6 = i > 0 ? this.documentSequenceIDs[i - 1] : -1;
            int i7 = this.documentSequenceIDs[i];
            int i8 = -1;
            if (!z && i < this.numDocs - 1) {
                i8 = this.documentSequenceIDs[i + 1];
            }
            double[] dArr = new double[this.numStates];
            double[] dArr2 = new double[this.numStates];
            if (z) {
                if (i6 != i7) {
                    for (int i9 = 0; i9 < this.numStates; i9++) {
                        dArr[i9] = Math.log((this.initialStateCounts[i9] + this.pi) / ((this.numSequences - 1) + this.sumPi));
                    }
                } else {
                    int i10 = this.documentStates[i - 1];
                    for (int i11 = 0; i11 < this.numStates; i11++) {
                        dArr[i11] = Math.log(this.stateStateTransitions[i10][i11] + this.gamma);
                        if (Double.isInfinite(dArr[i11])) {
                            System.out.println("infinite end");
                        }
                    }
                }
            } else if (i6 != i7 && i7 != i8) {
                int[] iArr3 = this.initialStateCounts;
                iArr3[i2] = iArr3[i2] - 1;
                for (int i12 = 0; i12 < this.numStates; i12++) {
                    dArr[i12] = Math.log((this.initialStateCounts[i12] + this.pi) / ((this.numSequences - 1) + this.sumPi));
                }
            } else if (i6 != i7) {
                int[] iArr4 = this.initialStateCounts;
                iArr4[i2] = iArr4[i2] - 1;
                int i13 = this.documentStates[i + 1];
                int[] iArr5 = this.stateStateTransitions[i2];
                iArr5[i13] = iArr5[i13] - 1;
                if (!$assertionsDisabled && this.stateStateTransitions[i2][i13] < 0) {
                    throw new AssertionError();
                }
                int[] iArr6 = this.stateTransitionTotals;
                iArr6[i2] = iArr6[i2] - 1;
                for (int i14 = 0; i14 < this.numStates; i14++) {
                    dArr[i14] = Math.log(((this.stateStateTransitions[i14][i13] + this.gamma) * (this.initialStateCounts[i14] + this.pi)) / ((this.numSequences - 1) + this.sumPi));
                    if (Double.isInfinite(dArr[i14])) {
                        System.out.println("infinite beginning");
                    }
                }
            } else if (i7 != i8) {
                int i15 = this.documentStates[i - 1];
                int[] iArr7 = this.stateStateTransitions[i15];
                iArr7[i2] = iArr7[i2] - 1;
                if (!$assertionsDisabled && this.stateStateTransitions[i15][i2] < 0) {
                    throw new AssertionError();
                }
                for (int i16 = 0; i16 < this.numStates; i16++) {
                    dArr[i16] = Math.log(this.stateStateTransitions[i15][i16] + this.gamma);
                    if (Double.isInfinite(dArr[i16])) {
                        System.out.println("infinite end");
                    }
                }
            } else {
                int i17 = this.documentStates[i + 1];
                int[] iArr8 = this.stateStateTransitions[i2];
                iArr8[i17] = iArr8[i17] - 1;
                if (this.stateStateTransitions[i2][i17] < 0) {
                    System.out.println(printStateTransitions());
                    System.out.println(i2 + " -> " + i17);
                    System.out.println(i7);
                }
                if (!$assertionsDisabled && this.stateStateTransitions[i2][i17] < 0) {
                    throw new AssertionError();
                }
                int[] iArr9 = this.stateTransitionTotals;
                iArr9[i2] = iArr9[i2] - 1;
                int i18 = this.documentStates[i - 1];
                int[] iArr10 = this.stateStateTransitions[i18];
                iArr10[i2] = iArr10[i2] - 1;
                if (!$assertionsDisabled && this.stateStateTransitions[i18][i2] < 0) {
                    throw new AssertionError();
                }
                for (int i19 = 0; i19 < this.numStates; i19++) {
                    if (i18 == i19 && i19 == i17) {
                        dArr[i19] = Math.log(((this.stateStateTransitions[i18][i19] + this.gamma) * ((this.stateStateTransitions[i19][i17] + 1) + this.gamma)) / ((this.stateTransitionTotals[i19] + 1) + this.gammaSum));
                    } else if (i18 == i19) {
                        dArr[i19] = Math.log(((this.stateStateTransitions[i18][i19] + this.gamma) * (this.stateStateTransitions[i19][i17] + this.gamma)) / ((this.stateTransitionTotals[i19] + 1) + this.gammaSum));
                    } else {
                        dArr[i19] = Math.log(((this.stateStateTransitions[i18][i19] + this.gamma) * (this.stateStateTransitions[i19][i17] + this.gamma)) / (this.stateTransitionTotals[i19] + this.gammaSum));
                    }
                    if (Double.isInfinite(dArr[i19])) {
                        System.out.println("infinite middle: " + i);
                        System.out.println(i18 + " -> " + i19 + " -> " + i17);
                        System.out.println(this.stateStateTransitions[i18][i19] + " -> " + this.stateStateTransitions[i19][i17] + " / " + this.stateTransitionTotals[i19]);
                    }
                }
            }
            double d = Double.NEGATIVE_INFINITY;
            for (int i20 = 0; i20 < this.numStates; i20++) {
                int i21 = i20;
                dArr[i21] = dArr[i21] - (this.stateTransitionTotals[i20] / 10);
                int[] iArr11 = this.stateTopicCounts[i20];
                double[][] dArr3 = this.topicLogGammaCache[i20];
                for (int i22 : tIntIntHashMap.keys()) {
                    int i23 = i20;
                    dArr[i23] = dArr[i23] + dArr3[i22][tIntIntHashMap.get(i22)];
                }
                int i24 = i20;
                dArr[i24] = dArr[i24] - this.docLogGammaCache[i20][i3];
                if (dArr[i20] > d) {
                    d = dArr[i20];
                }
            }
            double d2 = 0.0d;
            for (int i25 = 0; i25 < this.numStates; i25++) {
                if (Double.isNaN(dArr2[i25])) {
                    System.out.println(dArr[i25]);
                }
                if (!$assertionsDisabled && Double.isNaN(dArr2[i25])) {
                    throw new AssertionError();
                }
                dArr2[i25] = Math.exp(dArr[i25] - d);
                d2 += dArr2[i25];
                if (Double.isNaN(dArr2[i25])) {
                    System.out.println(dArr[i25]);
                }
                if (!$assertionsDisabled && Double.isNaN(dArr2[i25])) {
                    throw new AssertionError();
                }
                if (i % 100 == 0) {
                }
            }
            int nextDiscrete = randoms.nextDiscrete(dArr2, d2);
            this.documentStates[i] = nextDiscrete;
            for (int i26 = 0; i26 < this.numTopics; i26++) {
                int[] iArr12 = this.stateTopicCounts[nextDiscrete];
                int i27 = i26;
                iArr12[i27] = iArr12[i27] + tIntIntHashMap.get(i26);
            }
            int[] iArr13 = this.stateTopicTotals;
            iArr13[nextDiscrete] = iArr13[nextDiscrete] + i3;
            recacheStateTopicDistribution(nextDiscrete, tIntIntHashMap);
            if (z) {
                if (i6 != i7) {
                    int[] iArr14 = this.initialStateCounts;
                    iArr14[nextDiscrete] = iArr14[nextDiscrete] + 1;
                    return;
                }
                int[] iArr15 = this.stateStateTransitions[this.documentStates[i - 1]];
                iArr15[nextDiscrete] = iArr15[nextDiscrete] + 1;
                int[] iArr16 = this.stateTransitionTotals;
                iArr16[nextDiscrete] = iArr16[nextDiscrete] + 1;
                return;
            }
            if (i6 != i7 && i7 != i8) {
                int[] iArr17 = this.initialStateCounts;
                iArr17[nextDiscrete] = iArr17[nextDiscrete] + 1;
                return;
            }
            if (i6 != i7) {
                int[] iArr18 = this.initialStateCounts;
                iArr18[nextDiscrete] = iArr18[nextDiscrete] + 1;
                int i28 = this.documentStates[i + 1];
                int[] iArr19 = this.stateStateTransitions[nextDiscrete];
                iArr19[i28] = iArr19[i28] + 1;
                int[] iArr20 = this.stateTransitionTotals;
                iArr20[nextDiscrete] = iArr20[nextDiscrete] + 1;
                return;
            }
            if (i7 != i8) {
                int[] iArr21 = this.stateStateTransitions[this.documentStates[i - 1]];
                iArr21[nextDiscrete] = iArr21[nextDiscrete] + 1;
                return;
            }
            int[] iArr22 = this.stateStateTransitions[this.documentStates[i - 1]];
            iArr22[nextDiscrete] = iArr22[nextDiscrete] + 1;
            int i29 = this.documentStates[i + 1];
            int[] iArr23 = this.stateStateTransitions[nextDiscrete];
            iArr23[i29] = iArr23[i29] + 1;
            int[] iArr24 = this.stateTransitionTotals;
            iArr24[nextDiscrete] = iArr24[nextDiscrete] + 1;
        }
    }

    public String printStateTransitions() {
        StringBuffer stringBuffer = new StringBuffer();
        IDSorter[] iDSorterArr = new IDSorter[this.numTopics];
        for (int i = 0; i < this.numStates; i++) {
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                iDSorterArr[i2] = new IDSorter(i2, this.stateTopicCounts[i][i2] / this.stateTopicTotals[i]);
            }
            Arrays.sort(iDSorterArr);
            stringBuffer.append("\n" + i + "\n");
            for (int i3 = 0; i3 < 4; i3++) {
                int id = iDSorterArr[i3].getID();
                stringBuffer.append(this.stateTopicCounts[i][id] + "\t" + this.topicKeys[id] + "\n");
            }
            stringBuffer.append("\n");
            stringBuffer.append("[" + this.initialStateCounts[i] + "/" + this.numSequences + "] ");
            stringBuffer.append("[" + this.stateTransitionTotals[i] + "]");
            for (int i4 = 0; i4 < this.numStates; i4++) {
                stringBuffer.append("\t");
                if (i == i4) {
                    stringBuffer.append("[" + this.stateStateTransitions[i][i4] + "]");
                } else {
                    stringBuffer.append(this.stateStateTransitions[i][i4]);
                }
            }
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public String stateTransitionMatrix() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.numStates; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                stringBuffer.append(this.stateStateTransitions[i][i2]);
                stringBuffer.append("\t");
            }
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public String stateTopics() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.numStates; i++) {
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                stringBuffer.append(this.stateTopicCounts[i][i2] + "\t");
            }
            stringBuffer.append("\n");
        }
        return stringBuffer.toString();
    }

    public static void main(String[] strArr) throws IOException {
        if (strArr.length != 4) {
            System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
            System.exit(0);
        }
        MultinomialHMM multinomialHMM = new MultinomialHMM(Integer.parseInt(strArr[0]), strArr[1], Constants.FCMPG);
        multinomialHMM.setGamma(1.0d);
        multinomialHMM.setRandomSeed(1);
        multinomialHMM.loadAlphaFromFile(strArr[2]);
        multinomialHMM.loadSequenceIDsFromFile(strArr[3]);
        multinomialHMM.initialize();
        multinomialHMM.sample();
    }

    static {
        $assertionsDisabled = !MultinomialHMM.class.desiredAssertionStatus();
    }
}
