package com.rapidminer.extension.operator.models.tresholds;

import bsh.org.objectweb.asm.Constants;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.GroupedModel;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.performance.BinaryClassificationPerformance;
import com.rapidminer.operator.performance.BinominalClassificationPerformanceEvaluator;
import com.rapidminer.operator.performance.MultiClassificationPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.postprocessing.Threshold;

/* loaded from: input_file:com/rapidminer/extension/operator/models/tresholds/OptimizeThreshold.class */
public class OptimizeThreshold extends Operator {
    InputPort exaInput;
    InputPort modInput;
    OutputPort exaOutput;
    OutputPort modOutput;
    OutputPort perfOutput;
    public static final String PARAMETER_PERFORMANCE_MEASURE = "performance_measure";

    public OptimizeThreshold(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.exaInput = getInputPorts().createPort("exa", ExampleSet.class);
        this.modInput = getInputPorts().createPort("mod", PredictionModel.class);
        this.exaOutput = getOutputPorts().createPort("exa");
        this.modOutput = getOutputPorts().createPort("mod");
        this.perfOutput = getOutputPorts().createPort("perf");
        getTransformer().addPassThroughRule(this.exaInput, this.exaOutput);
        if (this.modInput.isConnected()) {
            getTransformer().addGenerationRule(this.modOutput, PredictionModel.class);
        } else {
            getTransformer().addGenerationRule(this.modOutput, GroupedModel.class);
        }
        getTransformer().addGenerationRule(this.perfOutput, PerformanceVector.class);
        this.modInput.addPrecondition(new SimplePrecondition(this.modInput, null, false) { // from class: com.rapidminer.extension.operator.models.tresholds.OptimizeThreshold.1
            protected boolean isMandatory() {
                return false;
            }
        });
    }

    public void doWork() throws OperatorException {
        ExampleSet data = this.exaInput.getData(ExampleSet.class);
        if (this.modInput.isConnected() && data.getAttributes().getPredictedLabel() == null) {
            data = this.modInput.getData(PredictionModel.class).apply(data);
        } else if (!this.modInput.isConnected() && data.getAttributes().getPredictedLabel() == null) {
            throw new UserError(this, Constants.FCMPL, new Object[]{this.modInput.getName()});
        }
        ThresholdModel thresholdModel = null;
        PerformanceVector performanceVector = null;
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (d2 >= 1.0d) {
                break;
            }
            data.getAttributes().getLabel().getMapping().getNegativeIndex();
            ThresholdModel thresholdModel2 = new ThresholdModel(new Threshold(d2, data.getAttributes().getLabel().getMapping().getNegativeString(), data.getAttributes().getLabel().getMapping().getPositiveString()));
            thresholdModel2.apply(data);
            PerformanceVector doWork = new BinominalClassificationPerformanceEvaluator(getOperatorDescription()).doWork(data);
            if (performanceVector == null || doWork.getMainCriterion().getFitness() > performanceVector.getMainCriterion().getFitness()) {
                performanceVector = doWork;
                thresholdModel = thresholdModel2;
            }
            d = d2 + 0.01d;
        }
        thresholdModel.apply(data);
        this.exaOutput.deliver(data);
        if (this.modInput.isConnected()) {
            GroupedModel groupedModel = new GroupedModel(data);
            groupedModel.addModel(this.modInput.getData(PredictionModel.class));
            groupedModel.addModel(thresholdModel);
            this.modOutput.deliver(groupedModel);
        } else {
            this.modOutput.deliver(thresholdModel);
        }
        this.perfOutput.deliver(performanceVector);
    }

    public static String[] getCriteria() {
        String[] strArr = BinaryClassificationPerformance.NAMES;
        String[] strArr2 = MultiClassificationPerformance.NAMES;
        String[] strArr3 = new String[strArr.length + strArr2.length];
        int i = 0;
        for (String str : strArr) {
            strArr3[i] = str;
            i++;
        }
        for (String str2 : strArr2) {
            strArr3[i] = str2;
            i++;
        }
        return strArr3;
    }
}
