package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.sort.QuickSort;

/* loaded from: input_file:smile/clustering/XMeans.class */
public class XMeans extends KMeans implements Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) XMeans.class);
    private static final double LOG2PI = Math.log(6.283185307179586d);

    /* JADX WARN: Type inference failed for: r0v109, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v36, types: [double[], double[][]] */
    public XMeans(double[][] dArr, int i) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + i);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        this.k = 1;
        this.size = new int[this.k];
        this.size[0] = length;
        this.y = new int[length];
        this.centroids = new double[this.k][length2];
        for (double[] dArr2 : dArr) {
            for (int i2 = 0; i2 < length2; i2++) {
                double[] dArr3 = this.centroids[0];
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + dArr2[i2];
            }
        }
        for (int i4 = 0; i4 < length2; i4++) {
            double[] dArr4 = this.centroids[0];
            int i5 = i4;
            dArr4[i5] = dArr4[i5] / length;
        }
        double[] dArr5 = new double[this.k];
        for (double[] dArr6 : dArr) {
            dArr5[0] = dArr5[0] + Math.squaredDistance(dArr6, this.centroids[0]);
        }
        this.distortion = dArr5[0];
        logger.info(String.format("X-Means distortion with %d clusters: %.5f", Integer.valueOf(this.k), Double.valueOf(this.distortion)));
        BBDTree bBDTree = new BBDTree(dArr);
        while (this.k < i) {
            ArrayList arrayList = new ArrayList();
            double[] dArr7 = new double[this.k];
            KMeans[] kMeansArr = new KMeans[this.k];
            for (int i6 = 0; i6 < this.k; i6++) {
                if (this.size[i6] < 25) {
                    logger.info("Cluster {} too small to split: {} samples", Integer.valueOf(i6), Integer.valueOf(this.size[i6]));
                } else {
                    ?? r0 = new double[this.size[i6]];
                    int i7 = 0;
                    for (int i8 = 0; i8 < length; i8++) {
                        if (this.y[i8] == i6) {
                            int i9 = i7;
                            i7++;
                            r0[i9] = dArr[i8];
                        }
                    }
                    kMeansArr[i6] = new KMeans((double[][]) r0, 2, 100, 4);
                    double bic = bic(2, this.size[i6], length2, kMeansArr[i6].distortion, kMeansArr[i6].size);
                    double bic2 = bic(this.size[i6], length2, dArr5[i6]);
                    dArr7[i6] = bic - bic2;
                    logger.info(String.format("Cluster %3d\tBIC: %.5f\tBIC after split: %.5f\timprovement: %.5f", Integer.valueOf(i6), Double.valueOf(bic2), Double.valueOf(bic), Double.valueOf(dArr7[i6])));
                }
            }
            int[] sort = QuickSort.sort(dArr7);
            for (int i10 = 0; i10 < this.k; i10++) {
                if (dArr7[sort[i10]] <= 0.0d) {
                    arrayList.add(this.centroids[sort[i10]]);
                }
            }
            int size = arrayList.size();
            int i11 = this.k;
            while (true) {
                i11--;
                if (i11 < 0) {
                    break;
                }
                if (dArr7[i11] > 0.0d) {
                    if (((arrayList.size() + i11) - size) + 1 < i) {
                        logger.info("Split cluster {}", Integer.valueOf(sort[i11]));
                        arrayList.add(kMeansArr[sort[i11]].centroids[0]);
                        arrayList.add(kMeansArr[sort[i11]].centroids[1]);
                    } else {
                        arrayList.add(this.centroids[sort[i11]]);
                    }
                }
            }
            if (arrayList.size() == this.k) {
                return;
            }
            this.k = arrayList.size();
            double[][] dArr8 = new double[this.k][length2];
            this.size = new int[this.k];
            this.centroids = new double[this.k];
            for (int i12 = 0; i12 < this.k; i12++) {
                this.centroids[i12] = (double[]) arrayList.get(i12);
            }
            this.distortion = Double.MAX_VALUE;
            for (int i13 = 0; i13 < 100; i13++) {
                double clustering = bBDTree.clustering(this.centroids, dArr8, this.size, this.y);
                for (int i14 = 0; i14 < this.k; i14++) {
                    if (this.size[i14] > 0) {
                        for (int i15 = 0; i15 < length2; i15++) {
                            this.centroids[i14][i15] = dArr8[i14][i15] / this.size[i14];
                        }
                    }
                }
                if (this.distortion <= clustering) {
                    break;
                }
                this.distortion = clustering;
            }
            dArr5 = new double[this.k];
            for (int i16 = 0; i16 < length; i16++) {
                int i17 = this.y[i16];
                dArr5[i17] = dArr5[i17] + Math.squaredDistance(dArr[i16], this.centroids[this.y[i16]]);
            }
            logger.info(String.format("X-Means distortion with %d clusters: %.5f", Integer.valueOf(this.k), Double.valueOf(this.distortion)));
        }
    }

    private double bic(int i, int i2, double d) {
        return (((((-i) * LOG2PI) + (((-i) * i2) * Math.log(d / (i - 1)))) + (-(i - 1))) / 2.0d) - ((0.5d * (i2 + 1)) * Math.log(i));
    }

    private double bic(int i, int i2, int i3, double d, int[] iArr) {
        double d2 = d / (i2 - i);
        double d3 = 0.0d;
        for (int i4 = 0; i4 < i; i4++) {
            d3 += logLikelihood(i, i2, iArr[i4], i3, d2);
        }
        return d3 - ((0.5d * (i + (i * i3))) * Math.log(i2));
    }

    private static double logLikelihood(int i, int i2, int i3, int i4, double d) {
        double d2 = (-i3) * LOG2PI;
        double log = (-i3) * i4 * Math.log(d);
        double d3 = -(i3 - i);
        return (((d2 + log) + d3) / 2.0d) + (i3 * Math.log(i3)) + ((-i3) * Math.log(i2));
    }

    @Override // smile.clustering.KMeans
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("X-Means distortion: %.5f%n", Double.valueOf(this.distortion)));
        sb.append(String.format("Clusters of %d data points of dimension %d:%n", Integer.valueOf(this.y.length), Integer.valueOf(this.centroids[0].length)));
        for (int i = 0; i < this.k; i++) {
            int round = (int) Math.round((1000.0d * this.size[i]) / this.y.length);
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", Integer.valueOf(i), Integer.valueOf(this.size[i]), Integer.valueOf(round / 10), Integer.valueOf(round % 10)));
        }
        return sb.toString();
    }
}
