package smile.validation;

import java.util.Arrays;
import smile.math.Histogram;
import smile.math.Math;
import smile.sort.QuickSort;

/* loaded from: input_file:smile/validation/GroupKFold.class */
public class GroupKFold {
    public final int k;
    public final int[][] train;
    public final int[][] test;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/validation/GroupKFold$TestFolds.class */
    public class TestFolds {
        private final int[] numTestSamplesPerFold;
        private final int[] groupToTestFoldIndex;

        private TestFolds(int[] iArr, int[] iArr2) {
            this.numTestSamplesPerFold = iArr;
            this.groupToTestFoldIndex = iArr2;
        }
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
    public GroupKFold(int i, int i2, int[] iArr) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + i);
        }
        if (i2 < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + i2);
        }
        if (iArr.length != i) {
            throw new IllegalArgumentException("Groups array must be of size n, but length is " + iArr.length);
        }
        int[] unique = Math.unique(iArr);
        int length = unique.length;
        if (i2 > length) {
            throw new IllegalArgumentException("Number of splits mustn't be greater than number of groups");
        }
        Arrays.sort(unique);
        for (int i3 = 0; i3 < length; i3++) {
            if (unique[i3] != i3) {
                throw new IllegalArgumentException("Invalid encoding of groups, all group indices between [0, numGroups) have to exist");
            }
        }
        this.k = i2;
        this.train = new int[i2];
        this.test = new int[i2];
        TestFolds calculateTestFolds = calculateTestFolds(iArr, length);
        for (int i4 = 0; i4 < i2; i4++) {
            this.train[i4] = new int[i - calculateTestFolds.numTestSamplesPerFold[i4]];
            this.test[i4] = new int[calculateTestFolds.numTestSamplesPerFold[i4]];
            int i5 = 0;
            int i6 = 0;
            for (int i7 = 0; i7 < i; i7++) {
                if (calculateTestFolds.groupToTestFoldIndex[iArr[i7]] == i4) {
                    int i8 = i6;
                    i6++;
                    this.test[i4][i8] = i7;
                } else {
                    int i9 = i5;
                    i5++;
                    this.train[i4][i9] = i7;
                }
            }
        }
    }

    private TestFolds calculateTestFolds(int[] iArr, int i) {
        int[] array = Arrays.stream(Histogram.histogram(iArr, i)[2]).mapToInt(d -> {
            return (int) d;
        }).toArray();
        int[] sort = QuickSort.sort(array);
        int[] iArr2 = new int[this.k];
        int[] iArr3 = new int[i];
        for (int i2 = i - 1; i2 >= 0; i2--) {
            int whichMin = Math.whichMin(iArr2);
            iArr2[whichMin] = iArr2[whichMin] + array[i2];
            iArr3[sort[i2]] = whichMin;
        }
        return new TestFolds(iArr2, iArr3);
    }
}
