package com.rapidminer.operator.lasso;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.Vector;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;

/* loaded from: input_file:com/rapidminer/operator/lasso/LARSModel.class */
public class LARSModel extends PredictionModel {
    private static final long serialVersionUID = -6112829333480866927L;
    protected AttributeWeights weights;
    protected Vector<RealVector> beta;
    protected Vector<Double> gamma_A;
    protected Vector<Double> C;
    protected double avgY;
    protected double threshold;
    protected boolean lasso;
    private int solution;

    LARSModel(ExampleSet exampleSet) throws OperatorException {
        this(exampleSet, false, 1.0d, 1.0E-4d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LARSModel(ExampleSet exampleSet, boolean z, double d, double d2) throws OperatorException {
        super(exampleSet);
        this.threshold = d;
        this.lasso = z;
        int size = exampleSet.getAttributes().size();
        int size2 = exampleSet.size();
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(size2, size);
        int i = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            int i2 = 0;
            Iterator it = exampleSet.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                createRealMatrix.setEntry(i3, i, ((Example) it.next()).getNumericalValue(attribute));
            }
            i++;
        }
        RealMatrix createRealMatrix2 = MatrixUtils.createRealMatrix(size2, 1);
        int i4 = 0;
        if (exampleSet.getAttributes().getLabel().isNumerical()) {
            Iterator it2 = exampleSet.iterator();
            while (it2.hasNext()) {
                int i5 = i4;
                i4++;
                createRealMatrix2.setEntry(i5, 0, ((Example) it2.next()).getLabel());
            }
        } else {
            Iterator it3 = exampleSet.iterator();
            while (it3.hasNext()) {
                int i6 = i4;
                i4++;
                createRealMatrix2.setEntry(i6, 0, 2.0d * (((Example) it3.next()).getLabel() - 0.5d));
            }
        }
        double[] dArr = new double[size];
        TreeSet<Integer> treeSet = new TreeSet();
        int i7 = 0;
        int i8 = -1;
        double d3 = 0.0d;
        Vector vector = new Vector();
        this.beta = new Vector<>();
        Vector vector2 = new Vector();
        this.C = new Vector<>();
        this.gamma_A = new Vector<>();
        vector.add(MatrixUtils.createRealMatrix(size2, 1));
        RealVector createRealVector = MatrixUtils.createRealVector(new double[size]);
        this.beta.add(createRealVector);
        while (true) {
            RealMatrix multiply = createRealMatrix.transpose().multiply(createRealMatrix2.subtract((RealMatrix) vector.get(i7)));
            vector2.add(multiply);
            double d4 = 0.0d;
            for (int i9 = 0; i9 < size; i9++) {
                if (Math.abs(multiply.getEntry(i9, 0)) > d4) {
                    d4 = Math.abs(multiply.getEntry(i9, 0));
                }
            }
            this.C.add(Double.valueOf(d4));
            if (!z || i8 < 0) {
                for (int i10 = 0; i10 < size; i10++) {
                    if (Math.abs(multiply.getEntry(i10, 0)) >= d4 - d2) {
                        treeSet.add(Integer.valueOf(i10));
                    }
                }
            } else {
                treeSet.remove(Integer.valueOf(i8));
            }
            int size3 = treeSet.size();
            for (Integer num : treeSet) {
                dArr[num.intValue()] = Math.signum(multiply.getEntry(num.intValue(), 0));
            }
            if (treeSet.isEmpty()) {
                break;
            }
            RealMatrix createRealMatrix3 = MatrixUtils.createRealMatrix(size2, size3);
            int i11 = 0;
            for (Integer num2 : treeSet) {
                for (int i12 = 0; i12 < size2; i12++) {
                    createRealMatrix3.setEntry(i12, i11, createRealMatrix.getEntry(i12, num2.intValue()) * dArr[num2.intValue()]);
                }
                i11++;
            }
            RealMatrix inverse = new LUDecompositionImpl(createRealMatrix3.transpose().multiply(createRealMatrix3)).getSolver().getInverse();
            double d5 = 0.0d;
            for (int i13 = 0; i13 < size3; i13++) {
                for (int i14 = 0; i14 < size3; i14++) {
                    d5 += inverse.getEntry(i13, i14);
                }
            }
            if (d5 <= 0.0d) {
                throw new OperatorException("Sum of inverse G_A matrix ist not positive.");
            }
            double sqrt = 1.0d / Math.sqrt(d5);
            RealMatrix createRealMatrix4 = MatrixUtils.createRealMatrix(size3, 1);
            for (int i15 = 0; i15 < size3; i15++) {
                double d6 = 0.0d;
                for (int i16 = 0; i16 < size3; i16++) {
                    d6 += inverse.getEntry(i15, i16);
                }
                createRealMatrix4.setEntry(i15, 0, d6 * sqrt);
            }
            RealMatrix multiply2 = createRealMatrix3.multiply(createRealMatrix4);
            RealMatrix multiply3 = createRealMatrix.transpose().multiply(multiply2);
            double d7 = Double.POSITIVE_INFINITY;
            for (int i17 = 0; i17 < size; i17++) {
                if (!treeSet.contains(Integer.valueOf(i17)) || treeSet.size() == size) {
                    double entry = (d4 - multiply.getEntry(i17, 0)) / (sqrt - multiply3.getEntry(i17, 0));
                    double entry2 = (d4 + multiply.getEntry(i17, 0)) / (sqrt + multiply3.getEntry(i17, 0));
                    if (entry > 0.0d && entry < d7) {
                        d7 = entry;
                    }
                    if (entry2 > 0.0d && entry2 < d7) {
                        d7 = entry2;
                    }
                }
            }
            if (z) {
                Vector vector3 = new Vector(size);
                int i18 = 0;
                for (int i19 = 0; i19 < size; i19++) {
                    if (treeSet.contains(Integer.valueOf(i19))) {
                        int i20 = i18;
                        i18++;
                        vector3.add(Double.valueOf(((-1.0d) * createRealVector.getEntry(i19)) / (dArr[i19] * createRealMatrix4.getEntry(i20, 0))));
                    } else {
                        vector3.add(Double.valueOf(0.0d));
                    }
                }
                double d8 = Double.POSITIVE_INFINITY;
                i8 = -1;
                for (int i21 = 0; i21 < size; i21++) {
                    double doubleValue = ((Double) vector3.get(i21)).doubleValue();
                    if (doubleValue > 0.0d && doubleValue < d8) {
                        d8 = doubleValue;
                        i8 = i21;
                    }
                }
                if (d8 >= d7 || i8 < 0) {
                    i8 = -1;
                } else {
                    d7 = d8;
                }
            }
            this.gamma_A.add(Double.valueOf(d7));
            vector.add(((RealMatrix) vector.get(i7)).add(multiply2.scalarMultiply(d7)));
            ArrayRealVector arrayRealVector = new ArrayRealVector(size);
            int i22 = 0;
            for (int i23 = 0; i23 < size; i23++) {
                if (treeSet.contains(Integer.valueOf(i23))) {
                    arrayRealVector.setEntry(i23, createRealVector.getEntry(i23) + (createRealMatrix4.getEntry(i22, 0) * d7 * dArr[i23]));
                    i22++;
                }
            }
            this.beta.add(arrayRealVector);
            createRealVector = arrayRealVector;
            d3 = arrayRealVector.getL1Norm();
            i7++;
            if (d7 <= d2 || Double.isInfinite(d7) || ((!this.lasso && (i7 > size2 || i7 > size)) || (d > 0.0d && d < d3))) {
                break;
            }
        }
        this.solution = i7;
        if (d > 0.0d && d < d3) {
            this.solution = 0;
            for (int i24 = 1; i24 < i7; i24++) {
                if (this.beta.get(i24).getL1Norm() <= d) {
                    this.solution = i24;
                }
            }
            this.beta.add(this.solution + 1, this.beta.get(this.solution).add(this.beta.get(this.solution + 1).subtract(this.beta.get(this.solution)).mapMultiply((d - this.beta.get(this.solution).getL1Norm()) / (this.beta.get(this.solution + 1).getL1Norm() - this.beta.get(this.solution).getL1Norm()))));
            this.solution++;
        }
        this.weights = new AttributeWeights(exampleSet);
        int i25 = 0;
        Iterator it4 = exampleSet.getAttributes().iterator();
        while (it4.hasNext()) {
            this.weights.setWeight(((Attribute) it4.next()).getName(), this.beta.get(this.solution).getEntry(i25));
            i25++;
        }
    }

