package umontreal.ssj.probdistmulti;

import optimization.Uncmin_f77;
import optimization.Uncmin_methods;
import umontreal.ssj.util.Num;

/* loaded from: input_file:umontreal/ssj/probdistmulti/DirichletDist.class */
public class DirichletDist extends ContinuousDistributionMulti {
    private static final double LOGMIN = -709.1d;
    protected double[] alpha;

    /* loaded from: input_file:umontreal/ssj/probdistmulti/DirichletDist$Optim.class */
    private static class Optim implements Uncmin_methods {
        double[] logP;
        int n;
        int k;

        public Optim(double[] dArr, int i) {
            this.n = i;
            this.k = dArr.length;
            this.logP = new double[this.k];
            System.arraycopy(dArr, 0, this.logP, 0, this.k);
        }

        @Override // optimization.Uncmin_methods
        public double f_to_minimize(double[] dArr) {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i = 1; i < dArr.length; i++) {
                if (dArr[i] <= 0.0d) {
                    return 1.0E200d;
                }
                d += dArr[i];
                d2 += Num.lnGamma(dArr[i]);
                d3 += (dArr[i] - 1.0d) * this.logP[i - 1];
            }
            return (-this.n) * ((Num.lnGamma(d) - d2) + d3);
        }

        @Override // optimization.Uncmin_methods
        public void gradient(double[] dArr, double[] dArr2) {
        }

        @Override // optimization.Uncmin_methods
        public void hessian(double[] dArr, double[][] dArr2) {
        }
    }

    public DirichletDist(double[] dArr) {
        setParams(dArr);
    }

    @Override // umontreal.ssj.probdistmulti.ContinuousDistributionMulti
    public double density(double[] dArr) {
        return density_(this.alpha, dArr);
    }

    @Override // umontreal.ssj.probdistmulti.ContinuousDistributionMulti
    public double[] getMean() {
        return getMean_(this.alpha);
    }

    @Override // umontreal.ssj.probdistmulti.ContinuousDistributionMulti
    public double[][] getCovariance() {
        return getCovariance_(this.alpha);
    }

    @Override // umontreal.ssj.probdistmulti.ContinuousDistributionMulti
    public double[][] getCorrelation() {
        return getCorrelation_(this.alpha);
    }

    private static void verifParam(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] <= 0.0d) {
                throw new IllegalArgumentException("alpha[" + i + "] <= 0");
            }
        }
    }

    private static double density_(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException("alpha and x must have the same dimension");
        }
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i];
            d2 += Num.lnGamma(dArr[i]);
            if (dArr2[i] <= 0.0d || dArr2[i] >= 1.0d) {
                return 0.0d;
            }
            d3 += (dArr[i] - 1.0d) * Math.log(dArr2[i]);
        }
        return Math.exp((Num.lnGamma(d) - d2) + d3);
    }

    public static double density(double[] dArr, double[] dArr2) {
        verifParam(dArr);
        return density_(dArr, dArr2);
    }

    private static double[][] getCovariance_(double[] dArr) {
        double[][] dArr2 = new double[dArr.length][dArr.length];
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i][i2] = (-(dArr[i] * dArr[i2])) / ((d * d) * (d + 1.0d));
            }
            dArr2[i][i] = ((dArr[i] / d) * (1.0d - (dArr[i] / d))) / (d + 1.0d);
        }
        return dArr2;
    }

    public static double[][] getCovariance(double[] dArr) {
        verifParam(dArr);
        return getCovariance_(dArr);
    }

    private static double[][] getCorrelation_(double[] dArr) {
        double[][] dArr2 = new double[dArr.length][dArr.length];
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i][i2] = -Math.sqrt((dArr[i] * dArr[i2]) / ((d - dArr[i]) * (d - dArr[i2])));
            }
            dArr2[i][i] = 1.0d;
        }
        return dArr2;
    }

    public static double[][] getCorrelation(double[] dArr) {
        verifParam(dArr);
        return getCorrelation_(dArr);
    }

    public static double[] getMLE(double[][] dArr, int i, int i2) {
        if (i <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("d <= 0");
        }
        double[] dArr2 = new double[i2];
        double[] dArr3 = new double[i2];
        double[] dArr4 = new double[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            dArr2[i3] = 0.0d;
            dArr3[i3] = 0.0d;
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                if (dArr[i4][i5] > 0.0d) {
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + Math.log(dArr[i4][i5]);
                } else {
                    int i7 = i5;
                    dArr2[i7] = dArr2[i7] + LOGMIN;
                }
                int i8 = i5;
                dArr3[i8] = dArr3[i8] + dArr[i4][i5];
            }
        }
        for (int i9 = 0; i9 < i2; i9++) {
            int i10 = i9;
            dArr2[i10] = dArr2[i10] / i;
            int i11 = i9;
            dArr3[i11] = dArr3[i11] / i;
        }
        for (int i12 = 0; i12 < i2; i12++) {
            double d = 0.0d;
            for (int i13 = 0; i13 < i; i13++) {
                d += (dArr[i13][i12] - dArr3[i12]) * (dArr[i13][i12] - dArr3[i12]);
            }
            dArr4[i12] = d / i;
        }
        double d2 = ((dArr3[0] * (1.0d - dArr3[0])) / dArr4[0]) - 1.0d;
        Optim optim = new Optim(dArr2, i);
        double[] dArr5 = new double[i2];
        double[] dArr6 = new double[i2 + 1];
        double[] dArr7 = new double[i2 + 1];
        double[] dArr8 = new double[i2 + 1];
        double[] dArr9 = new double[i2 + 1];
        int[] iArr = new int[2];
        double[][] dArr10 = new double[i2 + 1][i2 + 1];
        double[] dArr11 = new double[i2 + 1];
        for (int i14 = 1; i14 <= i2; i14++) {
            dArr7[i14] = dArr3[i14 - 1] * d2;
        }
        Uncmin_f77.optif0_f77(i2, dArr7, optim, dArr6, dArr8, dArr9, iArr, dArr10, dArr11);
        for (int i15 = 0; i15 < i2; i15++) {
            dArr5[i15] = dArr6[i15 + 1];
        }
        return dArr5;
    }

    private static double[] getMean_(double[] dArr) {
        double d = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = dArr[i] / d;
        }
        return dArr2;
    }

    public static double[] getMean(double[] dArr) {
        verifParam(dArr);
        return getMean_(dArr);
    }

    public double[] getAlpha() {
        return this.alpha;
    }

    public double getAlpha(int i) {
        return this.alpha[i];
    }

    public void setParams(double[] dArr) {
        this.dimension = dArr.length;
        this.alpha = new double[this.dimension];
        for (int i = 0; i < this.dimension; i++) {
            if (dArr[i] <= 0.0d) {
                throw new IllegalArgumentException("alpha[" + i + "] <= 0");
            }
            this.alpha[i] = dArr[i];
        }
    }
}
