package com.rapidminer.extension.interpretation.operator;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.table.Table;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.interpretation.algorithm.ExplainPredictionsWrapper;
import com.rapidminer.extension.interpretation.algorithm.KernelSHAP;
import com.rapidminer.extension.interpretation.algorithm.LIME;
import com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm;
import com.rapidminer.extension.interpretation.algorithm.Shapley;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.IncompatibleMDClassException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.conditions.EqualStringCondition;
import com.rapidminer.studio.internal.Resources;
import com.rapidminer.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/extension/interpretation/operator/GenerateInterpretation.class */
public class GenerateInterpretation extends Operator {
    InputPort modelInput;
    InputPort trainingInput;
    InputPort testInput;
    OutputPort exaOuput;
    OutputPort importancesOutput;
    OutputPort globalWeightsPort;
    OutputPort modelOutput;
    public static final String PARAMETER_ALGORITHM = "algorithm";
    private String[] supportedAlgorithms;
    public static final String PARAMETER_SAMPLE_SIZE = "sample_size";

    public GenerateInterpretation(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.modelInput = getInputPorts().createPort("mod", PredictionModel.class);
        this.trainingInput = getInputPorts().createPort("training", ExampleSet.class);
        this.testInput = getInputPorts().createPort("test", ExampleSet.class);
        this.exaOuput = getOutputPorts().createPort("example set");
        this.importancesOutput = getOutputPorts().createPort("importance");
        this.globalWeightsPort = getOutputPorts().createPort("global weights");
        this.modelOutput = getOutputPorts().createPassThroughPort("model");
        this.supportedAlgorithms = new String[]{"Explain Predictions", "KernelSHAP", "LIME", "Shapley"};
        try {
            getTransformer().addGenerationRule(this.globalWeightsPort, AttributeWeights.class);
            getTransformer().addPassThroughRule(this.modelInput, this.modelOutput);
            getTransformer().addRule(() -> {
                ExampleSetMetaData exampleSetMetaData = new ExampleSetMetaData();
                exampleSetMetaData.addAttribute(new AttributeMetaData("Row No", 3));
                exampleSetMetaData.addAttribute(new AttributeMetaData("Name", 7));
                exampleSetMetaData.addAttribute(new AttributeMetaData("Importance", 4));
                this.importancesOutput.deliverMD(exampleSetMetaData);
            });
            getTransformer().addRule(() -> {
                ExampleSetMetaData exampleSetMetaData;
                try {
                    exampleSetMetaData = this.testInput.getMetaData(ExampleSetMetaData.class);
                    AttributeMetaData attributeMetaData = new AttributeMetaData(LocalInterpretationAlgorithm.INTERPRETATION_COLUMN_NAME, 7);
                    attributeMetaData.setRole(LocalInterpretationAlgorithm.INTERPRETATION_COLUMN_NAME);
                    exampleSetMetaData.addAttribute(attributeMetaData);
                } catch (IncompatibleMDClassException e) {
                    exampleSetMetaData = new ExampleSetMetaData();
                }
                this.exaOuput.deliverMD(exampleSetMetaData);
            });
        } catch (NullPointerException e) {
        }
    }

    public void doWork() throws OperatorException {
        LocalInterpretationAlgorithm explainPredictionsWrapper;
        Table table = this.trainingInput.getData(IOTable.class).getTable();
        Table table2 = this.testInput.getData(IOTable.class).getTable();
        PredictionModel data = this.modelInput.getData(PredictionModel.class);
        if (data.getLabel().isNominal() && data.getLabel().getMapping().size() > 2) {
            throw new UserError(this, "interpretation.LabelisPolynominal", new Object[]{data.getLabel().getName()});
        }
        String parameterAsString = getParameterAsString(PARAMETER_ALGORITHM);
        boolean z = -1;
        switch (parameterAsString.hashCode()) {
            case -1071017183:
                if (parameterAsString.equals("KernelSHAP")) {
                    z = 2;
                    break;
                }
                break;
            case -576148932:
                if (parameterAsString.equals("Shapley")) {
                    z = false;
                    break;
                }
                break;
            case -287756325:
                if (parameterAsString.equals("Explain Predictions")) {
                    z = 3;
                    break;
                }
                break;
            case 2336725:
                if (parameterAsString.equals("LIME")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                explainPredictionsWrapper = new Shapley(Resources.getConcurrencyContext(this), this, table, table2, data);
                break;
            case true:
                explainPredictionsWrapper = new LIME(Resources.getConcurrencyContext(this), this, table, table2, data);
                ((LIME) explainPredictionsWrapper).setRedrawEachIteration(getParameterAsBoolean(LIME.PARAMETER_REDRRAW));
                ((LIME) explainPredictionsWrapper).setUsedAlgorithm(getParameterAsString(LIME.PARAMETER_ALGORITHMS));
                ((LIME) explainPredictionsWrapper).setLocality(getParameterAsDouble(LIME.PARAMETER_LOCALITY));
                break;
            case true:
                explainPredictionsWrapper = new KernelSHAP(Resources.getConcurrencyContext(this), this, table, table2, data);
                break;
            case true:
                explainPredictionsWrapper = new ExplainPredictionsWrapper(Resources.getConcurrencyContext(this), this, table, table2, data);
                break;
            default:
                throw new OperatorException("Not yet implemented");
        }
        explainPredictionsWrapper.setRandomGenerator(RandomGenerator.getRandomGenerator(this));
        explainPredictionsWrapper.setLocalSampleSize(getParameterAsInt(PARAMETER_SAMPLE_SIZE));
        explainPredictionsWrapper.checkCapability(this, data, table);
        explainPredictionsWrapper.calculate(table, table2);
        AttributeWeights globalWeights = explainPredictionsWrapper.getGlobalWeights();
        explainPredictionsWrapper.getExplainPredictionsIOObject();
        IOTable interpretationDetails = explainPredictionsWrapper.getInterpretationDetails();
        IOTable explainedTable = explainPredictionsWrapper.getExplainedTable();
        this.globalWeightsPort.deliver(globalWeights);
        this.importancesOutput.deliver(interpretationDetails);
        this.exaOuput.deliver(explainedTable);
        this.modelOutput.deliver(data);
    }

    public List<ParameterType> getParameterTypes() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ParameterTypeCategory(PARAMETER_ALGORITHM, "which algorithm to take", this.supportedAlgorithms, 3, false));
        arrayList.add(new ParameterTypeInt(PARAMETER_SAMPLE_SIZE, "number of rows drawn for each interpretation.", 1, Integer.MAX_VALUE, 100));
        List<ParameterType> listOfParameters = Shapley.getListOfParameters();
        arrayList.addAll(listOfParameters);
        Iterator<ParameterType> it = listOfParameters.iterator();
        while (it.hasNext()) {
            it.next().registerDependencyCondition(new EqualStringCondition(this, PARAMETER_ALGORITHM, false, new String[]{"Shapley"}));
        }
        List<ParameterType> listOfParameters2 = LIME.getListOfParameters();
        arrayList.addAll(listOfParameters2);
        Iterator<ParameterType> it2 = listOfParameters2.iterator();
        while (it2.hasNext()) {
            it2.next().registerDependencyCondition(new EqualStringCondition(this, PARAMETER_ALGORITHM, false, new String[]{"LIME"}));
        }
        arrayList.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return arrayList;
    }
}
