/*
 * Decompiled with CFR 0.152.
 */
package com.owc.operator.validation;

import com.owc.license.ProductInformation;
import com.owc.operator.validation.ValidationOperator;
import com.owc.tools.ExampleSetAppender;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.extension.PluginInitJackhammerExtension;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.Value;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MDTransformationRule;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterHandler;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.parameter.conditions.ParameterCondition;
import com.rapidminer.studio.concurrency.internal.ConcurrencyExecutionServiceProvider;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.container.Pair;
import com.rapidminer.tools.math.AverageVector;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;

public class CrossValidationOperator
extends ValidationOperator {
    public static String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static String PARAMETER_SAMPLING_TYPE = "sampling_type";
    private InputPort exampleSetInput = this.getInputPorts().createPort("example set", ExampleSet.class);
    private OutputPort trainingSetInnerOutput = (OutputPort)this.getSubprocess(0).getInnerSources().createPort("training set");
    private OutputPort testSetInnerOutput = (OutputPort)this.getSubprocess(1).getInnerSources().createPort("test set");
    private InputPort testResultSetInnerInput = (InputPort)this.getSubprocess(1).getInnerSinks().createPort("test set results");
    private OutputPort exampleSetOutput = (OutputPort)this.getOutputPorts().createPort("example set");
    private OutputPort testResultSetOutput = (OutputPort)this.getOutputPorts().createPort("test result set");
    private InputPort performanceInnerInput = this.getSubprocess(1).getInnerSinks().createPort("performance", PerformanceVector.class);
    private OutputPort performanceOutput = (OutputPort)this.getOutputPorts().createPort("performance");
    private InputPort modelInnerInput = this.getSubprocess(0).getInnerSinks().createPort("model", Model.class);
    private OutputPort modelInnerOutput = (OutputPort)this.getSubprocess(1).getInnerSources().createPort("model");
    private OutputPort modelOutput = (OutputPort)this.getOutputPorts().createPort("model");
    private double[] loggingValuesPerformance = new double[4];
    private double[] loggingValuesStandardDeviation = new double[4];

    public CrossValidationOperator(OperatorDescription description) {
        super(description);
        this.getTransformer().addPassThroughRule(this.exampleSetInput, this.trainingSetInnerOutput);
        this.getTransformer().addRule(this.inputExtender.makePassThroughRule());
        this.getTransformer().addRule((MDTransformationRule)new SubprocessTransformRule(this.getSubprocess(0)));
        this.getTransformer().addPassThroughRule(this.modelInnerInput, this.modelInnerOutput);
        this.getTransformer().addPassThroughRule(this.exampleSetInput, this.testSetInnerOutput);
        this.getTransformer().addRule(this.resultExtender.makePassThroughRule());
        this.getTransformer().addRule((MDTransformationRule)new SubprocessTransformRule(this.getSubprocess(1)));
        this.getTransformer().addPassThroughRule(this.modelInnerInput, this.modelOutput);
        this.getTransformer().addPassThroughRule(this.performanceInnerInput, this.performanceOutput);
        this.getTransformer().addPassThroughRule(this.exampleSetInput, this.exampleSetOutput);
        this.getTransformer().addPassThroughRule(this.testResultSetInnerInput, this.testResultSetOutput);
        this.testResultSetInnerInput.addPrecondition((Precondition)new SimplePrecondition(this.testResultSetInnerInput, (MetaData)new ExampleSetMetaData()){

            protected boolean isMandatory() {
                return false;
            }
        });
        this.addValue((Value)new ValueDouble("performance main criterion", "The micro average of the main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely."){

            public double getDoubleValue() {
                return CrossValidationOperator.this.loggingValuesPerformance[0];
            }
        });
        this.addValue((Value)new ValueDouble("std deviation main criterion", "The standard deviation over all folds of the main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely."){

            public double getDoubleValue() {
                return CrossValidationOperator.this.loggingValuesStandardDeviation[0];
            }
        });
        for (int i = 1; i < 4; ++i) {
            final int index = i;
            this.addValue((Value)new ValueDouble("performance " + i, "The micro average of the " + i + ". main criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely."){

                public double getDoubleValue() {
                    return CrossValidationOperator.this.loggingValuesPerformance[index];
                }
            });
            this.addValue((Value)new ValueDouble("std deviation " + i, "The standard deviation over all folds of the " + i + ". criterion of the performance vector delivered by the testing subprocess. Available only after the entire Operator is executed completely."){

                public double getDoubleValue() {
                    return CrossValidationOperator.this.loggingValuesStandardDeviation[index];
                }
            });
        }
    }

    @Override
    public void doWork(boolean isLicensed) throws OperatorException {
        ExampleSet set = (ExampleSet)this.exampleSetInput.getData(ExampleSet.class);
        int numberOfValidations = this.getParameterAsInt(PARAMETER_NUMBER_OF_VALIDATIONS);
        if (this.getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) {
            numberOfValidations = set.size();
        }
        int samplingType = this.getParameterAsInt(PARAMETER_SAMPLING_TYPE);
        boolean useLocalRandomSeed = this.getParameterAsBoolean("use_local_random_seed");
        int localRandomSeed = this.getParameterAsInt("local_random_seed");
        SplittedExampleSet splittedSet = new SplittedExampleSet(set, numberOfValidations, samplingType, useLocalRandomSeed, localRandomSeed);
        List<IOObject> arbitraryInputObjects = this.inputExtender.getDataOrNull(IOObject.class);
        boolean executeParallely = this.checkParallelizability();
        if (executeParallely) {
            this.performParallelValidation(set, numberOfValidations, splittedSet, arbitraryInputObjects);
        } else {
            this.performSycronizedValidation(set, numberOfValidations, splittedSet, arbitraryInputObjects);
        }
    }

    private void performSycronizedValidation(ExampleSet set, int numberOfValidations, SplittedExampleSet splittedSet, List<IOObject> arbitraryInputObjects) throws UndefinedParameterError, OperatorException {
        ArrayList<Pair<PerformanceVector, ExampleSet>> results = new ArrayList<Pair<PerformanceVector, ExampleSet>>(numberOfValidations);
        for (int iteration = 0; iteration < numberOfValidations; ++iteration) {
            splittedSet.selectAllSubsetsBut(iteration);
            ExampleSet trainSet = (ExampleSet)this.getDataCopy((IOObject)splittedSet);
            splittedSet.selectSingleSubset(iteration);
            ExampleSet testSet = (ExampleSet)this.getDataCopy((IOObject)splittedSet);
            Pair<Model, List<IOObject>> trainResults = this.train(trainSet, arbitraryInputObjects);
            results.add(this.test(testSet, (Model)trainResults.getFirst(), (List)trainResults.getSecond()));
        }
        if (this.modelOutput.isConnected() || this.resultExtender.isConnected(this.getOutputPorts())) {
            splittedSet.selectAllSubsets();
            ExampleSet trainSet = (ExampleSet)this.getDataCopy((IOObject)splittedSet);
            Pair<Model, List<IOObject>> fullResults = this.train(trainSet, arbitraryInputObjects);
            this.modelOutput.deliver((IOObject)fullResults.getFirst());
            this.resultExtender.deliver((List)fullResults.getSecond());
        }
        Pair firstResult = (Pair)results.remove(numberOfValidations - 1);
        LinkedList<ExampleSet> resultSets = new LinkedList<ExampleSet>();
        if (firstResult.getSecond() != null) {
            resultSets.add((ExampleSet)firstResult.getSecond());
        }
        PerformanceVector vector = (PerformanceVector)firstResult.getFirst();
        for (Pair pair : results) {
            vector.buildAverages((AverageVector)pair.getFirst());
            if (pair.getSecond() == null) continue;
            resultSets.add((ExampleSet)pair.getSecond());
        }
        this.rememberLoggingValues(vector);
        if (this.testResultSetOutput.isConnected() && this.testResultSetInnerInput.isConnected()) {
            this.testResultSetOutput.deliver((IOObject)ExampleSetAppender.merge(resultSets, (Operator)this));
        }
        this.exampleSetOutput.deliver((IOObject)set);
        this.performanceOutput.deliver((IOObject)vector);
    }

    private void rememberLoggingValues(PerformanceVector vector) {
        if (vector.getMainCriterion() != null) {
            this.loggingValuesPerformance[0] = vector.getMainCriterion().getMikroAverage();
            this.loggingValuesStandardDeviation[0] = vector.getMainCriterion().getMakroStandardDeviation();
        }
        for (int i = 0; i < 3; ++i) {
            if (vector.getSize() <= i) continue;
            this.loggingValuesPerformance[i + 1] = vector.getCriterion(i).getMikroAverage();
            this.loggingValuesStandardDeviation[i + 1] = vector.getCriterion(i).getMakroStandardDeviation();
        }
    }

    private void performParallelValidation(ExampleSet set, int numberOfValidations, final SplittedExampleSet splittedSet, final List<IOObject> arbitraryInputObjects) throws UndefinedParameterError, OperatorException {
        LinkedList<Callable> tasks = new LinkedList<Callable>();
        for (int iteration = 0; iteration < numberOfValidations; ++iteration) {
            final int currentIteration = iteration;
            final CrossValidationOperator copy = (CrossValidationOperator)this.cloneOperator(this.getName(), true);
            Callable task = ConcurrencyExecutionServiceProvider.INSTANCE.getService().prepareOperatorTask(this.getProcess(), (Operator)copy, iteration + 1, iteration + 1 == numberOfValidations, (Callable)new Callable<ValidationOperator.RunResult>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public ValidationOperator.RunResult call() throws Exception {
                    ExampleSet testSet;
                    ExampleSet trainSet;
                    SplittedExampleSet splittedExampleSet = splittedSet;
                    synchronized (splittedExampleSet) {
                        splittedSet.selectAllSubsetsBut(currentIteration);
                        trainSet = (ExampleSet)CrossValidationOperator.this.getDataCopy((IOObject)splittedSet);
                    }
                    Pair trainResults = copy.train(trainSet, arbitraryInputObjects);
                    trainSet = null;
                    SplittedExampleSet splittedExampleSet2 = splittedSet;
                    synchronized (splittedExampleSet2) {
                        splittedSet.selectSingleSubset(currentIteration);
                        testSet = (ExampleSet)CrossValidationOperator.this.getDataCopy((IOObject)splittedSet);
                    }
                    Pair test = copy.test(testSet, (Model)trainResults.getFirst(), (List)trainResults.getSecond());
                    return new ValidationOperator.RunResult((PerformanceVector)test.getFirst(), null, null, (ExampleSet)test.getSecond());
                }
            });
            tasks.add(task);
        }
        if (this.modelOutput.isConnected() || this.resultExtender.isConnected(this.getOutputPorts())) {
            final CrossValidationOperator copy = (CrossValidationOperator)this.cloneOperator(this.getName(), true);
            Callable task = ConcurrencyExecutionServiceProvider.INSTANCE.getService().prepareOperatorTask(this.getProcess(), (Operator)copy, numberOfValidations + 1, true, (Callable)new Callable<ValidationOperator.RunResult>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public ValidationOperator.RunResult call() throws Exception {
                    ExampleSet trainSet;
                    SplittedExampleSet splittedExampleSet = splittedSet;
                    synchronized (splittedExampleSet) {
                        splittedSet.selectAllSubsets();
                        trainSet = (ExampleSet)CrossValidationOperator.this.getDataCopy((IOObject)splittedSet);
                    }
                    Pair result = copy.train(trainSet, arbitraryInputObjects);
                    return new ValidationOperator.RunResult(null, (List)result.getSecond(), (Model)result.getFirst(), null);
                }
            });
            tasks.add(task);
        }
        List results = ConcurrencyExecutionServiceProvider.INSTANCE.getService().executeOperatorTasks((Operator)this, tasks);
        PerformanceVector performanceResult = null;
        LinkedList<ExampleSet> resultSets = new LinkedList<ExampleSet>();
        for (ValidationOperator.RunResult result : results) {
            if (result.model != null) {
                this.modelOutput.deliver((IOObject)result.model);
            }
            if (result.results != null) {
                this.resultExtender.deliver(result.results);
            }
            if (result.performance != null) {
                if (performanceResult == null) {
                    performanceResult = result.performance;
                } else {
                    performanceResult.buildAverages((AverageVector)result.performance);
                }
            }
            if (result.testSet == null) continue;
            resultSets.add(result.testSet);
        }
        this.rememberLoggingValues(performanceResult);
        if (this.testResultSetOutput.isConnected() && this.testResultSetInnerInput.isConnected()) {
            this.testResultSetOutput.deliver((IOObject)ExampleSetAppender.merge(resultSets, (Operator)this));
        }
        this.exampleSetOutput.deliver((IOObject)set);
        this.performanceOutput.deliver(performanceResult);
    }

    private Pair<Model, List<IOObject>> train(ExampleSet trainSet, List<IOObject> arbitraryInputObjects) throws OperatorException {
        this.trainingSetInnerOutput.deliver((IOObject)trainSet);
        this.inputExtender.deliver(arbitraryInputObjects);
        this.getSubprocess(0).execute();
        return new Pair((Object)this.modelInnerInput.getData(Model.class), this.resultExtender.getDataOrNull(IOObject.class));
    }

    private Pair<PerformanceVector, ExampleSet> test(ExampleSet testSet, Model model, List<IOObject> resultObjects) throws OperatorException {
        this.testSetInnerOutput.deliver((IOObject)testSet);
        this.modelInnerOutput.deliver((IOObject)model);
        this.resultExtender.deliver(resultObjects);
        this.getSubprocess(1).execute();
        return new Pair((Object)this.performanceInnerInput.getData(PerformanceVector.class), (Object)this.testResultSetInnerInput.getDataOrNull(ExampleSet.class));
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        LinkedList<ParameterType> types = new LinkedList<ParameterType>();
        types.add((ParameterType)new ParameterTypeBoolean(PARAMETER_LEAVE_ONE_OUT, "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false, false));
        ParameterTypeInt type = new ParameterTypeInt(PARAMETER_NUMBER_OF_VALIDATIONS, "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10);
        type.registerDependencyCondition((ParameterCondition)new BooleanParameterCondition((ParameterHandler)this, PARAMETER_LEAVE_ONE_OUT, false, false));
        type.setExpert(false);
        types.add((ParameterType)type);
        type = new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 3);
        type.setExpert(false);
        type.registerDependencyCondition((ParameterCondition)new BooleanParameterCondition((ParameterHandler)this, PARAMETER_LEAVE_ONE_OUT, false, false));
        types.add((ParameterType)type);
        for (ParameterType addType : RandomGenerator.getRandomGeneratorParameters((Operator)this)) {
            addType.registerDependencyCondition((ParameterCondition)new BooleanParameterCondition((ParameterHandler)this, PARAMETER_LEAVE_ONE_OUT, false, false));
            addType.registerDependencyCondition((ParameterCondition)new EqualTypeCondition((ParameterHandler)this, PARAMETER_SAMPLING_TYPE, SplittedExampleSet.SAMPLING_NAMES, false, new int[]{1, 2}));
            types.add(addType);
        }
        List<ParameterType> superTypes = super.getParameterTypes();
        types.addAll(superTypes);
        type = superTypes.get(0);
        if (types.remove(type)) {
            types.add(0, (ParameterType)type);
        }
        return types;
    }

    @Override
    public ProductInformation getProductInformation() {
        return PluginInitJackhammerExtension.PRODUCT_INFORMATION;
    }
}

