package cc.mallet.grmm.types;

import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/grmm/types/TableFactor.class */
public class TableFactor extends AbstractTableFactor {
    public static DiscreteFactor multiplyAll(Factor[] factorArr) {
        return multiplyAll(Arrays.asList(factorArr));
    }

    public static AbstractTableFactor multiplyAll(Collection collection) {
        if (collection.size() == 1) {
            return (AbstractTableFactor) ((Factor) collection.iterator().next()).duplicate();
        }
        HashVarSet hashVarSet = new HashVarSet();
        Iterator it = collection.iterator();
        while (it.hasNext()) {
            hashVarSet.addAll(((Factor) it.next()).varSet());
        }
        TableFactor tableFactor = new TableFactor(hashVarSet);
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            tableFactor.multiplyBy((Factor) it2.next());
        }
        return tableFactor;
    }

    public TableFactor(Variable variable) {
        super(variable);
    }

    public TableFactor(Variable variable, double[] dArr) {
        super(variable, dArr);
    }

    public TableFactor() {
    }

    public TableFactor(BidirectionalIntObjectMap bidirectionalIntObjectMap) {
        super(bidirectionalIntObjectMap);
    }

    public TableFactor(Variable[] variableArr) {
        super(variableArr);
    }

    public TableFactor(Collection collection) {
        super(collection);
    }

    public TableFactor(Variable[] variableArr, double[] dArr) {
        super(variableArr, dArr);
    }

    public TableFactor(VarSet varSet, double[] dArr) {
        super(varSet, dArr);
    }

    public TableFactor(Variable[] variableArr, Matrix matrix) {
        super(variableArr, matrix);
    }

    public TableFactor(AbstractTableFactor abstractTableFactor) {
        super(abstractTableFactor);
        this.probs = (Matrix) abstractTableFactor.getValueMatrix().cloneMatrix();
    }

    public TableFactor(VarSet varSet, Matrix matrix) {
        super(varSet, matrix);
    }

    public TableFactor(AbstractTableFactor abstractTableFactor, double[] dArr) {
        super(abstractTableFactor, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    void setAsIdentity() {
        setAll(1.0d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public Factor duplicate() {
        return new TableFactor(this);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected AbstractTableFactor createBlankSubset(Variable[] variableArr) {
        return new TableFactor(variableArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public Factor normalize() {
        Flops.increment(2 * this.probs.numLocations());
        this.probs.oneNormalize();
        return this;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public double sum() {
        Flops.increment(this.probs.numLocations());
        return this.probs.oneNorm();
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(AssignmentIterator assignmentIterator) {
        Flops.log();
        return Math.log(rawValue(assignmentIterator.indexOfCurrentAssn()));
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(Assignment assignment) {
        Flops.log();
        return Math.log(rawValue(assignment));
    }

    @Override // cc.mallet.grmm.types.Factor
    public double logValue(int i) {
        Flops.log();
        return Math.log(rawValue(i));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor, cc.mallet.grmm.types.Factor
    public double value(Assignment assignment) {
        return rawValue(assignment);
    }

    @Override // cc.mallet.grmm.types.DiscreteFactor
    public double value(int i) {
        return rawValue(i);
    }

    @Override // cc.mallet.grmm.types.Factor
    public double value(AssignmentIterator assignmentIterator) {
        return rawValue(assignmentIterator.indexOfCurrentAssn());
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor marginalizeInternal(AbstractTableFactor abstractTableFactor) {
        abstractTableFactor.setAll(0.0d);
        int[] largeIdxToSmall = largeIdxToSmall(abstractTableFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            abstractTableFactor.probs.incrementSingleValue(largeIdxToSmall[i], this.probs.valueAtLocation(i));
        }
        Flops.increment(numLocations);
        return abstractTableFactor;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void multiplyByInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            this.probs.setValueAtLocation(i, this.probs.valueAtLocation(i) * discreteFactor.value(largeIdxToSmall[i]));
        }
        Flops.increment(numLocations);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void divideByInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            int i2 = largeIdxToSmall[i];
            double valueAtLocation = this.probs.valueAtLocation(i);
            double value = discreteFactor.value(i2);
            double d = valueAtLocation / value;
            if (Maths.almostEquals(value, 0.0d)) {
                d = 0.0d;
            }
            this.probs.setValueAtLocation(i, d);
        }
        Flops.increment(numLocations);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void plusEqualsInternal(DiscreteFactor discreteFactor) {
        int[] largeIdxToSmall = largeIdxToSmall(discreteFactor);
        int numLocations = this.probs.numLocations();
        for (int i = 0; i < numLocations; i++) {
            this.probs.setValueAtLocation(i, this.probs.valueAtLocation(i) + discreteFactor.value(largeIdxToSmall[i]));
        }
        Flops.increment(numLocations);
    }

    protected double rawValue(Assignment assignment) {
        int numVars = getNumVars();
        int[] iArr = new int[numVars];
        for (int i = 0; i < numVars; i++) {
            iArr[i] = assignment.get(getVariable(i));
        }
        return rawValue(iArr);
    }

    private double rawValue(int[] iArr) {
        return rawValue(this.probs.singleIndex(iArr));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected double rawValue(int i) {
        int location = this.probs.location(i);
        if (location < 0) {
            return 0.0d;
        }
        return this.probs.valueAtLocation(location);
    }

    @Override // cc.mallet.grmm.types.Factor
    public void exponentiate(double d) {
        for (int i = 0; i < this.probs.numLocations(); i++) {
            this.probs.setValueAtLocation(i, Math.pow(this.probs.valueAtLocation(i), d));
        }
        Flops.pow(this.probs.numLocations());
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValue(Assignment assignment, double d) {
        Flops.exp();
        setRawValue(assignment, Math.exp(d));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValue(AssignmentIterator assignmentIterator, double d) {
        Flops.exp();
        setRawValue(assignmentIterator, Math.exp(d));
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setValue(AssignmentIterator assignmentIterator, double d) {
        setRawValue(assignmentIterator, d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setLogValues(double[] dArr) {
        Flops.exp(dArr.length);
        for (int i = 0; i < dArr.length; i++) {
            setRawValue(i, Math.exp(dArr[i]));
        }
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void setValues(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            setRawValue(i, dArr[i]);
        }
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public void timesEquals(double d) {
        Flops.increment(this.probs.numLocations());
        this.probs.timesEquals(d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected void plusEqualsAtLocation(int i, double d) {
        Flops.increment(1);
        setRawValue(i, valueAtLocation(i) + d);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public Matrix getValueMatrix() {
        return this.probs;
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public Matrix getLogValueMatrix() {
        Flops.log(this.probs.numLocations());
        Matrix matrix = (Matrix) this.probs.cloneMatrix();
        for (int i = 0; i < this.probs.numLocations(); i++) {
            matrix.setValueAtLocation(i, Math.log(matrix.valueAtLocation(i)));
        }
        return matrix;
    }

    @Override // cc.mallet.grmm.types.DiscreteFactor
    public double valueAtLocation(int i) {
        return this.probs.valueAtLocation(i);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_onevar(Variable variable, Assignment assignment) {
        double[] dArr = new double[variable.getNumOutcomes()];
        for (int i = 0; i < variable.getNumOutcomes(); i++) {
            dArr[i] = value(Assignment.union(new Assignment(variable, i), assignment));
        }
        return new TableFactor(variable, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_twovar(Variable variable, Variable variable2, Assignment assignment) {
        int numOutcomes = variable.getNumOutcomes();
        int numOutcomes2 = variable2.getNumOutcomes();
        int[] iArr = {numOutcomes, numOutcomes2};
        Variable[] variableArr = {variable, variable2};
        int[] iArr2 = new int[2];
        double[] dArr = new double[numOutcomes * numOutcomes2];
        for (int i = 0; i < numOutcomes; i++) {
            iArr2[0] = i;
            for (int i2 = 0; i2 < numOutcomes2; i2++) {
                iArr2[1] = i2;
                dArr[Matrixn.singleIndex(iArr, new int[]{i, i2})] = value(Assignment.union(new Assignment(variableArr, iArr2), assignment));
            }
        }
        return new TableFactor(new Variable[]{variable, variable2}, dArr);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    protected Factor slice_general(Variable[] variableArr, Assignment assignment) {
        HashVarSet hashVarSet = new HashVarSet(variableArr);
        hashVarSet.removeAll(assignment.varSet());
        double[] dArr = new double[hashVarSet.weight()];
        AssignmentIterator assignmentIterator = hashVarSet.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            dArr[assignmentIterator.indexOfCurrentAssn()] = value(Assignment.union(assignment, assignmentIterator.assignment()));
            assignmentIterator.advance();
        }
        return new TableFactor(hashVarSet, dArr);
    }

    public static TableFactor makeFromLogValues(VarSet varSet, double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i]);
        }
        return new TableFactor(varSet, dArr2);
    }

    @Override // cc.mallet.grmm.types.AbstractTableFactor
    public AbstractTableFactor recenter() {
        timesEquals(1.0d / valueAtLocation(argmax()));
        return this;
    }
}
