package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.math.Math;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;

/* loaded from: input_file:smile/classification/KNN.class */
public class KNN<T> implements SoftClassifier<T>, Serializable {
    private static final long serialVersionUID = 1;
    private KNNSearch<T, T> knn;
    private int[] y;
    private int k;
    private int c;

    /* loaded from: input_file:smile/classification/KNN$Trainer.class */
    public static class Trainer<T> extends ClassifierTrainer<T> {
        private int k;
        private Distance<T> distance;

        public Trainer(Distance<T> distance, int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid k of k-NN: " + i);
            }
            this.distance = distance;
            this.k = i;
        }

        @Override // smile.classification.ClassifierTrainer
        public KNN<T> train(T[] tArr, int[] iArr) {
            return new KNN<>(tArr, iArr, this.distance, this.k);
        }
    }

    public KNN(KNNSearch<T, T> kNNSearch, int[] iArr, int i) {
        this.knn = kNNSearch;
        this.k = i;
        this.y = iArr;
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i2 = 0; i2 < unique.length; i2++) {
            if (unique[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i2]);
            }
            if (i2 > 0 && unique[i2] - unique[i2 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i2] + 1);
            }
        }
        this.c = unique.length;
        if (this.c < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
    }

    public KNN(T[] tArr, int[] iArr, Distance<T> distance) {
        this(tArr, iArr, distance, 1);
    }

    public KNN(T[] tArr, int[] iArr, Distance<T> distance, int i) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Illegal k = " + i);
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i2 = 0; i2 < unique.length; i2++) {
            if (unique[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i2]);
            }
            if (i2 > 0 && unique[i2] - unique[i2 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i2] + 1);
            }
        }
        this.c = unique.length;
        if (this.c < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.y = iArr;
        this.k = i;
        if (distance instanceof Metric) {
            this.knn = new CoverTree(tArr, (Metric) distance);
        } else {
            this.knn = new LinearSearch(tArr, distance);
        }
    }

    public static KNN<double[]> learn(double[][] dArr, int[] iArr) {
        return learn(dArr, iArr, 1);
    }

    public static KNN<double[]> learn(double[][] dArr, int[] iArr, int i) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Illegal k = " + i);
        }
        return new KNN<>(dArr[0].length < 10 ? new KDTree(dArr, dArr) : new CoverTree(dArr, new EuclideanDistance()), iArr, i);
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        return predict(t, null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(T t, double[] dArr) {
        if (dArr != null && dArr.length != this.c) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.c)));
        }
        Neighbor<T, T>[] knn = this.knn.knn(t, this.k);
        if (this.k == 1) {
            return this.y[knn[0].index];
        }
        int[] iArr = new int[this.c];
        for (int i = 0; i < this.k; i++) {
            int i2 = this.y[knn[i].index];
            iArr[i2] = iArr[i2] + 1;
        }
        if (dArr != null) {
            for (int i3 = 0; i3 < this.c; i3++) {
                dArr[i3] = iArr[i3] / this.k;
            }
        }
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < this.c; i6++) {
            if (iArr[i6] > i4) {
                i4 = iArr[i6];
                i5 = i6;
            }
        }
        return i5;
    }
}
