package com.rapidminer.extension.interpretation.algorithm;

import com.rapidminer.Process;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.table.BeltConverter;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.table.TableViewCreator;
import com.rapidminer.core.concurrency.ConcurrencyContext;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.extension.modelsimulator.operator.scoring.ExplainPredictionsIOObject;
import com.rapidminer.extension.modelsimulator.operator.scoring.ExplainPredictionsOperator;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.OperatorService;

/* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/ExplainPredictionsWrapper.class */
public class ExplainPredictionsWrapper extends LocalInterpretationAlgorithm {
    private ExplainPredictionsIOObject explainPredictionsIOObject;
    private AttributeWeights globalWeights;
    private ExampleSet importances;
    private ExampleSet explainedTable;

    public ExplainPredictionsWrapper(ConcurrencyContext concurrencyContext, Operator operator, Table table, Table table2, PredictionModel predictionModel) throws OperatorException {
        super(concurrencyContext, operator, table, table2, predictionModel);
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void calculate(Table table, Table table2) throws OperatorException {
        try {
            ExplainPredictionsOperator createOperator = OperatorService.createOperator(ExplainPredictionsOperator.class);
            createOperator.setParameter("local sample size", Integer.toString(this.localSampleSize));
            createOperator.setParameter("apply maximum to importances output", Integer.toString(this.numberOfExplainingAttributes));
            Process process = new Process();
            process.getRootOperator().setParameter("logverbosity", "off");
            process.getRootOperator().getSubprocess(0).addOperator(createOperator);
            process.getRootOperator().getSubprocess(0).getInnerSources().getPortByIndex(0).connectTo(createOperator.getInputPorts().getPortByName("model"));
            process.getRootOperator().getSubprocess(0).getInnerSources().getPortByIndex(1).connectTo(createOperator.getInputPorts().getPortByName("training data"));
            process.getRootOperator().getSubprocess(0).getInnerSources().getPortByIndex(2).connectTo(createOperator.getInputPorts().getPortByName("test data"));
            createOperator.getOutputPorts().getPortByIndex(0).connectTo(process.getRootOperator().getSubprocess(0).getInnerSinks().getPortByIndex(0));
            createOperator.getOutputPorts().getPortByIndex(1).connectTo(process.getRootOperator().getSubprocess(0).getInnerSinks().getPortByIndex(1));
            createOperator.getOutputPorts().getPortByIndex(2).connectTo(process.getRootOperator().getSubprocess(0).getInnerSinks().getPortByIndex(2));
            createOperator.getOutputPorts().getPortByIndex(3).connectTo(process.getRootOperator().getSubprocess(0).getInnerSinks().getPortByIndex(3));
            try {
                IOContainer run = process.run(new IOContainer(new IOObject[]{this.model, TableViewCreator.INSTANCE.convertOnWriteView(new IOTable(table), false), TableViewCreator.INSTANCE.convertOnWriteView(new IOTable(table2), false)}));
                this.explainPredictionsIOObject = run.getIOObjects()[0];
                this.explainedTable = run.getIOObjects()[1];
                this.importances = run.getIOObjects()[2];
                this.globalWeights = run.getIOObjects()[3];
            } catch (OperatorException e) {
                throw new OperatorException("Error running Explain Predictions: " + e.getMessage(), e);
            } catch (UserError e2) {
                throw new OperatorException("Error running Explain Predictions: " + e2.getDetails(), e2);
            }
        } catch (OperatorCreationException e3) {
            throw new OperatorException(e3.getMessage());
        }
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public ExplainPredictionsIOObject getExplainPredictionsIOObject() {
        return this.explainPredictionsIOObject;
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public IOTable getInterpretationDetails() {
        return BeltConverter.convert(this.importances, this.concurrencyContext);
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public IOTable getExplainedTable() {
        Attribute createAttribute = AttributeFactory.createAttribute(INTERPRETATION_COLUMN_NAME, 1);
        createAttribute.setTableIndex(this.explainedTable.getAttributes().size());
        AttributeRole attributeRole = new AttributeRole(createAttribute);
        attributeRole.setSpecial(INTERPRETATION_COLUMN_NAME);
        this.explainedTable.getAttributes().add(attributeRole);
        this.explainedTable.getExampleTable().addAttribute(createAttribute);
        Attribute attribute = this.explainedTable.getAttributes().get("Support Prediction");
        Attribute attribute2 = this.explainedTable.getAttributes().get("Contradict Prediction");
        for (Example example : this.explainedTable) {
            example.setValue(createAttribute, example.getValueAsString(attribute) + " " + example.getValueAsString(attribute2));
        }
        this.explainedTable.getAttributes().remove(attribute);
        this.explainedTable.getAttributes().remove(attribute2);
        return BeltConverter.convert(this.explainedTable, this.concurrencyContext);
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public AttributeWeights getGlobalWeights() {
        return this.globalWeights;
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void checkCapability(Operator operator, PredictionModel predictionModel, Table table) throws UserError {
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public String getName() {
        return "Explain Predictions";
    }
}
