package com.rapidminer.extension.interpretation.algorithm;

import com.rapidminer.adaption.belt.ContextAdapter;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NominalBuffer;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.column.ColumnType;
import com.rapidminer.belt.execution.Context;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.ColumnSelector;
import com.rapidminer.belt.table.MixedRowWriter;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.table.TableViewCreator;
import com.rapidminer.belt.table.Tables;
import com.rapidminer.belt.table.Writers;
import com.rapidminer.belt.util.ColumnRole;
import com.rapidminer.core.concurrency.ConcurrencyContext;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.set.ExampleSetUtilities;
import com.rapidminer.extension.interpretation.utility.BeltUtilities;
import com.rapidminer.extension.modelsimulator.operator.scoring.ExplainPredictionsIOObject;
import com.rapidminer.extension.modelsimulator.tools.KeyAndValue;
import com.rapidminer.operator.AbstractModel;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorProgress;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.IOTablePredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.belt.BeltTools;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/LocalInterpretationAlgorithm.class */
public abstract class LocalInterpretationAlgorithm {
    protected ConcurrencyContext concurrencyContext;
    protected Context beltContext;
    protected OperatorProgress progress;
    protected List<List<KeyAndValue>> localWeights;
    protected boolean initialized;
    protected AbstractModel model;
    protected IOTablePredictionModel ioTablePredictionModel;
    protected Boolean isIOTablePredictionModel;
    protected Boolean isRegression;
    protected Table storedTestingTable;
    protected List<Column.TypeId> types;
    protected List<String> names;
    protected List<Column> testingColumns;
    protected List<Column> trainingColums;
    protected Column predictionColumn;
    protected String predictionName;
    protected int localSampleSize;
    protected int numberOfExplainingAttributes;
    protected Operator executingOperator;
    protected RandomGenerator randomGenerator;
    public static String INTERPRETATION_COLUMN_NAME = "interpretation";

    /* JADX INFO: Access modifiers changed from: protected */
    public LocalInterpretationAlgorithm(ConcurrencyContext concurrencyContext, Operator operator) {
        this.isIOTablePredictionModel = false;
        this.isRegression = true;
        this.predictionName = null;
        this.localSampleSize = 100;
        this.numberOfExplainingAttributes = 3;
        this.concurrencyContext = concurrencyContext;
        this.beltContext = ContextAdapter.adapt(concurrencyContext);
        this.progress = operator.getProgress();
        this.executingOperator = operator;
        this.initialized = false;
    }

    public LocalInterpretationAlgorithm(ConcurrencyContext concurrencyContext, Operator operator, Table table, Table table2, IOObject iOObject) throws OperatorException {
        this(concurrencyContext, operator);
        if (iOObject instanceof AbstractModel) {
            this.model = (AbstractModel) iOObject;
            this.isRegression = Boolean.valueOf(BeltUtilities.isRegressionProblem(this.model, this.executingOperator));
        } else {
            if (!(iOObject instanceof IOTablePredictionModel)) {
                throw new OperatorException("Cannot handle Model type: " + iOObject.getClass().getName());
            }
            this.ioTablePredictionModel = (IOTablePredictionModel) iOObject;
            this.isIOTablePredictionModel = true;
            if (this.ioTablePredictionModel.getTrainingHeader().getTable().select().withMetaData(ColumnRole.LABEL).columns().size() == 0) {
                this.isRegression = true;
            } else if (((Column) this.ioTablePredictionModel.getTrainingHeader().getTable().select().withMetaData(ColumnRole.LABEL).columns().get(0)).type().equals(ColumnType.NOMINAL)) {
                this.isRegression = false;
            } else {
                this.isRegression = true;
            }
        }
        init(table, table2);
    }

    public abstract void calculate(Table table, Table table2) throws OperatorException;

    public ExplainPredictionsIOObject getExplainPredictionsIOObject() {
        return new ExplainPredictionsIOObject(TableViewCreator.INSTANCE.convertOnWriteView(new IOTable(this.storedTestingTable), false), this.localWeights);
    }

