package com.rapidminer.extension.operator_toolbox.operator.models.tresholds;

import bsh.org.objectweb.asm.Constants;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.operator.models.tresholds.ThresholdModel;
import com.rapidminer.operator.GroupedModel;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.operator.postprocessing.Threshold;

/* loaded from: input_file:com/rapidminer/extension/operator_toolbox/operator/models/tresholds/OptimizeThresholdCustomPerformance.class */
public class OptimizeThresholdCustomPerformance extends OperatorChain {
    InputPort exaInput;
    InputPort modInput;
    OutputPort exaOutput;
    OutputPort modOutput;
    OutputPort perfOutput;
    private final OutputPort innerExampleSource;
    private final InputPort innerPerformanceOutput;
    public static final String PARAMETER_PERFORMANCE_MEASURE = "performance_measure";

    public OptimizeThresholdCustomPerformance(OperatorDescription operatorDescription) {
        this(operatorDescription, "Nested Process");
    }

    public OptimizeThresholdCustomPerformance(OperatorDescription operatorDescription, String str) {
        super(operatorDescription, new String[]{str});
        this.exaInput = getInputPorts().createPort("exa", ExampleSet.class);
        this.modInput = getInputPorts().createPort("mod");
        this.exaOutput = getOutputPorts().createPort("exa");
        this.modOutput = getOutputPorts().createPort("mod");
        this.perfOutput = getOutputPorts().createPort("perf");
        this.innerExampleSource = getSubprocess(0).getInnerSources().createPort("scored set");
        this.innerPerformanceOutput = getSubprocess(0).getInnerSinks().createPort("Performance Vector", PerformanceVector.class);
        getTransformer().addPassThroughRule(this.exaInput, this.exaOutput);
        if (this.modInput.isConnected()) {
            getTransformer().addGenerationRule(this.modOutput, PredictionModel.class);
        } else {
            getTransformer().addGenerationRule(this.modOutput, GroupedModel.class);
        }
        getTransformer().addGenerationRule(this.perfOutput, PerformanceVector.class);
        this.modInput.addPrecondition(new SimplePrecondition(this.modInput, new MetaData(PredictionModel.class), false) { // from class: com.rapidminer.extension.operator_toolbox.operator.models.tresholds.OptimizeThresholdCustomPerformance.1
            protected boolean isMandatory() {
                return false;
            }
        });
        getTransformer().addPassThroughRule(this.exaInput, this.innerExampleSource);
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
    }

    public void doWork() throws OperatorException {
        IOObject iOObject = (ExampleSet) this.exaInput.getData(ExampleSet.class);
        if (this.modInput.isConnected() && iOObject.getAttributes().getPredictedLabel() == null) {
            iOObject = this.modInput.getData(PredictionModel.class).apply(iOObject);
        } else if (!this.modInput.isConnected() && iOObject.getAttributes().getPredictedLabel() == null) {
            throw new UserError(this, Constants.FCMPL, new Object[]{this.modInput.getName()});
        }
        if (iOObject.getAttributes().getPredictedLabel().isNumerical() || iOObject.getAttributes().getLabel().getMapping().size() > 2) {
            throw new UserError(this, 114, new Object[]{"application of thresholds", iOObject.getAttributes().getPredictedLabel().getName()});
        }
        SimpleDataTable simpleDataTable = new SimpleDataTable(getName(), new String[]{"Threshold", "Performance"});
        getProcess().addDataTable(simpleDataTable);
        ThresholdModel thresholdModel = null;
        PerformanceVector performanceVector = null;
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (d2 >= 1.0d) {
                break;
            }
            ThresholdModel thresholdModel2 = new ThresholdModel(new Threshold(d2, iOObject.getAttributes().getLabel().getMapping().getNegativeString(), iOObject.getAttributes().getLabel().getMapping().getPositiveString()));
            iOObject = thresholdModel2.apply(iOObject);
            this.innerExampleSource.deliver(iOObject);
            getSubprocess(0).execute();
            PerformanceVector data = this.innerPerformanceOutput.getData(PerformanceVector.class);
            simpleDataTable.add(new SimpleDataTableRow(new double[]{d2, data.getMainCriterion().getAverage()}));
            if (performanceVector == null || data.getMainCriterion().getFitness() > performanceVector.getMainCriterion().getFitness()) {
                performanceVector = data;
                thresholdModel = thresholdModel2;
            }
            d = d2 + 0.01d;
        }
        if (thresholdModel == null) {
            throw new OperatorException("No optimal threshold could be found.");
        }
        ExampleSet apply = thresholdModel.apply(iOObject);
        this.exaOutput.deliver(apply);
        if (this.modInput.isConnected()) {
            GroupedModel groupedModel = new GroupedModel(apply);
            groupedModel.addModel(this.modInput.getData(PredictionModel.class));
            groupedModel.addModel(thresholdModel);
            this.modOutput.deliver(groupedModel);
        } else {
            this.modOutput.deliver(thresholdModel);
        }
        this.perfOutput.deliver(performanceVector);
    }
}