    public void changeModel(double d, int i) {
        int size = this.beta.size();
        if (i > 0) {
            size = 0;
            for (int i2 = 1; i2 < this.beta.size(); i2++) {
                if (this.beta.get(i2).mapSignum().getL1Norm() <= i) {
                    size = i2;
                    this.threshold = this.beta.get(i2).getL1Norm();
                }
            }
        } else {
            this.threshold = d;
            if (d > 0.0d) {
                size = 0;
                for (int i3 = 1; i3 < this.beta.size(); i3++) {
                    if (this.beta.get(i3).getL1Norm() <= d) {
                        size = i3;
                    }
                }
            }
        }
        this.weights = (AttributeWeights) this.weights.clone();
        int i4 = 0;
        Iterator it = this.weights.getAttributeNames().iterator();
        while (it.hasNext()) {
            this.weights.setWeight((String) it.next(), this.beta.get(size).getEntry(i4));
            i4++;
        }
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(super.toString() + Tools.getLineSeparator() + Tools.getLineSeparator());
        stringBuffer.append((this.beta.size() - 1) + " iterations in total, using BETA of iteration no. " + this.solution + Tools.getLineSeparator() + Tools.getLineSeparator());
        for (String str : this.weights.getAttributeNames()) {
            stringBuffer.append(str + " * " + this.weights.getWeight(str) + " + " + Tools.getLineSeparator());
        }
        stringBuffer.append("0 (bias)");
        return stringBuffer.toString();
    }

    public AttributeWeights getWeights() {
        return this.weights;
    }

    public int size() {
        if (this.weights == null) {
            return 0;
        }
        return this.weights.getSize();
    }

    public Vector<RealVector> getBeta() {
        return this.beta;
    }

    public Vector<Double> getGamma() {
        return this.gamma_A;
    }

    public Vector<Double> getC() {
        return this.C;
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) throws OperatorException {
        boolean z = !exampleSet.getAttributes().getLabel().isNumerical();
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            double d = 0.0d;
            for (Attribute attribute2 : exampleSet.getAttributes()) {
                d += example.getValue(attribute2) * this.weights.getWeight(attribute2.getName());
            }
            if (z) {
                d = d > 0.0d ? 1.0d : 0.0d;
            }
            example.setValue(attribute, d);
        }
        return exampleSet;
    }
}
