package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectDoubleHashMap;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import net.didion.jwnl.dictionary.file.DictionaryFile;

/* loaded from: input_file:cc/mallet/topics/HierarchicalLDA.class */
public class HierarchicalLDA {
    InstanceList instances;
    InstanceList testing;
    NCRPNode rootNode;
    NCRPNode node;
    int numLevels;
    int numDocuments;
    int numTypes;
    double etaSum;
    int[][] levels;
    NCRPNode[] documentLeaves;
    Randoms random;
    static final /* synthetic */ boolean $assertionsDisabled;
    int totalNodes = 0;
    String stateFile = "hlda.state";
    boolean showProgress = true;
    int displayTopicsInterval = 50;
    int numWordsToDisplay = 10;
    double alpha = 10.0d;
    double gamma = 1.0d;
    double eta = 0.1d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:cc/mallet/topics/HierarchicalLDA$NCRPNode.class */
    public class NCRPNode {
        int customers;
        ArrayList<NCRPNode> children;
        NCRPNode parent;
        int level;
        int totalTokens;
        int[] typeCounts;
        public int nodeID;

        public NCRPNode(NCRPNode nCRPNode, int i, int i2) {
            this.customers = 0;
            this.parent = nCRPNode;
            this.children = new ArrayList<>();
            this.level = i2;
            this.totalTokens = 0;
            this.typeCounts = new int[i];
            this.nodeID = HierarchicalLDA.this.totalNodes;
            HierarchicalLDA.this.totalNodes++;
        }

        public NCRPNode(HierarchicalLDA hierarchicalLDA, int i) {
            this(null, i, 0);
        }

        public NCRPNode addChild() {
            NCRPNode nCRPNode = new NCRPNode(this, this.typeCounts.length, this.level + 1);
            this.children.add(nCRPNode);
            return nCRPNode;
        }

        public boolean isLeaf() {
            return this.level == HierarchicalLDA.this.numLevels - 1;
        }

        public NCRPNode getNewLeaf() {
            NCRPNode nCRPNode = this;
            for (int i = this.level; i < HierarchicalLDA.this.numLevels - 1; i++) {
                nCRPNode = nCRPNode.addChild();
            }
            return nCRPNode;
        }

        public void dropPath() {
            NCRPNode nCRPNode = this;
            nCRPNode.customers--;
            if (nCRPNode.customers == 0) {
                nCRPNode.parent.remove(nCRPNode);
            }
            for (int i = 1; i < HierarchicalLDA.this.numLevels; i++) {
                nCRPNode = nCRPNode.parent;
                nCRPNode.customers--;
                if (nCRPNode.customers == 0) {
                    nCRPNode.parent.remove(nCRPNode);
                }
            }
        }

        public void remove(NCRPNode nCRPNode) {
            this.children.remove(nCRPNode);
        }

        public void addPath() {
            NCRPNode nCRPNode = this;
            nCRPNode.customers++;
            for (int i = 1; i < HierarchicalLDA.this.numLevels; i++) {
                nCRPNode = nCRPNode.parent;
                nCRPNode.customers++;
            }
        }

        public NCRPNode selectExisting() {
            double[] dArr = new double[this.children.size()];
            int i = 0;
            Iterator<NCRPNode> it2 = this.children.iterator();
            while (it2.hasNext()) {
                dArr[i] = it2.next().customers / (HierarchicalLDA.this.gamma + this.customers);
                i++;
            }
            return this.children.get(HierarchicalLDA.this.random.nextDiscrete(dArr));
        }

        public NCRPNode select() {
            double[] dArr = new double[this.children.size() + 1];
            dArr[0] = HierarchicalLDA.this.gamma / (HierarchicalLDA.this.gamma + this.customers);
            int i = 1;
            Iterator<NCRPNode> it2 = this.children.iterator();
            while (it2.hasNext()) {
                dArr[i] = it2.next().customers / (HierarchicalLDA.this.gamma + this.customers);
                i++;
            }
            int nextDiscrete = HierarchicalLDA.this.random.nextDiscrete(dArr);
            return nextDiscrete == 0 ? addChild() : this.children.get(nextDiscrete - 1);
        }

