package de.dwslab.toolbox.classification.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.ExecutionUnit;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.StackingModel;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.InputPortExtender;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPortExtender;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.GeneratePredictionModelTransformationRule;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeInt;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:de/dwslab/toolbox/classification/meta/XAbstractStacking.class */
public abstract class XAbstractStacking extends OperatorChain implements Learner {
    public static final String PARAMETER_NUMBER_OF_FOLDS = "number_of_folds";
    protected InputPort exampleSetInput;
    protected OutputPortExtender baseInputExtender;
    protected InputPortExtender baseModelExtender;
    protected OutputPort modelOutput;

    public XAbstractStacking(OperatorDescription operatorDescription, String... strArr) {
        super(operatorDescription, strArr);
        this.exampleSetInput = getInputPorts().createPort("training set", ExampleSet.class);
        this.baseInputExtender = new OutputPortExtender("training set", getBaseModelLearnerProcess().getInnerSources());
        this.baseModelExtender = new InputPortExtender("base model", getBaseModelLearnerProcess().getInnerSinks(), new PredictionModelMetaData(PredictionModel.class, new ExampleSetMetaData()), 2);
        this.modelOutput = getOutputPorts().createPort("model");
        this.baseInputExtender.start();
        this.baseModelExtender.start();
        getTransformer().addRule(this.baseInputExtender.makePassThroughRule(this.exampleSetInput));
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
        getTransformer().addRule(new GeneratePredictionModelTransformationRule(this.exampleSetInput, this.modelOutput, PredictionModel.class));
    }

    protected abstract String getModelName();

    protected abstract ExecutionUnit getBaseModelLearnerProcess();

    protected abstract Model getStackingModel(ExampleSet exampleSet) throws OperatorException;

    public abstract boolean keepOldAttributes();

    public void doWork() throws OperatorException {
        this.modelOutput.deliver(learn((ExampleSet) this.exampleSetInput.getData(ExampleSet.class)));
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMBER_OF_FOLDS);
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet2, parameterAsInt, 0, true, 0);
        if (!keepOldAttributes()) {
            exampleSet2.getAttributes().clearRegular();
        }
        LinkedList<Attribute> linkedList = new LinkedList();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < parameterAsInt; i++) {
            splittedExampleSet.selectAllSubsetsBut(i);
            this.baseInputExtender.deliverToAll(splittedExampleSet, false);
            getBaseModelLearnerProcess().execute();
            int i2 = 0;
            for (Model model : this.baseModelExtender.getData(Model.class, true)) {
                splittedExampleSet.selectSingleSubset(i);
                ExampleSet apply = model.apply(splittedExampleSet);
                Attribute predictedLabel = apply.getAttributes().getPredictedLabel();
                if (!hashMap.containsKey(Integer.valueOf(i2))) {
                    hashMap.put(Integer.valueOf(i2), new HashMap());
                }
                Map map = (Map) hashMap.get(Integer.valueOf(i2));
                for (int i3 = 0; i3 < apply.size(); i3++) {
                    map.put(Integer.valueOf(splittedExampleSet.getActualParentIndex(i3)), Double.valueOf(apply.getExample(i3).getValue(predictedLabel)));
                }
                i2++;
            }
        }
        splittedExampleSet.selectAllSubsets();
        for (int i4 = 0; i4 < hashMap.size(); i4++) {
            Map map2 = (Map) hashMap.get(Integer.valueOf(i4));
            Attribute attribute = (Attribute) exampleSet.getAttributes().getLabel().clone();
            attribute.setName("base_prediction" + i4);
            exampleSet2.getAttributes().addRegular(attribute);
            exampleSet2.getExampleTable().addAttribute(attribute);
            for (Map.Entry entry : map2.entrySet()) {
                exampleSet2.getExample(((Integer) entry.getKey()).intValue()).setValue(attribute, ((Double) entry.getValue()).doubleValue());
            }
            linkedList.add(attribute);
        }
        Model stackingModel = getStackingModel(exampleSet2);
        PredictionModel.removePredictedLabel(exampleSet2);
        for (Attribute attribute2 : linkedList) {
            exampleSet2.getAttributes().remove(attribute2);
            exampleSet2.getExampleTable().removeAttribute(attribute2);
        }
        this.baseInputExtender.deliverToAll(exampleSet, false);
        getBaseModelLearnerProcess().execute();
        return new StackingModel(exampleSet, getModelName(), this.baseModelExtender.getData(Model.class, true), stackingModel, keepOldAttributes());
    }

    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        throw new UserError(this, 912, new Object[]{getName(), "estimation of performance not supported."});
    }

    public AttributeWeights getWeights(ExampleSet exampleSet) throws OperatorException {
        throw new UserError(this, 916, new Object[]{getName(), "calculation of weights not supported."});
    }

    public boolean shouldEstimatePerformance() {
        return false;
    }

    public boolean shouldCalculateWeights() {
        return false;
    }

    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return true;
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMBER_OF_FOLDS, "Number of folds", 2, Integer.MAX_VALUE, 10));
        return parameterTypes;
    }
}
