package com.rapidminer.operator.bahsic;

import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel;
import java.util.Vector;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;

/* loaded from: input_file:com/rapidminer/operator/bahsic/HSICEstimator.class */
public class HSICEstimator {
    Kernel kernelx;
    Kernel kernely;
    RealMatrix lMatrix;
    RealMatrix y;
    int yrows;
    int ycols;
    double ltotalSum;
    Vector<Double> lrowSums = new Vector<>();
    Vector<Double> lcolSums = new Vector<>();

    public HSICEstimator(Kernel kernel, Kernel kernel2, RealMatrix realMatrix, boolean z) {
        this.kernelx = kernel;
        this.kernely = kernel2;
        this.y = realMatrix;
        this.yrows = realMatrix.getRowDimension();
        this.ycols = realMatrix.getColumnDimension();
        Vector<Boolean> vector = new Vector<>();
        for (int i = 0; i < this.ycols; i++) {
            vector.add(true);
        }
        this.lMatrix = computeKernelMatrix(realMatrix, kernel2, vector);
        if (!z) {
            for (int i2 = 0; i2 < this.yrows; i2++) {
                this.lMatrix.setEntry(i2, i2, 0.0d);
            }
        }
        this.ltotalSum = 0.0d;
        for (int i3 = 0; i3 < this.yrows; i3++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.yrows; i4++) {
                d += this.lMatrix.getEntry(i3, i4);
                d2 += this.lMatrix.getEntry(i4, i3);
            }
            this.ltotalSum += d;
            this.lrowSums.add(Double.valueOf(d));
            this.lcolSums.add(Double.valueOf(d2));
        }
    }

    public RealMatrix computeHLH() throws OperatorException {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        for (int i = 0; i < this.yrows; i++) {
            createRealMatrix.setEntry(i, i, 1.0d - (1.0d / this.yrows));
        }
        for (int i2 = 0; i2 < this.yrows; i2++) {
            for (int i3 = 0; i3 < i2; i3++) {
                createRealMatrix.setEntry(i2, i3, (-1.0d) / this.yrows);
                createRealMatrix.setEntry(i3, i2, (-1.0d) / this.yrows);
            }
        }
        return createRealMatrix.multiply(this.lMatrix.multiply(createRealMatrix));
    }

    public RealMatrix computeHLHFast() throws OperatorException {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        RealMatrix createRealMatrix2 = MatrixUtils.createRealMatrix(this.yrows, this.yrows);
        for (int i = 0; i < this.yrows; i++) {
            for (int i2 = 0; i2 < this.yrows; i2++) {
                createRealMatrix.setEntry(i, i2, this.lrowSums.get(i).doubleValue() + this.lcolSums.get(i2).doubleValue());
                createRealMatrix2.setEntry(i, i2, this.ltotalSum);
            }
        }
        return this.lMatrix.subtract(createRealMatrix.scalarMultiply(1.0d / this.yrows)).add(createRealMatrix2.scalarMultiply(1.0d / (this.yrows * this.yrows)));
    }

    public RealMatrix computeKernelMatrix(RealMatrix realMatrix, Kernel kernel, Vector<Boolean> vector) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(rowDimension, rowDimension);
        int[] iArr = new int[columnDimension];
        for (int i = 0; i < columnDimension; i++) {
            iArr[i] = i;
        }
        int[] yIndices = getYIndices(vector);
        for (int i2 = 0; i2 < rowDimension; i2++) {
            for (int i3 = 0; i3 <= i2; i3++) {
                double calculate_K = kernel.calculate_K(iArr, realMatrix.getRow(i2), yIndices, realMatrix.getRow(i3));
                createRealMatrix.setEntry(i2, i3, calculate_K);
                createRealMatrix.setEntry(i3, i2, calculate_K);
            }
        }
        return createRealMatrix;
    }

    public double biasedHSIC(RealMatrix realMatrix, RealMatrix realMatrix2) throws OperatorException {
        int rowDimension = realMatrix.getRowDimension();
        if (rowDimension == realMatrix.getColumnDimension() && rowDimension == realMatrix2.getRowDimension() && rowDimension == realMatrix2.getColumnDimension()) {
            return realMatrix.multiply(realMatrix2).getTrace() / ((rowDimension - 1) * (rowDimension - 1));
        }
        throw new OperatorException("Dimensions do not match");
    }

    public double unbiasedHSIC(RealMatrix realMatrix) throws OperatorException {
        int rowDimension = realMatrix.getRowDimension();
        if (rowDimension != realMatrix.getColumnDimension() || rowDimension != this.yrows) {
            throw new OperatorException("Dimensions do not match");
        }
        for (int i = 0; i < rowDimension; i++) {
            realMatrix.setEntry(i, i, 0.0d);
        }
        RealMatrix multiply = realMatrix.multiply(this.lMatrix);
        double trace = multiply.getTrace();
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < rowDimension; i2++) {
            for (int i3 = 0; i3 < rowDimension; i3++) {
                d += realMatrix.getEntry(i2, i3);
                d2 += multiply.getEntry(i2, i3);
            }
        }
        return ((trace + ((d * this.ltotalSum) / ((rowDimension - 1) * (rowDimension - 2)))) - ((2.0d * d2) / (rowDimension - 2))) / (rowDimension * (rowDimension - 3));
    }

    public RealMatrix outerProduct(RealMatrix realMatrix) throws OperatorException {
        if (realMatrix.getColumnDimension() != 1) {
            throw new OperatorException("Vector expected, not a matrix");
        }
        int rowDimension = realMatrix.getRowDimension();
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(rowDimension, rowDimension);
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                double entry = realMatrix.getEntry(i, 0);
                double entry2 = realMatrix.getEntry(i2, 0);
                createRealMatrix.setEntry(i, i2, entry * entry2);
                createRealMatrix.setEntry(i2, i, entry2 * entry);
            }
        }
        return createRealMatrix;
    }

    public int[] getYIndices(Vector<Boolean> vector) {
        int size = vector.size();
        int[] iArr = new int[size];
        for (int i = size - 1; i >= 0; i--) {
            if (vector.get(i).booleanValue()) {
                iArr[i] = i;
            } else {
                iArr[i] = size + 1;
            }
        }
        return iArr;
    }
}