        public String getTopWords(int i, boolean z) {
            IDSorter[] iDSorterArr = new IDSorter[HierarchicalLDA.this.numTypes];
            for (int i2 = 0; i2 < HierarchicalLDA.this.numTypes; i2++) {
                iDSorterArr[i2] = new IDSorter(i2, this.typeCounts[i2]);
            }
            Arrays.sort(iDSorterArr);
            Alphabet dataAlphabet = HierarchicalLDA.this.instances.getDataAlphabet();
            StringBuffer stringBuffer = new StringBuffer();
            for (int i3 = 0; i3 < i; i3++) {
                if (z) {
                    stringBuffer.append(dataAlphabet.lookupObject(iDSorterArr[i3].getID()) + ":" + iDSorterArr[i3].getWeight() + LangRequest.DEFAULT_SELECTION);
                } else {
                    stringBuffer.append(dataAlphabet.lookupObject(iDSorterArr[i3].getID()) + LangRequest.DEFAULT_SELECTION);
                }
            }
            return stringBuffer.toString();
        }
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    public void setGamma(double d) {
        this.gamma = d;
    }

    public void setEta(double d) {
        this.eta = d;
    }

    public void setStateFile(String str) {
        this.stateFile = str;
    }

    public void setTopicDisplay(int i, int i2) {
        this.displayTopicsInterval = i;
        this.numWordsToDisplay = i2;
    }

    public void setProgressDisplay(boolean z) {
        this.showProgress = z;
    }

    /* JADX WARN: Type inference failed for: r1v16, types: [int[], int[][]] */
    public void initialize(InstanceList instanceList, InstanceList instanceList2, int i, Randoms randoms) {
        this.instances = instanceList;
        this.testing = instanceList2;
        this.numLevels = i;
        this.random = randoms;
        if (!(instanceList.get(0).getData() instanceof FeatureSequence)) {
            throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
        }
        this.numDocuments = instanceList.size();
        this.numTypes = instanceList.getDataAlphabet().size();
        this.etaSum = this.eta * this.numTypes;
        NCRPNode[] nCRPNodeArr = new NCRPNode[i];
        this.rootNode = new NCRPNode(this, this.numTypes);
        this.levels = new int[this.numDocuments];
        this.documentLeaves = new NCRPNode[this.numDocuments];
        for (int i2 = 0; i2 < this.numDocuments; i2++) {
            FeatureSequence featureSequence = (FeatureSequence) instanceList.get(i2).getData();
            int length = featureSequence.getLength();
            nCRPNodeArr[0] = this.rootNode;
            this.rootNode.customers++;
            for (int i3 = 1; i3 < i; i3++) {
                nCRPNodeArr[i3] = nCRPNodeArr[i3 - 1].select();
                nCRPNodeArr[i3].customers++;
            }
            this.node = nCRPNodeArr[i - 1];
            this.levels[i2] = new int[length];
            this.documentLeaves[i2] = this.node;
            for (int i4 = 0; i4 < length; i4++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i4);
                this.levels[i2][i4] = randoms.nextInt(i);
                this.node = nCRPNodeArr[this.levels[i2][i4]];
                this.node.totalTokens++;
                int[] iArr = this.node.typeCounts;
                iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
            }
        }
    }

    public void estimate(int i) {
        for (int i2 = 1; i2 <= i; i2++) {
            for (int i3 = 0; i3 < this.numDocuments; i3++) {
                samplePath(i3, i2);
            }
            for (int i4 = 0; i4 < this.numDocuments; i4++) {
                sampleTopics(i4);
            }
            if (this.showProgress) {
                System.out.print(".");
                if (i2 % 50 == 0) {
                    System.out.println(LangRequest.DEFAULT_SELECTION + i2);
                }
            }
            if (i2 % this.displayTopicsInterval == 0) {
                printNodes();
            }
        }
    }

