package com.owc.operator.validation;

import com.owc.license.ProductInformation;
import com.owc.operator.validation.ValidationOperator;
import com.owc.tools.ExampleSetAppender;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.extension.PluginInitJackhammerExtension;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.studio.concurrency.internal.ConcurrencyExecutionServiceProvider;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.container.Pair;
import com.rapidminer.tools.math.AverageVector;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator;

/* loaded from: input_file:com/owc/operator/validation/CrossValidationOperator.class */
public class CrossValidationOperator extends ValidationOperator {
    public static String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static String PARAMETER_SAMPLING_TYPE = "sampling_type";
    private InputPort exampleSetInput;
    private OutputPort trainingSetInnerOutput;
    private OutputPort testSetInnerOutput;
    private InputPort testResultSetInnerInput;
    private OutputPort exampleSetOutput;
    private OutputPort testResultSetOutput;
    private InputPort performanceInnerInput;
    private OutputPort performanceOutput;
    private InputPort modelInnerInput;
    private OutputPort modelInnerOutput;
    private OutputPort modelOutput;
    private double[] loggingValuesPerformance;
    private double[] loggingValuesStandardDeviation;

    public CrossValidationOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class);
        this.trainingSetInnerOutput = getSubprocess(0).getInnerSources().createPort("training set");
        this.testSetInnerOutput = getSubprocess(1).getInnerSources().createPort("test set");
        this.testResultSetInnerInput = getSubprocess(1).getInnerSinks().createPort("test set results");
        this.exampleSetOutput = getOutputPorts().createPort("example set");
        this.testResultSetOutput = getOutputPorts().createPort("test result set");
        this.performanceInnerInput = getSubprocess(1).getInnerSinks().createPort("performance", PerformanceVector.class);
        this.performanceOutput = getOutputPorts().createPort("performance");
        this.modelInnerInput = getSubprocess(0).getInnerSinks().createPort("model", Model.class);
        this.modelInnerOutput = getSubprocess(1).getInnerSources().createPort("model");
        this.modelOutput = getOutputPorts().createPort("model");
        this.loggingValuesPerformance = new double[4];
        this.loggingValuesStandardDeviation = new double[4];
        getTransformer().addPassThroughRule(this.exampleSetInput, this.trainingSetInnerOutput);
        getTransformer().addRule(this.inputExtender.makePassThroughRule());
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
        getTransformer().addPassThroughRule(this.modelInnerInput, this.modelInnerOutput);
        getTransformer().addPassThroughRule(this.exampleSetInput, this.testSetInnerOutput);
        getTransformer().addRule(this.resultExtender.makePassThroughRule());
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(1)));
        getTransformer().addPassThroughRule(this.modelInnerInput, this.modelOutput);
        getTransformer().addPassThroughRule(this.performanceInnerInput, this.performanceOutput);
        getTransformer().addPassThroughRule(this.exampleSetInput, this.exampleSetOutput);
        getTransformer().addPassThroughRule(this.testResultSetInnerInput, this.testResultSetOutput);
        this.testResultSetInnerInput.addPrecondition(new SimplePrecondition(this.testResultSetInnerInput, new ExampleSetMetaData()) { // from class: com.owc.operator.validation.CrossValidationOperator.1
            protected boolean isMandatory() {
                return false;
            }
        });
        addValue(new ValueDouble("performance main criterion", "The micro average of the main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely.") { // from class: com.owc.operator.validation.CrossValidationOperator.2
            public double getDoubleValue() {
                return CrossValidationOperator.this.loggingValuesPerformance[0];
            }
        });
        addValue(new ValueDouble("std deviation main criterion", "The standard deviation over all folds of the main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely.") { // from class: com.owc.operator.validation.CrossValidationOperator.3
            public double getDoubleValue() {
                return CrossValidationOperator.this.loggingValuesStandardDeviation[0];
            }
        });
        for (int i = 1; i < 4; i++) {
            final int i2 = i;
            addValue(new ValueDouble("performance " + i, "The micro average of the " + i + ". main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely.") { // from class: com.owc.operator.validation.CrossValidationOperator.4
                public double getDoubleValue() {
                    return CrossValidationOperator.this.loggingValuesPerformance[i2];
                }
            });
            addValue(new ValueDouble("std deviation " + i, "The standard deviation over all folds of the " + i + ". criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely.") { // from class: com.owc.operator.validation.CrossValidationOperator.5
                public double getDoubleValue() {
                    return CrossValidationOperator.this.loggingValuesStandardDeviation[i2];
                }
            });
        }
    }

    @Override // com.owc.operator.LicensedOperatorChain
    public void doWork(boolean z) throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.exampleSetInput.getData(ExampleSet.class);
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMBER_OF_VALIDATIONS);
        if (getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) {
            parameterAsInt = exampleSet.size();
        }
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, parameterAsInt, getParameterAsInt(PARAMETER_SAMPLING_TYPE), getParameterAsBoolean("use_local_random_seed"), getParameterAsInt("local_random_seed"));
        List<IOObject> dataOrNull = this.inputExtender.getDataOrNull(IOObject.class);
        if (checkParallelizability()) {
            performParallelValidation(exampleSet, parameterAsInt, splittedExampleSet, dataOrNull);
        } else {
            performSycronizedValidation(exampleSet, parameterAsInt, splittedExampleSet, dataOrNull);
        }
    }

    private void performSycronizedValidation(ExampleSet exampleSet, int i, SplittedExampleSet splittedExampleSet, List<IOObject> list) throws UndefinedParameterError, OperatorException {
        ArrayList<Pair> arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            splittedExampleSet.selectAllSubsetsBut(i2);
            ExampleSet exampleSet2 = (ExampleSet) getDataCopy((IOObject) splittedExampleSet);
            splittedExampleSet.selectSingleSubset(i2);
            ExampleSet exampleSet3 = (ExampleSet) getDataCopy((IOObject) splittedExampleSet);
            Pair<Model, List<IOObject>> train = train(exampleSet2, list);
            arrayList.add(test(exampleSet3, (Model) train.getFirst(), (List) train.getSecond()));
        }
        if (this.modelOutput.isConnected() || this.resultExtender.isConnected(getOutputPorts())) {
            splittedExampleSet.selectAllSubsets();
            Pair<Model, List<IOObject>> train2 = train((ExampleSet) getDataCopy((IOObject) splittedExampleSet), list);
            this.modelOutput.deliver((IOObject) train2.getFirst());
            this.resultExtender.deliver((List) train2.getSecond());
        }
        Pair pair = (Pair) arrayList.remove(i - 1);
        LinkedList linkedList = new LinkedList();
        if (pair.getSecond() != null) {
            linkedList.add(pair.getSecond());
        }
        PerformanceVector performanceVector = (PerformanceVector) pair.getFirst();
        for (Pair pair2 : arrayList) {
            performanceVector.buildAverages((AverageVector) pair2.getFirst());
            if (pair2.getSecond() != null) {
                linkedList.add(pair2.getSecond());
            }
        }
        rememberLoggingValues(performanceVector);
        if (this.testResultSetOutput.isConnected() && this.testResultSetInnerInput.isConnected()) {
            this.testResultSetOutput.deliver(ExampleSetAppender.merge(linkedList, this));
        }
        this.exampleSetOutput.deliver(exampleSet);
        this.performanceOutput.deliver(performanceVector);
    }

    private void rememberLoggingValues(PerformanceVector performanceVector) {
        if (performanceVector.getMainCriterion() != null) {
            this.loggingValuesPerformance[0] = performanceVector.getMainCriterion().getMikroAverage();
            this.loggingValuesStandardDeviation[0] = performanceVector.getMainCriterion().getMakroStandardDeviation();
        }
        for (int i = 0; i < 3; i++) {
            if (performanceVector.getSize() > i) {
                this.loggingValuesPerformance[i + 1] = performanceVector.getCriterion(i).getMikroAverage();
                this.loggingValuesStandardDeviation[i + 1] = performanceVector.getCriterion(i).getMakroStandardDeviation();
            }
        }
    }

    private void performParallelValidation(ExampleSet exampleSet, int i, final SplittedExampleSet splittedExampleSet, final List<IOObject> list) throws UndefinedParameterError, OperatorException {
        LinkedList linkedList = new LinkedList();
        for (int i2 = 0; i2 < i; i2++) {
            final int i3 = i2;
            final CrossValidationOperator cloneOperator = cloneOperator(getName(), true);
            linkedList.add(ConcurrencyExecutionServiceProvider.INSTANCE.getService().prepareOperatorTask(getProcess(), cloneOperator, i2 + 1, i2 + 1 == i, new Callable<ValidationOperator.RunResult>() { // from class: com.owc.operator.validation.CrossValidationOperator.6
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public ValidationOperator.RunResult call() throws Exception {
                    ExampleSet dataCopy;
                    ExampleSet dataCopy2;
                    synchronized (splittedExampleSet) {
                        splittedExampleSet.selectAllSubsetsBut(i3);
                        dataCopy = CrossValidationOperator.this.getDataCopy((IOObject) splittedExampleSet);
                    }
                    Pair train = cloneOperator.train(dataCopy, list);
                    synchronized (splittedExampleSet) {
                        splittedExampleSet.selectSingleSubset(i3);
                        dataCopy2 = CrossValidationOperator.this.getDataCopy((IOObject) splittedExampleSet);
                    }
                    Pair test = cloneOperator.test(dataCopy2, (Model) train.getFirst(), (List) train.getSecond());
                    return new ValidationOperator.RunResult((PerformanceVector) test.getFirst(), null, null, (ExampleSet) test.getSecond());
                }
            }));
        }
        if (this.modelOutput.isConnected() || this.resultExtender.isConnected(getOutputPorts())) {
            final CrossValidationOperator cloneOperator2 = cloneOperator(getName(), true);
            linkedList.add(ConcurrencyExecutionServiceProvider.INSTANCE.getService().prepareOperatorTask(getProcess(), cloneOperator2, i + 1, true, new Callable<ValidationOperator.RunResult>() { // from class: com.owc.operator.validation.CrossValidationOperator.7
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public ValidationOperator.RunResult call() throws Exception {
                    ExampleSet dataCopy;
                    synchronized (splittedExampleSet) {
                        splittedExampleSet.selectAllSubsets();
                        dataCopy = CrossValidationOperator.this.getDataCopy((IOObject) splittedExampleSet);
                    }
                    Pair train = cloneOperator2.train(dataCopy, list);
                    return new ValidationOperator.RunResult(null, (List) train.getSecond(), (Model) train.getFirst(), null);
                }
            }));
        }
        List<ValidationOperator.RunResult> executeOperatorTasks = ConcurrencyExecutionServiceProvider.INSTANCE.getService().executeOperatorTasks(this, linkedList);
        PerformanceVector performanceVector = null;
        LinkedList linkedList2 = new LinkedList();
        for (ValidationOperator.RunResult runResult : executeOperatorTasks) {
            if (runResult.model != null) {
                this.modelOutput.deliver(runResult.model);
            }
            if (runResult.results != null) {
                this.resultExtender.deliver(runResult.results);
            }
            if (runResult.performance != null) {
                if (performanceVector == null) {
                    performanceVector = runResult.performance;
                } else {
                    performanceVector.buildAverages(runResult.performance);
                }
            }
            if (runResult.testSet != null) {
                linkedList2.add(runResult.testSet);
            }
        }
        rememberLoggingValues(performanceVector);
        if (this.testResultSetOutput.isConnected() && this.testResultSetInnerInput.isConnected()) {
            this.testResultSetOutput.deliver(ExampleSetAppender.merge(linkedList2, this));
        }
        this.exampleSetOutput.deliver(exampleSet);
        this.performanceOutput.deliver(performanceVector);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<Model, List<IOObject>> train(ExampleSet exampleSet, List<IOObject> list) throws OperatorException {
        this.trainingSetInnerOutput.deliver(exampleSet);
        this.inputExtender.deliver(list);
        getSubprocess(0).execute();
        return new Pair<>(this.modelInnerInput.getData(Model.class), this.resultExtender.getDataOrNull(IOObject.class));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<PerformanceVector, ExampleSet> test(ExampleSet exampleSet, Model model, List<IOObject> list) throws OperatorException {
        this.testSetInnerOutput.deliver(exampleSet);
        this.modelInnerOutput.deliver(model);
        this.resultExtender.deliver(list);
        getSubprocess(1).execute();
        return new Pair<>(this.performanceInnerInput.getData(PerformanceVector.class), this.testResultSetInnerInput.getDataOrNull(ExampleSet.class));
    }

    @Override // com.owc.operator.ParallelOperatorChain, com.owc.operator.LicensedOperatorChain
    public List<ParameterType> getParameterTypes() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(new ParameterTypeBoolean(PARAMETER_LEAVE_ONE_OUT, "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false, false));
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt(PARAMETER_NUMBER_OF_VALIDATIONS, "Number of subsets for the crossvalidation.", 2, BaseAbstractUnivariateIntegrator.DEFAULT_MAX_ITERATIONS_COUNT, 10);
        parameterTypeInt.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
        parameterTypeInt.setExpert(false);
        linkedList.add(parameterTypeInt);
        ParameterTypeCategory parameterTypeCategory = new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 3);
        parameterTypeCategory.setExpert(false);
        parameterTypeCategory.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
        linkedList.add(parameterTypeCategory);
        for (ParameterType parameterType : RandomGenerator.getRandomGeneratorParameters(this)) {
            parameterType.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false));
            parameterType.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_SAMPLING_TYPE, SplittedExampleSet.SAMPLING_NAMES, false, new int[]{1, 2}));
            linkedList.add(parameterType);
        }
        List<ParameterType> parameterTypes = super.getParameterTypes();
        linkedList.addAll(parameterTypes);
        ParameterType parameterType2 = parameterTypes.get(0);
        if (linkedList.remove(parameterType2)) {
            linkedList.add(0, parameterType2);
        }
        return linkedList;
    }

    @Override // com.owc.operator.LicensedOperatorChain
    public ProductInformation getProductInformation() {
        return PluginInitJackhammerExtension.PRODUCT_INFORMATION;
    }
}
