package com.rapidminer.extension.interpretation.algorithm;

import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.util.ColumnRole;
import com.rapidminer.core.concurrency.ConcurrencyContext;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.extension.interpretation.utility.BeltUtilities;
import com.rapidminer.extension.interpretation.utility.CorrelationUtilities;
import com.rapidminer.extension.interpretation.utility.MinMaxNormalizer;
import com.rapidminer.extension.interpretation.utility.NominalShuffler;
import com.rapidminer.extension.interpretation.utility.weightprovider.DecisionTreeWeightProvider;
import com.rapidminer.extension.interpretation.utility.weightprovider.GLMWeightProvider;
import com.rapidminer.extension.modelsimulator.tools.KeyAndValue;
import com.rapidminer.operator.AbstractModel;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/LIME.class */
public class LIME extends LocalInterpretationAlgorithm {
    double locality;
    boolean redrawEachIteration;
    public static final String PARAMETER_NUMBER_OF_DATA_POINTS = "number_of_data_points";
    public static final String PARAMETER_ALGORITHMS = "explanation_algorithm";
    public static final String PARAMETER_REDRRAW = "redraw_local_samples";
    public static final String PARAMETER_LOCALITY = "locality";
    public static String WEIGHT_NAME = "weight";
    private static String[] supportedAlgorithms = {"Correlation", "DecisionTree", "Linear Regression"};
    private String usedAlgorithm;

    public LIME(ConcurrencyContext concurrencyContext, Operator operator, Table table, Table table2, AbstractModel abstractModel) throws OperatorException {
        super(concurrencyContext, operator, table, table2, abstractModel);
        this.locality = 0.2d;
        this.redrawEachIteration = false;
        this.usedAlgorithm = "Correlation";
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void calculate(Table table, Table table2) throws OperatorException {
        MinMaxNormalizer minMaxNormalizer = new MinMaxNormalizer(table, this.beltContext);
        Table apply = minMaxNormalizer.apply(table2, this.beltContext);
        NominalShuffler nominalShuffler = new NominalShuffler(table, this.randomGenerator, this.beltContext);
        Table table3 = null;
        Table table4 = null;
        for (int i = 0; i < table2.height(); i++) {
            if (table3 == null || this.redrawEachIteration) {
                table3 = LIMENeighborhoodGenerator.getArtificalTable(this.types, this.names, nominalShuffler, this.localSampleSize, this.randomGenerator, this.beltContext);
                table4 = BeltUtilities.applyModelToTable(this.model, minMaxNormalizer.inverse_apply(table3, this.beltContext), this.concurrencyContext);
            }
            int[] iArr = new int[this.localSampleSize];
            for (int i2 = 0; i2 < this.localSampleSize; i2++) {
                iArr[i2] = i;
            }
            Table build = Builders.newTableBuilder(table4).add(WEIGHT_NAME, LIMENeighborhoodGenerator.getWeights(table3, apply.rows(iArr, this.beltContext), this.locality)).addMetaData(WEIGHT_NAME, ColumnRole.WEIGHT).build(this.beltContext);
            if (BeltUtilities.isRegressionProblem(this.model, this.executingOperator)) {
                this.predictionName = (String) build.select().withMetaData(ColumnRole.PREDICTION).labels().get(0);
            } else {
                this.predictionName = (String) build.select().withMetaData(ColumnRole.SCORE).labels().get(0);
            }
            if (this.usedAlgorithm.equals("Correlation")) {
                for (String str : this.names) {
                    this.localWeights.get(i).add(new KeyAndValue(str, CorrelationUtilities.getWeightedCorrelation(build, str, this.predictionName, WEIGHT_NAME), false, true));
                }
            } else {
                AttributeWeights calculate = (this.usedAlgorithm.equals("DecisionTree") ? new DecisionTreeWeightProvider() : new GLMWeightProvider()).calculate(build, this.predictionName, WEIGHT_NAME);
                for (String str2 : calculate.getAttributeNames()) {
                    this.localWeights.get(i).add(new KeyAndValue(str2, calculate.getWeight(str2), false, true));
                }
            }
        }
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void checkCapability(Operator operator, AbstractModel abstractModel, Table table) throws UserError {
    }

    public static List<ParameterType> getListOfParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ParameterTypeBoolean("redraw_local_samples", "if set to false we use one big set of artifical data points", true, true));
        arrayList.add(new ParameterTypeCategory(PARAMETER_ALGORITHMS, "Algorithm used to explain the prediction", supportedAlgorithms, 0, true));
        arrayList.add(new ParameterTypeDouble("locality", "defines how local the surrogate model will be. The lower the more local", 0.0d, Double.MAX_VALUE, 0.2d));
        return arrayList;
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public String getName() {
        return "LIME";
    }

    public void setRedrawEachIteration(boolean z) {
        this.redrawEachIteration = z;
    }

    public String getUsedAlgorithm() {
        return this.usedAlgorithm;
    }

    public void setUsedAlgorithm(String str) {
        this.usedAlgorithm = str;
    }

    public void setLocality(double d) {
        this.locality = d;
    }
}