    public void samplePath(int i, int i2) {
        NCRPNode[] nCRPNodeArr = new NCRPNode[this.numLevels];
        NCRPNode nCRPNode = this.documentLeaves[i];
        for (int i3 = this.numLevels - 1; i3 >= 0; i3--) {
            nCRPNodeArr[i3] = nCRPNode;
            nCRPNode = nCRPNode.parent;
        }
        this.documentLeaves[i].dropPath();
        TObjectDoubleHashMap<NCRPNode> tObjectDoubleHashMap = new TObjectDoubleHashMap<>();
        calculateNCRP(tObjectDoubleHashMap, this.rootNode, 0.0d);
        TIntIntHashMap[] tIntIntHashMapArr = new TIntIntHashMap[this.numLevels];
        for (int i4 = 0; i4 < this.numLevels; i4++) {
            tIntIntHashMapArr[i4] = new TIntIntHashMap();
        }
        int[] iArr = this.levels[i];
        FeatureSequence featureSequence = (FeatureSequence) this.instances.get(i).getData();
        for (int i5 = 0; i5 < iArr.length; i5++) {
            int i6 = iArr[i5];
            int indexAtPosition = featureSequence.getIndexAtPosition(i5);
            if (tIntIntHashMapArr[i6].containsKey(indexAtPosition)) {
                tIntIntHashMapArr[i6].increment(indexAtPosition);
            } else {
                tIntIntHashMapArr[i6].put(indexAtPosition, 1);
            }
            int[] iArr2 = nCRPNodeArr[i6].typeCounts;
            iArr2[indexAtPosition] = iArr2[indexAtPosition] - 1;
            if (!$assertionsDisabled && nCRPNodeArr[i6].typeCounts[indexAtPosition] < 0) {
                throw new AssertionError();
            }
            nCRPNodeArr[i6].totalTokens--;
            if (!$assertionsDisabled && nCRPNodeArr[i6].totalTokens < 0) {
                throw new AssertionError();
            }
        }
        double[] dArr = new double[this.numLevels];
        for (int i7 = 1; i7 < this.numLevels; i7++) {
            int i8 = 0;
            for (int i9 : tIntIntHashMapArr[i7].keys()) {
                for (int i10 = 0; i10 < tIntIntHashMapArr[i7].get(i9); i10++) {
                    int i11 = i7;
                    dArr[i11] = dArr[i11] + Math.log((this.eta + i10) / (this.etaSum + i8));
                    i8++;
                }
            }
        }
        calculateWordLikelihood(tObjectDoubleHashMap, this.rootNode, 0.0d, tIntIntHashMapArr, dArr, 0, i2);
        NCRPNode[] keys = tObjectDoubleHashMap.keys(new NCRPNode[0]);
        double[] dArr2 = new double[keys.length];
        double d = 0.0d;
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i12 = 0; i12 < keys.length; i12++) {
            if (tObjectDoubleHashMap.get(keys[i12]) > d2) {
                d2 = tObjectDoubleHashMap.get(keys[i12]);
            }
        }
        for (int i13 = 0; i13 < keys.length; i13++) {
            dArr2[i13] = Math.exp(tObjectDoubleHashMap.get(keys[i13]) - d2);
            d += dArr2[i13];
        }
        NCRPNode nCRPNode2 = keys[this.random.nextDiscrete(dArr2, d)];
        if (!nCRPNode2.isLeaf()) {
            nCRPNode2 = nCRPNode2.getNewLeaf();
        }
        nCRPNode2.addPath();
        this.documentLeaves[i] = nCRPNode2;
        for (int i14 = this.numLevels - 1; i14 >= 0; i14--) {
            for (int i15 : tIntIntHashMapArr[i14].keys()) {
                int[] iArr3 = nCRPNode2.typeCounts;
                iArr3[i15] = iArr3[i15] + tIntIntHashMapArr[i14].get(i15);
                nCRPNode2.totalTokens += tIntIntHashMapArr[i14].get(i15);
            }
            nCRPNode2 = nCRPNode2.parent;
        }
    }

    public void calculateNCRP(TObjectDoubleHashMap<NCRPNode> tObjectDoubleHashMap, NCRPNode nCRPNode, double d) {
        Iterator<NCRPNode> it2 = nCRPNode.children.iterator();
        while (it2.hasNext()) {
            calculateNCRP(tObjectDoubleHashMap, it2.next(), d + Math.log(r0.customers / (nCRPNode.customers + this.gamma)));
        }
        tObjectDoubleHashMap.put(nCRPNode, d + Math.log(this.gamma / (nCRPNode.customers + this.gamma)));
    }

    public void calculateWordLikelihood(TObjectDoubleHashMap<NCRPNode> tObjectDoubleHashMap, NCRPNode nCRPNode, double d, TIntIntHashMap[] tIntIntHashMapArr, double[] dArr, int i, int i2) {
        double d2 = 0.0d;
        int i3 = 0;
        for (int i4 : tIntIntHashMapArr[i].keys()) {
            for (int i5 = 0; i5 < tIntIntHashMapArr[i].get(i4); i5++) {
                d2 += Math.log(((this.eta + nCRPNode.typeCounts[i4]) + i5) / ((this.etaSum + nCRPNode.totalTokens) + i3));
                i3++;
            }
        }
        Iterator<NCRPNode> it2 = nCRPNode.children.iterator();
        while (it2.hasNext()) {
            calculateWordLikelihood(tObjectDoubleHashMap, it2.next(), d + d2, tIntIntHashMapArr, dArr, i + 1, i2);
        }
        while (true) {
            i++;
            if (i >= this.numLevels) {
                tObjectDoubleHashMap.adjustValue(nCRPNode, d2);
                return;
            }
            d2 += dArr[i];
        }
    }

    public void propagateTopicWeight(TObjectDoubleHashMap<NCRPNode> tObjectDoubleHashMap, NCRPNode nCRPNode, double d) {
        if (tObjectDoubleHashMap.containsKey(nCRPNode)) {
            Iterator<NCRPNode> it2 = nCRPNode.children.iterator();
            while (it2.hasNext()) {
                propagateTopicWeight(tObjectDoubleHashMap, it2.next(), d);
            }
            tObjectDoubleHashMap.adjustValue(nCRPNode, d);
        }
    }

    public void sampleTopics(int i) {
        FeatureSequence featureSequence = (FeatureSequence) this.instances.get(i).getData();
        int length = featureSequence.getLength();
        int[] iArr = this.levels[i];
        NCRPNode[] nCRPNodeArr = new NCRPNode[this.numLevels];
        int[] iArr2 = new int[this.numLevels];
        NCRPNode nCRPNode = this.documentLeaves[i];
        for (int i2 = this.numLevels - 1; i2 >= 0; i2--) {
            nCRPNodeArr[i2] = nCRPNode;
            nCRPNode = nCRPNode.parent;
        }
        double[] dArr = new double[this.numLevels];
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = iArr[i3];
            iArr2[i4] = iArr2[i4] + 1;
        }
        for (int i5 = 0; i5 < length; i5++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i5);
            int i6 = iArr[i5];
            iArr2[i6] = iArr2[i6] - 1;
            NCRPNode nCRPNode2 = nCRPNodeArr[iArr[i5]];
            int[] iArr3 = nCRPNode2.typeCounts;
            iArr3[indexAtPosition] = iArr3[indexAtPosition] - 1;
            nCRPNode2.totalTokens--;
            double d = 0.0d;
            for (int i7 = 0; i7 < this.numLevels; i7++) {
                dArr[i7] = ((this.alpha + iArr2[i7]) * (this.eta + nCRPNodeArr[i7].typeCounts[indexAtPosition])) / (this.etaSum + nCRPNodeArr[i7].totalTokens);
                d += dArr[i7];
            }
            int nextDiscrete = this.random.nextDiscrete(dArr, d);
            iArr[i5] = nextDiscrete;
            int i8 = iArr[i5];
            iArr2[i8] = iArr2[i8] + 1;
            NCRPNode nCRPNode3 = nCRPNodeArr[nextDiscrete];
            int[] iArr4 = nCRPNode3.typeCounts;
            iArr4[indexAtPosition] = iArr4[indexAtPosition] + 1;
            nCRPNode3.totalTokens++;
        }
    }

    public void printState() throws IOException, FileNotFoundException {
        printState(new PrintWriter(new BufferedWriter(new FileWriter(this.stateFile))));
    }

    public void printState(PrintWriter printWriter) throws IOException {
        int i = 0;
        Alphabet dataAlphabet = this.instances.getDataAlphabet();
        Iterator<Instance> it2 = this.instances.iterator();
        while (it2.hasNext()) {
            FeatureSequence featureSequence = (FeatureSequence) it2.next().getData();
            int length = featureSequence.getLength();
            int[] iArr = this.levels[i];
            StringBuffer stringBuffer = new StringBuffer();
            NCRPNode nCRPNode = this.documentLeaves[i];
            for (int i2 = this.numLevels - 1; i2 >= 0; i2--) {
                stringBuffer.append(nCRPNode.nodeID + LangRequest.DEFAULT_SELECTION);
                nCRPNode = nCRPNode.parent;
            }
            for (int i3 = 0; i3 < length; i3++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i3);
                printWriter.println(((Object) stringBuffer) + "" + indexAtPosition + LangRequest.DEFAULT_SELECTION + dataAlphabet.lookupObject(indexAtPosition) + LangRequest.DEFAULT_SELECTION + iArr[i3] + LangRequest.DEFAULT_SELECTION);
            }
            i++;
        }
    }

    public void printNodes() {
        printNode(this.rootNode, 0, false);
    }

    public void printNodes(boolean z) {
        printNode(this.rootNode, 0, z);
    }

    public void printNode(NCRPNode nCRPNode, int i, boolean z) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i2 = 0; i2 < i; i2++) {
            stringBuffer.append(DictionaryFile.COMMENT_HEADER);
        }
        stringBuffer.append(nCRPNode.totalTokens + "/" + nCRPNode.customers + LangRequest.DEFAULT_SELECTION);
        stringBuffer.append(nCRPNode.getTopWords(this.numWordsToDisplay, z));
        System.out.println(stringBuffer);
        Iterator<NCRPNode> it2 = nCRPNode.children.iterator();
        while (it2.hasNext()) {
            printNode(it2.next(), i + 1, z);
        }
    }

    public double empiricalLikelihood(int i, InstanceList instanceList) {
        NCRPNode[] nCRPNodeArr = new NCRPNode[this.numLevels];
        nCRPNodeArr[0] = this.rootNode;
        Dirichlet dirichlet = new Dirichlet(this.numLevels, this.alpha);
        double[] dArr = new double[this.numTypes];
        double[][] dArr2 = new double[instanceList.size()][i];
        for (int i2 = 0; i2 < i; i2++) {
            Arrays.fill(dArr, 0.0d);
            for (int i3 = 1; i3 < this.numLevels; i3++) {
                nCRPNodeArr[i3] = nCRPNodeArr[i3 - 1].selectExisting();
            }
            double[] nextDistribution = dirichlet.nextDistribution();
            for (int i4 = 0; i4 < this.numTypes; i4++) {
                for (int i5 = 0; i5 < this.numLevels; i5++) {
                    NCRPNode nCRPNode = nCRPNodeArr[i5];
                    int i6 = i4;
                    dArr[i6] = dArr[i6] + ((nextDistribution[i5] * (this.eta + nCRPNode.typeCounts[i4])) / (this.etaSum + nCRPNode.totalTokens));
                }
            }
            for (int i7 = 0; i7 < this.numTypes; i7++) {
                dArr[i7] = Math.log(dArr[i7]);
            }
            for (int i8 = 0; i8 < instanceList.size(); i8++) {
                FeatureSequence featureSequence = (FeatureSequence) instanceList.get(i8).getData();
                int length = featureSequence.getLength();
                for (int i9 = 0; i9 < length; i9++) {
                    int indexAtPosition = featureSequence.getIndexAtPosition(i9);
                    double[] dArr3 = dArr2[i8];
                    int i10 = i2;
                    dArr3[i10] = dArr3[i10] + dArr[indexAtPosition];
                }
            }
        }
        double d = 0.0d;
        double log = Math.log(i);
        for (int i11 = 0; i11 < instanceList.size(); i11++) {
            double d2 = Double.NEGATIVE_INFINITY;
            for (int i12 = 0; i12 < i; i12++) {
                if (dArr2[i11][i12] > d2) {
                    d2 = dArr2[i11][i12];
                }
            }
            double d3 = 0.0d;
            for (int i13 = 0; i13 < i; i13++) {
                d3 += Math.exp(dArr2[i11][i13] - d2);
            }
            d += (Math.log(d3) + d2) - log;
        }
        return d;
    }

    public static void main(String[] strArr) {
        try {
            InstanceList load = InstanceList.load(new File(strArr[0]));
            InstanceList load2 = InstanceList.load(new File(strArr[1]));
            HierarchicalLDA hierarchicalLDA = new HierarchicalLDA();
            hierarchicalLDA.initialize(load, load2, 5, new Randoms());
            hierarchicalLDA.estimate(250);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    static {
        $assertionsDisabled = !HierarchicalLDA.class.desiredAssertionStatus();
    }
}