    public AttributeWeights getGlobalWeights() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.localWeights.size(); i++) {
            for (KeyAndValue keyAndValue : this.localWeights.get(i)) {
                String key = keyAndValue.getKey();
                hashMap.put(keyAndValue.getKey(), Double.valueOf((hashMap.containsKey(key) ? ((Double) hashMap.get(key)).doubleValue() : 0.0d) + keyAndValue.getValue()));
            }
        }
        AttributeWeights attributeWeights = new AttributeWeights();
        for (String str : hashMap.keySet()) {
            attributeWeights.setWeight(str, ((Double) hashMap.get(str)).doubleValue() / this.localWeights.size());
        }
        if (this.predictionName != null) {
            attributeWeights.getAnnotations().put("Prediction Column: ", this.predictionName);
        }
        return attributeWeights;
    }

    public IOTable getInterpretationDetails() {
        MixedRowWriter mixedRowWriter = Writers.mixedRowWriter(Arrays.asList("Row No", "Name", "Importance"), Arrays.asList(Column.TypeId.INTEGER_53_BIT, Column.TypeId.NOMINAL, Column.TypeId.REAL), false);
        for (int i = 0; i < this.localWeights.size(); i++) {
            for (int i2 = 0; i2 < this.localWeights.get(i).size(); i2++) {
                mixedRowWriter.move();
                KeyAndValue keyAndValue = this.localWeights.get(i).get(i2);
                mixedRowWriter.set(0, i);
                mixedRowWriter.set(1, keyAndValue.getKey());
                mixedRowWriter.set(2, this.localWeights.get(i).get(i2).getValue());
            }
        }
        return new IOTable(mixedRowWriter.create());
    }

    public IOTable getExplainedTable() {
        NominalBuffer nominalBuffer = Buffers.nominalBuffer(this.localWeights.size());
        for (int i = 0; i < this.localWeights.size(); i++) {
            StringBuilder sb = new StringBuilder();
            List<KeyAndValue> list = this.localWeights.get(i);
            Collections.sort(list);
            DecimalFormat decimalFormat = new DecimalFormat("#.###");
            for (int i2 = 0; i2 < Math.min(this.numberOfExplainingAttributes, list.size()); i2++) {
                KeyAndValue keyAndValue = list.get(i2);
                sb.append(keyAndValue.getKey() + ": " + decimalFormat.format(keyAndValue.getValue()) + "; ");
            }
            nominalBuffer.set(i, sb.toString());
        }
        return new IOTable(Builders.newTableBuilder(this.storedTestingTable).add(INTERPRETATION_COLUMN_NAME, nominalBuffer.toColumn()).addMetaData(INTERPRETATION_COLUMN_NAME, ColumnRole.METADATA).addNominal("Interpreted Column", i3 -> {
            return this.predictionName;
        }).addMetaData("Interpreted Column", ColumnRole.METADATA).build(this.beltContext));
    }

    public void init(Table table, Table table2) throws OperatorException {
        this.storedTestingTable = table2;
        ColumnSelector selectRegularColumns = BeltTools.selectRegularColumns(table);
        ColumnSelector selectRegularColumns2 = BeltTools.selectRegularColumns(table2);
        this.trainingColums = selectRegularColumns.columns();
        this.names = selectRegularColumns.labels();
        if (!this.names.equals(selectRegularColumns2.labels())) {
            table2 = Tables.adapt(table2, table, Tables.ColumnHandling.REORDER, Tables.DictionaryHandling.UNCHANGED);
            selectRegularColumns2 = BeltTools.selectRegularColumns(table2);
        }
        this.testingColumns = selectRegularColumns2.columns();
        if (!this.isIOTablePredictionModel.booleanValue()) {
            BeltUtilities.compareTableToHeader(table, this.model, ExampleSetUtilities.SetsCompareOption.EQUAL, ExampleSetUtilities.TypesCompareOption.ALLOW_SAME_PARENTS, this.executingOperator);
        }
        BeltUtilities.compareTables(table, table2, ExampleSetUtilities.SetsCompareOption.ALLOW_SUPERSET, ExampleSetUtilities.TypesCompareOption.ALLOW_SUPERTYPES, this.executingOperator);
        this.types = (List) this.trainingColums.stream().map(column -> {
            return column.type().id();
        }).collect(Collectors.toList());
        this.localWeights = new ArrayList(table.height());
        for (int i = 0; i < table2.height(); i++) {
            this.localWeights.add(new ArrayList(this.names.size()));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Table applyModelToTable(Table table, ConcurrencyContext concurrencyContext) throws OperatorException {
        return this.isIOTablePredictionModel.booleanValue() ? this.ioTablePredictionModel.apply(new IOTable(table), this.executingOperator).getTable() : BeltUtilities.applyModelToTable(this.model, table, concurrencyContext);
    }

    public static List<ParameterType> getListOfParameters() {
        return new ArrayList();
    }

    protected boolean isInitialized() {
        return this.initialized;
    }

    public abstract void checkCapability(Operator operator, AbstractModel abstractModel, Table table) throws UserError;

    public abstract String getName();

    public int getLocalSampleSize() {
        return this.localSampleSize;
    }

    public void setLocalSampleSize(int i) {
        this.localSampleSize = i;
    }

    public int getNumberOfExplainingAttributes() {
        return this.numberOfExplainingAttributes;
    }

    public void setNumberOfExplainingAttributes(int i) {
        this.numberOfExplainingAttributes = i;
    }

    public RandomGenerator getRandomGenerator() {
        return this.randomGenerator;
    }

    public void setRandomGenerator(RandomGenerator randomGenerator) {
        this.randomGenerator = randomGenerator;
    }
}
