package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.BetaFactor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.ModelReader;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import gnu.trove.TDoubleArrayList;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:cc/mallet/grmm/test/TestBetaFactor.class */
public class TestBetaFactor extends TestCase {
    static String mdlstr = "VAR u1 u2 : continuous\nu1 ~ Beta 0.2 0.7\nu2 ~ Beta 1.0 0.3\n";

    public TestBetaFactor(String str) {
        super(str);
    }

    public void testVarSet() {
        Variable variable = new Variable(-1);
        BetaFactor betaFactor = new BetaFactor(variable, 0.5d, 0.5d);
        assertEquals(1, betaFactor.varSet().size());
        assertTrue(betaFactor.varSet().contains(variable));
    }

    public void testValue() {
        Variable variable = new Variable(-1);
        assertEquals(0.94321d, new BetaFactor(variable, 1.0d, 1.2d).value(new Assignment(variable, 0.7d)), 1.0E-5d);
    }

    public void testSample() {
        Variable variable = new Variable(-1);
        Randoms randoms = new Randoms(2343);
        BetaFactor betaFactor = new BetaFactor(variable, 0.7d, 0.5d);
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        for (int i = 0; i < 100000; i++) {
            tDoubleArrayList.add(betaFactor.sample(randoms).getDouble(variable));
        }
        assertEquals(0.5833333333333334d, MatrixOps.mean(tDoubleArrayList.toNativeArray()), 0.01d);
    }

    public void testSample2() {
        Variable variable = new Variable(-1);
        Randoms randoms = new Randoms(2343);
        BetaFactor betaFactor = new BetaFactor(variable, 0.7d, 0.5d, 3.0d, 8.0d);
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        for (int i = 0; i < 100000; i++) {
            tDoubleArrayList.add(betaFactor.sample(randoms).getDouble(variable));
        }
        assertEquals(5.92d, MatrixOps.mean(tDoubleArrayList.toNativeArray()), 0.01d);
    }

    public void testSliceInFg() throws IOException {
        FactorGraph readModel = new ModelReader().readModel(new BufferedReader(new StringReader(mdlstr)));
        FactorGraph factorGraph = (FactorGraph) readModel.slice(new Assignment(new Variable[]{readModel.findVariable("u1"), readModel.findVariable("u2")}, new double[]{0.25d, 0.85d}));
        assertEquals(2, factorGraph.factors().size());
        assertEquals(0.6708463722d, factorGraph.value(new Assignment()), 1.0E-5d);
    }

    public static TestSuite suite() {
        return new TestSuite(TestBetaFactor.class);
    }

    public static void main(String[] strArr) {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestBetaFactor(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }
}
