package com.rapidminer.extension.converters.operator.model;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.table.MixedRowWriter;
import com.rapidminer.belt.table.Writers;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.utils.ExampleSetBuilder;
import com.rapidminer.example.utils.ExampleSets;
import com.rapidminer.h2o.model.GeneralizedLinearModel;
import com.rapidminer.h2o.model.GeneralizedLinearModel_v2;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.LinearRegressionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/extension/converters/operator/model/LinearRegressionModel2ExampleSet.class */
public class LinearRegressionModel2ExampleSet extends Operator {
    private InputPort modelInputPort;
    private OutputPort exampleSetOutputPort;
    private OutputPort originalModelOutputPort;
    List<String> attNames;
    List<Column.TypeId> attTypes;

    public LinearRegressionModel2ExampleSet(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.modelInputPort = getInputPorts().createPort("linear model input", PredictionModel.class);
        this.exampleSetOutputPort = getOutputPorts().createPort("example set");
        this.originalModelOutputPort = getOutputPorts().createPassThroughPort("original model output");
        this.attNames = Arrays.asList("Attribute", "Coefficient", "Std. Coefficients", "Std. Error", "z-Value", "p-Value");
        this.attTypes = Arrays.asList(Column.TypeId.NOMINAL, Column.TypeId.REAL, Column.TypeId.REAL, Column.TypeId.REAL, Column.TypeId.REAL, Column.TypeId.REAL);
        this.modelInputPort.addPrecondition(new SimplePrecondition(this.modelInputPort, new MetaData(PredictionModel.class)));
        getTransformer().addGenerationRule(this.exampleSetOutputPort, ExampleSet.class);
        getTransformer().addPassThroughRule(this.modelInputPort, this.originalModelOutputPort);
        getTransformer().addRule(() -> {
            String typeName = this.modelInputPort.getMetaData().getObjectClass().getTypeName();
            ExampleSetMetaData exampleSetMetaData = new ExampleSetMetaData();
            exampleSetMetaData.addAttribute(new AttributeMetaData("Attribute", 1));
            exampleSetMetaData.addAttribute(new AttributeMetaData("Coefficient", 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData("Std. Error", 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData("Std. Coefficients", 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData("Tolerance", 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData("p-Value", 2));
            if (typeName.equals("com.rapidminer.operator.learner.functions.LinearRegressionModel")) {
                exampleSetMetaData.addAttribute(new AttributeMetaData("t-Stat", 2));
                exampleSetMetaData.addAttribute(new AttributeMetaData("Code", 1));
            } else {
                exampleSetMetaData.addAttribute(new AttributeMetaData("z-Value", 2));
            }
            this.exampleSetOutputPort.deliverMD(exampleSetMetaData);
        });
    }

    public void doWork() throws OperatorException {
        PredictionModel data = this.modelInputPort.getData(PredictionModel.class);
        if (data instanceof GeneralizedLinearModel_v2) {
            this.exampleSetOutputPort.deliver(convertGLM((GeneralizedLinearModel_v2) data));
        } else if (data instanceof GeneralizedLinearModel) {
            this.exampleSetOutputPort.deliver(convertGLM((GeneralizedLinearModel) data));
        } else {
            if (!(data instanceof LinearRegressionModel)) {
                throw new UserError(this, "converters.no_linear_regression", new Object[]{data.getClass().toString()});
            }
            this.exampleSetOutputPort.deliver(convertLinReg((LinearRegressionModel) data));
        }
        this.originalModelOutputPort.deliver(data);
    }

    private IOTable convertGLM(GeneralizedLinearModel_v2 generalizedLinearModel_v2) {
        MixedRowWriter mixedRowWriter = Writers.mixedRowWriter(this.attNames, this.attTypes, false);
        for (int i = 0; i < generalizedLinearModel_v2.getCoefficientNames().length; i++) {
            mixedRowWriter.move();
            mixedRowWriter.set(0, generalizedLinearModel_v2.getCoefficientNames()[i]);
            mixedRowWriter.set(1, generalizedLinearModel_v2.getCoefficients()[i]);
            mixedRowWriter.set(2, generalizedLinearModel_v2.getStdCoefficients()[i]);
            mixedRowWriter.set(3, generalizedLinearModel_v2.getStdErr() != null ? generalizedLinearModel_v2.getStdErr()[i] : Double.NaN);
            mixedRowWriter.set(4, generalizedLinearModel_v2.getZValues() != null ? generalizedLinearModel_v2.getZValues()[i] : Double.NaN);
            mixedRowWriter.set(5, generalizedLinearModel_v2.getPValues() != null ? generalizedLinearModel_v2.getPValues()[i] : Double.NaN);
        }
        return new IOTable(mixedRowWriter.create());
    }

    private IOTable convertGLM(GeneralizedLinearModel generalizedLinearModel) {
        MixedRowWriter mixedRowWriter = Writers.mixedRowWriter(this.attNames, this.attTypes, false);
        for (int i = 0; i < generalizedLinearModel.getCoefficientNames().length; i++) {
            mixedRowWriter.move();
            mixedRowWriter.set(0, generalizedLinearModel.getCoefficientNames()[i]);
            mixedRowWriter.set(1, generalizedLinearModel.getCoefficients()[i]);
            mixedRowWriter.set(2, generalizedLinearModel.getStdCoefficients()[i]);
            mixedRowWriter.set(3, generalizedLinearModel.getStdErr() != null ? generalizedLinearModel.getStdErr()[i] : Double.NaN);
            mixedRowWriter.set(4, generalizedLinearModel.getZValues() != null ? generalizedLinearModel.getZValues()[i] : Double.NaN);
            mixedRowWriter.set(5, generalizedLinearModel.getPValues() != null ? generalizedLinearModel.getPValues()[i] : Double.NaN);
        }
        return new IOTable(mixedRowWriter.create());
    }

    private ExampleSet convertLinReg(LinearRegressionModel linearRegressionModel) {
        LinkedList linkedList = new LinkedList();
        Attribute createAttribute = AttributeFactory.createAttribute("Attribute", 1);
        Attribute createAttribute2 = AttributeFactory.createAttribute("Coefficient", 2);
        Attribute createAttribute3 = AttributeFactory.createAttribute("Std. Error", 2);
        Attribute createAttribute4 = AttributeFactory.createAttribute("Std. Coefficients", 2);
        Attribute createAttribute5 = AttributeFactory.createAttribute("Tolerance", 2);
        Attribute createAttribute6 = AttributeFactory.createAttribute("t-Stat", 2);
        Attribute createAttribute7 = AttributeFactory.createAttribute("p-Value", 2);
        Attribute createAttribute8 = AttributeFactory.createAttribute("Code", 1);
        linkedList.add(createAttribute);
        linkedList.add(createAttribute2);
        linkedList.add(createAttribute3);
        linkedList.add(createAttribute4);
        linkedList.add(createAttribute5);
        linkedList.add(createAttribute6);
        linkedList.add(createAttribute7);
        linkedList.add(createAttribute8);
        ExampleSetBuilder from = ExampleSets.from(linkedList);
        String[] selectedAttributeNames = linearRegressionModel.getSelectedAttributeNames();
        if (linearRegressionModel.usesIntercept()) {
            selectedAttributeNames = (String[]) Arrays.copyOf(selectedAttributeNames, selectedAttributeNames.length + 1);
            selectedAttributeNames[selectedAttributeNames.length - 1] = "(Intercept)";
        }
        String[] strArr = selectedAttributeNames;
        double[] coefficients = linearRegressionModel.getCoefficients();
        double[] standardErrors = linearRegressionModel.getStandardErrors();
        double[] standardizedCoefficients = linearRegressionModel.getStandardizedCoefficients();
        double[] tolerances = linearRegressionModel.getTolerances();
        double[] tStats = linearRegressionModel.getTStats();
        double[] probabilities = linearRegressionModel.getProbabilities();
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            double d = probabilities[i];
            String str = "";
            if (d < 0.001d) {
                str = "****";
            } else if (d < 0.01d) {
                str = "***";
            } else if (d < 0.05d) {
                str = "**";
            } else if (d < 0.1d) {
                str = "*";
            }
            strArr2[i] = str;
        }
        from.withBlankSize(strArr.length);
        from.withColumnFiller(createAttribute, i2 -> {
            return createAttribute.getMapping().mapString(strArr[i2]);
        });
        from.withColumnFiller(createAttribute2, i3 -> {
            return coefficients[i3];
        });
        from.withColumnFiller(createAttribute3, i4 -> {
            return standardErrors[i4];
        });
        from.withColumnFiller(createAttribute4, i5 -> {
            return standardizedCoefficients[i5];
        });
        from.withColumnFiller(createAttribute5, i6 -> {
            return tolerances[i6];
        });
        from.withColumnFiller(createAttribute6, i7 -> {
            return tStats[i7];
        });
        from.withColumnFiller(createAttribute7, i8 -> {
            return probabilities[i8];
        });
        from.withColumnFiller(createAttribute8, i9 -> {
            return createAttribute8.getMapping().mapString(strArr2[i9]);
        });
        return from.build();
    }
}
