package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Iterator;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:cc/mallet/grmm/inference/JunctionTreePropagation.class */
public class JunctionTreePropagation implements Serializable {
    private static Logger logger = MalletLogger.getLogger(JunctionTreePropagation.class.getName());
    private transient int totalMessagesSent = 0;
    private MessageStrategy strategy;
    private static final long serialVersionUID = 1;
    private static final int CUURENT_SERIAL_VERSION = 1;

    /* loaded from: input_file:cc/mallet/grmm/inference/JunctionTreePropagation$MaxProductMessageStrategy.class */
    public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {
        private static final long serialVersionUID = 1;
        private static final int CUURENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.grmm.inference.JunctionTreePropagation.MessageStrategy
        public void sendMessage(JunctionTree junctionTree, VarSet varSet, VarSet varSet2) {
            Set sepset = junctionTree.getSepset(varSet, varSet2);
            Factor cpf = junctionTree.getCPF(varSet);
            Factor cpf2 = junctionTree.getCPF(varSet2);
            Factor sepsetPot = junctionTree.getSepsetPot(varSet, varSet2);
            Factor extractMax = cpf.extractMax(sepset);
            extractMax.normalize();
            junctionTree.setSepsetPot(extractMax, varSet, varSet2);
            Factor multiply = cpf2.multiply(extractMax);
            multiply.divideBy(sepsetPot);
            multiply.normalize();
            junctionTree.setCPF(varSet2, multiply);
        }

        @Override // cc.mallet.grmm.inference.JunctionTreePropagation.MessageStrategy
        public Factor extractBelief(Factor factor, VarSet varSet) {
            return factor.extractMax(varSet);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            objectInputStream.readInt();
        }
    }

    /* loaded from: input_file:cc/mallet/grmm/inference/JunctionTreePropagation$MessageStrategy.class */
    public interface MessageStrategy {
        void sendMessage(JunctionTree junctionTree, VarSet varSet, VarSet varSet2);

        Factor extractBelief(Factor factor, VarSet varSet);
    }

    /* loaded from: input_file:cc/mallet/grmm/inference/JunctionTreePropagation$SumProductMessageStrategy.class */
    public static class SumProductMessageStrategy implements MessageStrategy, Serializable {
        private static final long serialVersionUID = 1;
        private static final int CUURENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.grmm.inference.JunctionTreePropagation.MessageStrategy
        public void sendMessage(JunctionTree junctionTree, VarSet varSet, VarSet varSet2) {
            Set sepset = junctionTree.getSepset(varSet, varSet2);
            Factor cpf = junctionTree.getCPF(varSet);
            Factor cpf2 = junctionTree.getCPF(varSet2);
            Factor sepsetPot = junctionTree.getSepsetPot(varSet, varSet2);
            Factor marginalize = cpf.marginalize(sepset);
            marginalize.normalize();
            junctionTree.setSepsetPot(marginalize, varSet, varSet2);
            Factor multiply = cpf2.multiply(marginalize);
            multiply.divideBy(sepsetPot);
            multiply.normalize();
            junctionTree.setCPF(varSet2, multiply);
        }

        @Override // cc.mallet.grmm.inference.JunctionTreePropagation.MessageStrategy
        public Factor extractBelief(Factor factor, VarSet varSet) {
            return factor.marginalize(varSet);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            objectInputStream.readInt();
        }
    }

    public JunctionTreePropagation(MessageStrategy messageStrategy) {
        this.strategy = messageStrategy;
    }

    public static JunctionTreePropagation createSumProductInferencer() {
        return new JunctionTreePropagation(new SumProductMessageStrategy());
    }

    public static JunctionTreePropagation createMaxProductInferencer() {
        return new JunctionTreePropagation(new MaxProductMessageStrategy());
    }

    public int getTotalMessagesSent() {
        return this.totalMessagesSent;
    }

    public void computeMarginals(JunctionTree junctionTree) {
        propagate(junctionTree);
        junctionTree.normalizeAll();
    }

    private void collectEvidence(JunctionTree junctionTree, VarSet varSet, VarSet varSet2) {
        logger.finer("collectEvidence " + varSet + " --> " + varSet2);
        Iterator it2 = junctionTree.getChildren(varSet2).iterator();
        while (it2.hasNext()) {
            collectEvidence(junctionTree, varSet2, (VarSet) it2.next());
        }
        if (varSet != null) {
            this.totalMessagesSent++;
            this.strategy.sendMessage(junctionTree, varSet2, varSet);
        }
    }

    private void distributeEvidence(JunctionTree junctionTree, VarSet varSet) {
        for (VarSet varSet2 : junctionTree.getChildren(varSet)) {
            this.totalMessagesSent++;
            this.strategy.sendMessage(junctionTree, varSet, varSet2);
            distributeEvidence(junctionTree, varSet2);
        }
    }

    private void propagate(JunctionTree junctionTree) {
        VarSet varSet = (VarSet) junctionTree.getRoot();
        collectEvidence(junctionTree, null, varSet);
        distributeEvidence(junctionTree, varSet);
    }

    public Factor lookupMarginal(JunctionTree junctionTree, VarSet varSet) {
        if (junctionTree == null) {
            throw new IllegalStateException("Call computeMarginals() first.");
        }
        VarSet findParentCluster = junctionTree.findParentCluster(varSet);
        if (findParentCluster == null) {
            throw new UnsupportedOperationException("No parent cluster in " + junctionTree + " for clique " + varSet);
        }
        Factor cpf = junctionTree.getCPF(findParentCluster);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Lookup jt marginal: clique " + varSet + " cluster " + findParentCluster);
            logger.finest("  cpf " + cpf);
        }
        Factor extractBelief = this.strategy.extractBelief(cpf, varSet);
        extractBelief.normalize();
        return extractBelief;
    }

    public Factor lookupMarginal(JunctionTree junctionTree, Variable variable) {
        if (junctionTree == null) {
            throw new IllegalStateException("Call computeMarginals() first.");
        }
        VarSet findParentCluster = junctionTree.findParentCluster(variable);
        Factor cpf = junctionTree.getCPF(findParentCluster);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Lookup jt marginal: var " + variable + " cluster " + findParentCluster);
            logger.finest(" cpf " + cpf);
        }
        Factor extractBelief = this.strategy.extractBelief(cpf, new HashVarSet(new Variable[]{variable}));
        extractBelief.normalize();
        return extractBelief;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeInt(1);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        objectInputStream.readInt();
    }
}
