package smile.manifold;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.stat.distribution.GaussianDistribution;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:smile/manifold/TSNE.class */
public class TSNE {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) TSNE.class);
    private double[][] coordinates;
    private double eta;
    private double momentum;
    private double finalMomentum;
    private int momentumSwitchIter;
    private double minGain;
    private int totalIter;
    private double[][] D;
    private double[][] dY;
    private double[][] gains;
    private double[][] P;
    private double[][] Q;
    private double Qsum;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/manifold/TSNE$PerplexityTask.class */
    public class PerplexityTask implements Callable<Void> {
        int start;
        int end;
        double[][] D;
        double[][] P;
        double[] DiSum;
        double perplexity;
        double tol;

        PerplexityTask(int i, int i2, double[][] dArr, double[][] dArr2, double[] dArr3, double d, double d2) {
            this.start = i;
            this.end = i2;
            this.D = dArr;
            this.P = dArr2;
            this.DiSum = dArr3;
            this.perplexity = d;
            this.tol = d2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() {
            for (int i = this.start; i < this.end; i++) {
                compute(i);
            }
            return null;
        }

        private void compute(int i) {
            int length = this.D.length;
            double log2 = Math.log2(this.perplexity);
            double[] dArr = this.P[i];
            double[] dArr2 = this.D[i];
            double sqrt = Math.sqrt((length - 1) / this.DiSum[i]);
            double d = 0.0d;
            double d2 = Double.POSITIVE_INFINITY;
            TSNE.logger.debug("initial beta[{}] = {}", Integer.valueOf(i), Double.valueOf(sqrt));
            double d3 = Double.MAX_VALUE;
            for (int i2 = 0; Math.abs(d3) > this.tol && i2 < 50; i2++) {
                double d4 = 0.0d;
                double d5 = 0.0d;
                for (int i3 = 0; i3 < length; i3++) {
                    double d6 = sqrt * dArr2[i3];
                    double exp = Math.exp(-d6);
                    dArr[i3] = exp;
                    d4 += exp;
                    d5 += exp * d6;
                }
                dArr[i] = 0.0d;
                double d7 = d4 - 1.0d;
                double log22 = Math.log2(d7) + (d5 / d7);
                d3 = log22 - log2;
                if (Math.abs(d3) <= this.tol) {
                    for (int i4 = 0; i4 < length; i4++) {
                        int i5 = i4;
                        dArr[i5] = dArr[i5] / d7;
                    }
                } else if (d3 > 0.0d) {
                    d = sqrt;
                    sqrt = Double.isInfinite(d2) ? sqrt * 2.0d : (sqrt + d2) / 2.0d;
                } else {
                    d2 = sqrt;
                    sqrt = (sqrt + d) / 2.0d;
                }
                TSNE.logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", Double.valueOf(d3), Integer.valueOf(i), Double.valueOf(sqrt), Double.valueOf(log22), Double.valueOf(log2));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/manifold/TSNE$SNETask.class */
    public class SNETask implements Callable<Void> {
        int start;
        int end;
        double[] dC;

        SNETask(int i, int i2) {
            this.start = i;
            this.end = i2;
            this.dC = new double[TSNE.this.coordinates[0].length];
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() {
            for (int i = this.start; i < this.end; i++) {
                compute(i);
            }
            return null;
        }

        private void compute(int i) {
            double[][] dArr = TSNE.this.coordinates;
            int length = dArr.length;
            int length2 = dArr[0].length;
            Arrays.fill(this.dC, 0.0d);
            double[] dArr2 = dArr[i];
            double[] dArr3 = TSNE.this.P[i];
            double[] dArr4 = TSNE.this.Q[i];
            double[] dArr5 = TSNE.this.dY[i];
            double[] dArr6 = TSNE.this.gains[i];
            for (int i2 = 0; i2 < length; i2++) {
                if (i != i2) {
                    double[] dArr7 = dArr[i2];
                    double d = dArr4[i2];
                    double d2 = (dArr3[i2] - (d / TSNE.this.Qsum)) * d;
                    for (int i3 = 0; i3 < length2; i3++) {
                        double[] dArr8 = this.dC;
                        int i4 = i3;
                        dArr8[i4] = dArr8[i4] + (4.0d * (dArr2[i3] - dArr7[i3]) * d2);
                    }
                }
            }
            for (int i5 = 0; i5 < length2; i5++) {
                dArr6[i5] = Math.signum(this.dC[i5]) != Math.signum(dArr5[i5]) ? dArr6[i5] + 0.2d : dArr6[i5] * 0.8d;
                if (dArr6[i5] < TSNE.this.minGain) {
                    dArr6[i5] = TSNE.this.minGain;
                }
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + dArr5[i5];
                dArr5[i5] = (TSNE.this.momentum * dArr5[i5]) - ((TSNE.this.eta * dArr6[i5]) * this.dC[i5]);
            }
        }
    }

    public TSNE(double[][] dArr, int i) {
        this(dArr, i, 20.0d, 200.0d, 1000);
    }

    public TSNE(double[][] dArr, int i, double d, double d2, int i2) {
        this.eta = 500.0d;
        this.momentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.momentumSwitchIter = 250;
        this.minGain = 0.01d;
        this.totalIter = 1;
        this.eta = d2;
        int length = dArr.length;
        if (dArr.length == dArr[0].length) {
            this.D = dArr;
        } else {
            this.D = new double[length][length];
            Math.pdist(dArr, this.D, true, false);
        }
        this.coordinates = new double[length][i];
        double[][] dArr2 = this.coordinates;
        this.dY = new double[length][i];
        this.gains = new double[length][i];
        GaussianDistribution gaussianDistribution = new GaussianDistribution(0.0d, 1.0E-4d);
        for (int i3 = 0; i3 < length; i3++) {
            Arrays.fill(this.gains[i3], 1.0d);
            double[] dArr3 = dArr2[i3];
            for (int i4 = 0; i4 < i; i4++) {
                dArr3[i4] = gaussianDistribution.rand();
            }
        }
        this.P = expd(this.D, d, 0.001d);
        this.Q = new double[length][length];
        double d3 = 2 * length;
        for (int i5 = 0; i5 < length; i5++) {
            double[] dArr4 = this.P[i5];
            for (int i6 = 0; i6 < i5; i6++) {
                double d4 = (12.0d * (dArr4[i6] + this.P[i6][i5])) / d3;
                if (Double.isNaN(d4) || d4 < 1.0E-16d) {
                    d4 = 1.0E-16d;
                }
                dArr4[i6] = d4;
                this.P[i6][i5] = d4;
            }
        }
        learn(i2);
    }

    public void learn(int i) {
        double[][] dArr = this.coordinates;
        int length = dArr.length;
        int length2 = dArr[0].length;
        int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
        int i2 = length / threadPoolSize;
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        while (i3 < threadPoolSize) {
            arrayList.add(new SNETask(i3 * i2, i3 == threadPoolSize - 1 ? length : (i3 + 1) * i2));
            i3++;
        }
        int i4 = 1;
        while (i4 <= i) {
            Math.pdist(dArr, this.Q, true, false);
            this.Qsum = 0.0d;
            for (int i5 = 0; i5 < length; i5++) {
                double[] dArr2 = this.Q[i5];
                for (int i6 = 0; i6 < i5; i6++) {
                    double d = 1.0d / (1.0d + dArr2[i6]);
                    dArr2[i6] = d;
                    this.Q[i6][i5] = d;
                    this.Qsum += d;
                }
            }
            this.Qsum *= 2.0d;
            try {
                MulticoreExecutor.run(arrayList);
            } catch (Exception e) {
                logger.error("t-SNE iteration task fails: {}", (Throwable) e);
            }
            if (this.totalIter == this.momentumSwitchIter) {
                this.momentum = this.finalMomentum;
                for (int i7 = 0; i7 < length; i7++) {
                    double[] dArr3 = this.P[i7];
                    for (int i8 = 0; i8 < length; i8++) {
                        int i9 = i8;
                        dArr3[i9] = dArr3[i9] / 12.0d;
                    }
                }
            }
            if (i4 % 50 == 0) {
                double d2 = 0.0d;
                for (int i10 = 0; i10 < length; i10++) {
                    double[] dArr4 = this.P[i10];
                    double[] dArr5 = this.Q[i10];
                    for (int i11 = 0; i11 < i10; i11++) {
                        double d3 = dArr4[i11];
                        double d4 = dArr5[i11] / this.Qsum;
                        if (Double.isNaN(d4) || d4 < 1.0E-16d) {
                            d4 = 1.0E-16d;
                        }
                        d2 += d3 * Math.log2(d3 / d4);
                    }
                }
                logger.info("Error after {} iterations: {}", Integer.valueOf(this.totalIter), Double.valueOf(2.0d * d2));
            }
            i4++;
            this.totalIter++;
        }
        double[] colMeans = Math.colMeans(dArr);
        for (double[] dArr6 : dArr) {
            for (int i12 = 0; i12 < length2; i12++) {
                int i13 = i12;
                dArr6[i13] = dArr6[i13] - colMeans[i12];
            }
        }
    }

    private double[][] expd(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        double[] rowSums = Math.rowSums(dArr);
        int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
        int i = length / threadPoolSize;
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        while (i2 < threadPoolSize) {
            arrayList.add(new PerplexityTask(i2 * i, i2 == threadPoolSize - 1 ? length : (i2 + 1) * i, dArr, dArr2, rowSums, d, d2));
            i2++;
        }
        try {
            MulticoreExecutor.run(arrayList);
        } catch (Exception e) {
            logger.error("t-SNE Gaussian kernel width search task fails: {}", (Throwable) e);
        }
        return dArr2;
    }

    public double[][] getCoordinates() {
        return this.coordinates;
    }
}
