package com.rapidminer.operator.validation;

import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.PerformanceCriterion;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.PortPairExtender;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.AverageVector;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/validation/RandomSentenceSplitValidationChain.class */
public class RandomSentenceSplitValidationChain extends ValidationChain {
    public static final String PARAMETER_SPLIT_RATIO = "split_ratio";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_USE_LOCAL_RANDOM_SEED = "local_random_seed";
    private final InputPort trainingSetInput;
    private final OutputPort trainingProcessExampleSetOutput;
    private final InputPort trainingProcessModelInput;
    private final PortPairExtender throughExtender;
    private final OutputPort applyProcessModelOutput;
    private final OutputPort applyProcessExampleSetOutput;
    private final PortPairExtender applyProcessPerformancePortExtender;
    private final OutputPort modelOutput;
    private final OutputPort exampleSetOutput;
    private double lastMainPerformance;
    private double lastMainVariance;
    private double lastMainDeviation;
    private double lastFirstPerformance;
    private double lastSecondPerformance;
    private double lastThirdPerformance;

    public RandomSentenceSplitValidationChain(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.trainingSetInput = getInputPorts().getPortByName("training");
        this.trainingProcessExampleSetOutput = getSubprocess(0).getInnerSources().getPortByName("training");
        this.trainingProcessModelInput = getSubprocess(0).getInnerSinks().getPortByName("model");
        this.throughExtender = new PortPairExtender("through", getSubprocess(0).getInnerSinks(), getSubprocess(1).getInnerSources());
        this.applyProcessModelOutput = getSubprocess(1).getInnerSources().getPortByName("model");
        this.applyProcessExampleSetOutput = getSubprocess(1).getInnerSources().getPortByName("test set");
        this.applyProcessPerformancePortExtender = new PortPairExtender("averagable", getSubprocess(1).getInnerSinks(), getOutputPorts(), new MetaData(AverageVector.class));
        this.modelOutput = getOutputPorts().getPortByName("model");
        this.exampleSetOutput = getOutputPorts().getPortByName("training");
        this.lastMainPerformance = Double.NaN;
        this.lastMainVariance = Double.NaN;
        this.lastMainDeviation = Double.NaN;
        this.lastFirstPerformance = Double.NaN;
        this.lastSecondPerformance = Double.NaN;
        this.lastThirdPerformance = Double.NaN;
    }

    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.trainingSetInput.getData();
        estimatePerformance(exampleSet);
        if (this.modelOutput.isConnected()) {
            learn(exampleSet);
            this.modelOutput.deliver(this.trainingProcessModelInput.getData());
        }
        this.exampleSetOutput.deliver(exampleSet);
        boolean z = false;
        for (IOObject iOObject : this.applyProcessPerformancePortExtender.getData()) {
            if (iOObject instanceof PerformanceVector) {
                setResult((PerformanceVector) iOObject);
                z = true;
            }
        }
        if (z) {
            return;
        }
        getLogger().warning("No performance vector found among averagable results. Performance will not be loggable.");
    }

    public void estimatePerformance(ExampleSet exampleSet) throws OperatorException {
        Hashtable<String, Integer> hashtable;
        int nextIntInRange;
        int i;
        double parameterAsDouble = getParameterAsDouble(PARAMETER_SPLIT_RATIO);
        SplittedExampleSet splittedExampleSet = null;
        if (parameterAsDouble != 1.0d) {
            int size = (int) (parameterAsDouble * exampleSet.size());
            int i2 = -1;
            do {
                i2++;
                hashtable = new Hashtable<>();
                HashSet hashSet = new HashSet();
                RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
                for (int i3 = 0; i3 < size; i3++) {
                    while (true) {
                        i = nextIntInRange;
                        nextIntInRange = (i == -1 || hashSet.contains(Integer.valueOf(i))) ? randomGenerator.nextIntInRange(0, exampleSet.size() - 1) : -1;
                    }
                    hashSet.add(Integer.valueOf(i));
                }
                int[] iArr = new int[exampleSet.size()];
                Iterator it = exampleSet.iterator();
                int i4 = 0;
                while (it.hasNext()) {
                    Example example = (Example) it.next();
                    int intValue = new Integer(example.getValueAsString(example.getAttributes().getSpecial("batch"))).intValue();
                    if (1 != 0) {
                        String valueAsString = example.getValueAsString(example.getAttributes().getLabel());
                        if (valueAsString.contains("-")) {
                            valueAsString = valueAsString.substring(valueAsString.indexOf("-") + 1);
                        }
                        if (hashtable.containsKey(valueAsString)) {
                            hashtable.put(valueAsString, Integer.valueOf(hashtable.get(valueAsString).intValue() + 1));
                        } else {
                            hashtable.put(valueAsString, 1);
                        }
                    }
                    if (hashSet.contains(Integer.valueOf(intValue))) {
                        iArr[i4] = 0;
                    } else {
                        iArr[i4] = 1;
                    }
                    i4++;
                }
                splittedExampleSet = new SplittedExampleSet(exampleSet, new Partition(iArr, 2));
                splittedExampleSet.selectSingleSubset(0);
            } while (!isStratified(splittedExampleSet, hashtable, parameterAsDouble, i2));
            if (splittedExampleSet.size() == 0) {
                throw new UserError(this, 117);
            }
        }
        splittedExampleSet.selectSingleSubset(0);
        learn(splittedExampleSet);
        splittedExampleSet.selectSingleSubset(1);
        evaluate(splittedExampleSet);
        LinkedList linkedList = new LinkedList();
        PerformanceVector performanceVector = Tools.getPerformanceVector(linkedList);
        if (performanceVector != null) {
            setResult(performanceVector);
        }
        this.exampleSetOutput.deliver(performanceVector);
        linkedList.toArray(new AverageVector[linkedList.size()]);
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_SPLIT_RATIO, "Relative size of the training set", 0.0d, 1.0d, 0.7d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        parameterTypes.add(new ParameterTypeInt("local_random_seed", "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }

    private final void setResult(PerformanceVector performanceVector) {
        PerformanceCriterion criterion;
        PerformanceCriterion criterion2;
        PerformanceCriterion criterion3;
        this.lastMainPerformance = Double.NaN;
        this.lastMainVariance = Double.NaN;
        this.lastMainDeviation = Double.NaN;
        this.lastFirstPerformance = Double.NaN;
        this.lastSecondPerformance = Double.NaN;
        this.lastThirdPerformance = Double.NaN;
        if (performanceVector != null) {
            PerformanceCriterion mainCriterion = performanceVector.getMainCriterion();
            if (mainCriterion == null && performanceVector.size() > 0) {
                mainCriterion = performanceVector.getCriterion(0);
            }
            if (mainCriterion != null) {
                this.lastMainPerformance = mainCriterion.getAverage();
                this.lastMainVariance = mainCriterion.getVariance();
                this.lastMainDeviation = mainCriterion.getStandardDeviation();
            }
            if (performanceVector.size() >= 1 && (criterion3 = performanceVector.getCriterion(0)) != null) {
                this.lastFirstPerformance = criterion3.getAverage();
            }
            if (performanceVector.size() >= 2 && (criterion2 = performanceVector.getCriterion(1)) != null) {
                this.lastSecondPerformance = criterion2.getAverage();
            }
            if (performanceVector.size() < 3 || (criterion = performanceVector.getCriterion(2)) == null) {
                return;
            }
            this.lastThirdPerformance = criterion.getAverage();
        }
    }

    private boolean isStratified(ExampleSet exampleSet, Hashtable<String, Integer> hashtable, double d, int i) {
        Hashtable hashtable2 = new Hashtable();
        double d2 = d + (d / 100.0d);
        double d3 = d - (d / 100.0d);
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            String valueAsString = example.getValueAsString(example.getAttributes().getLabel());
            if (valueAsString.contains("-")) {
                valueAsString = valueAsString.substring(valueAsString.indexOf("-") + 1);
            }
            if (hashtable2.containsKey(valueAsString)) {
                hashtable2.put(valueAsString, Integer.valueOf(((Integer) hashtable2.get(valueAsString)).intValue() + 1));
            } else {
                hashtable2.put(valueAsString, 1);
            }
        }
        Enumeration<String> keys = hashtable.keys();
        while (keys.hasMoreElements()) {
            String nextElement = keys.nextElement();
            if (!hashtable2.containsKey(nextElement)) {
                return false;
            }
            int intValue = ((Integer) hashtable2.get(nextElement)).intValue();
            int intValue2 = hashtable.get(nextElement).intValue();
            if (Math.round(d3 * intValue2) - i > intValue || Math.round(d2 * intValue2) + i < intValue) {
                return false;
            }
        }
        hashtable.keys();
        return true;
    }

    protected MDInteger getTestSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return null;
    }

    protected MDInteger getTrainingSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return null;
    }

    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return false;
    }
}
