package com.rapidminer.extension.interpretation.algorithm;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.reader.MixedRowReader;
import com.rapidminer.belt.reader.Readers;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.MixedRowWriter;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.table.TableBuilder;
import com.rapidminer.belt.table.Writers;
import com.rapidminer.belt.transform.RowTransformer;
import com.rapidminer.belt.util.ColumnRole;
import com.rapidminer.core.concurrency.ConcurrencyContext;
import com.rapidminer.extension.interpretation.utility.BeltUtilities;
import com.rapidminer.extension.interpretation.utility.PermutationHelper;
import com.rapidminer.extension.modelsimulator.tools.KeyAndValue;
import com.rapidminer.operator.AbstractModel;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.parameter.ParameterType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/Shapley.class */
public class Shapley extends LocalInterpretationAlgorithm {
    public static final String PARAMETER_NUMBER_OF_PERMUTATIONS = "permutations";
    public IOObjectCollection<IOTable> intermediateTables;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.rapidminer.extension.interpretation.algorithm.Shapley$1, reason: invalid class name */
    /* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/Shapley$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$rapidminer$belt$column$Column$TypeId = new int[Column.TypeId.values().length];

        static {
            try {
                $SwitchMap$com$rapidminer$belt$column$Column$TypeId[Column.TypeId.REAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$rapidminer$belt$column$Column$TypeId[Column.TypeId.INTEGER_53_BIT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$rapidminer$belt$column$Column$TypeId[Column.TypeId.NOMINAL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    protected Shapley(ConcurrencyContext concurrencyContext, Operator operator) {
        super(concurrencyContext, operator);
        this.intermediateTables = new IOObjectCollection<>();
    }

    public Shapley(ConcurrencyContext concurrencyContext, Operator operator, Table table, Table table2, AbstractModel abstractModel) throws OperatorException {
        super(concurrencyContext, operator, table, table2, abstractModel);
        this.intermediateTables = new IOObjectCollection<>();
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void calculate(Table table, Table table2) throws OperatorException {
        Column column;
        this.progress.setTotal(this.testingColumns.size());
        for (int i = 0; i < this.testingColumns.size(); i++) {
            MixedRowReader mixedRowReader = Readers.mixedRowReader(this.testingColumns);
            int i2 = 0;
            this.progress.step();
            while (mixedRowReader.hasRemaining()) {
                mixedRowReader.move();
                Table permutateAttributes = permutateAttributes(table, PermutationHelper.getBootstrapedTable(table, this.names, this.localSampleSize, this.randomGenerator, this.beltContext), Writers.mixedRowWriter(this.names, this.types, this.localSampleSize, false), mixedRowReader, i);
                Table columnOfInterestToOriginal = setColumnOfInterestToOriginal(permutateAttributes, mixedRowReader, i);
                Table applyModelToTable = BeltUtilities.applyModelToTable(this.model, permutateAttributes, this.concurrencyContext);
                Table applyModelToTable2 = BeltUtilities.applyModelToTable(this.model, columnOfInterestToOriginal, this.concurrencyContext);
                if (BeltUtilities.isRegressionProblem(this.model, this.executingOperator)) {
                    this.predictionColumn = (Column) applyModelToTable.select().withMetaData(ColumnRole.PREDICTION).columns().get(0);
                    this.predictionName = (String) applyModelToTable.select().withMetaData(ColumnRole.PREDICTION).labels().get(0);
                    column = (Column) applyModelToTable2.select().withMetaData(ColumnRole.PREDICTION).columns().get(0);
                } else {
                    this.predictionColumn = (Column) applyModelToTable.select().withMetaData(ColumnRole.SCORE).columns().get(0);
                    this.predictionName = (String) applyModelToTable.select().withMetaData(ColumnRole.SCORE).labels().get(0);
                    column = (Column) applyModelToTable2.select().withMetaData(ColumnRole.SCORE).columns().get(0);
                }
                if (this.predictionColumn.type().category() == Column.Category.NUMERIC) {
                    this.localWeights.get(i2).add(new KeyAndValue(this.names.get(i), ((double[]) new RowTransformer(Arrays.asList(this.predictionColumn, column)).reduceNumeric(() -> {
                        return new double[1];
                    }, (dArr, numericRow) -> {
                        dArr[0] = dArr[0] + (numericRow.get(0) - numericRow.get(1));
                    }, (dArr2, dArr3) -> {
                        dArr2[0] = dArr2[0] + dArr3[0];
                    }, this.beltContext))[0] / this.localSampleSize, false, true));
                }
                i2++;
            }
        }
    }

    private Set<Integer> getRandomColumnIndices(int i) {
        int nextIntInRange = this.randomGenerator.nextIntInRange(1, this.trainingColums.size());
        Set<Integer> nextIntSetWithRange = this.randomGenerator.nextIntSetWithRange(0, this.trainingColums.size(), nextIntInRange);
        while (true) {
            Set<Integer> set = nextIntSetWithRange;
            if (set.contains(Integer.valueOf(i))) {
                set.add(Integer.valueOf(i));
                return set;
            }
            nextIntSetWithRange = this.randomGenerator.nextIntSetWithRange(0, this.trainingColums.size(), nextIntInRange);
        }
    }

    private Table permutateAttributes(Table table, MixedRowReader mixedRowReader, MixedRowWriter mixedRowWriter, MixedRowReader mixedRowReader2, int i) {
        while (mixedRowReader.hasRemaining()) {
            mixedRowReader.move();
            Set<Integer> randomColumnIndices = getRandomColumnIndices(i);
            mixedRowWriter.move();
            for (int i2 = 0; i2 < mixedRowReader2.width(); i2++) {
                if (randomColumnIndices.contains(Integer.valueOf(i2))) {
                    if (table.column(i2).type().category() == Column.Category.NUMERIC) {
                        mixedRowWriter.set(i2, mixedRowReader.getNumeric(i2));
                    } else {
                        mixedRowWriter.set(i2, mixedRowReader.getObject(i2));
                    }
                } else if (table.column(i2).type().category() == Column.Category.NUMERIC) {
                    mixedRowWriter.set(i2, mixedRowReader2.getNumeric(i2));
                } else {
                    mixedRowWriter.set(i2, mixedRowReader2.getObject(i2));
                }
            }
        }
        return mixedRowWriter.create();
    }

    private Table setColumnOfInterestToOriginal(Table table, MixedRowReader mixedRowReader, int i) throws UserError {
        TableBuilder newTableBuilder = Builders.newTableBuilder(table);
        String str = this.names.get(i);
        switch (AnonymousClass1.$SwitchMap$com$rapidminer$belt$column$Column$TypeId[this.types.get(i).ordinal()]) {
            case 1:
                double numeric = mixedRowReader.getNumeric(i);
                newTableBuilder.replaceReal(str, i2 -> {
                    return numeric;
                });
                break;
            case 2:
                double numeric2 = mixedRowReader.getNumeric(i);
                newTableBuilder.replaceInt53Bit(str, i3 -> {
                    return numeric2;
                });
                break;
            case 3:
                String str2 = (String) mixedRowReader.getObject(i, String.class);
                newTableBuilder.replaceNominal(str, i4 -> {
                    return str2;
                }, 1);
                break;
            default:
                throw new UserError((Operator) null, "not.handled.type", new Object[]{this.names.get(i), this.types.get(i)});
        }
        return newTableBuilder.build(this.beltContext);
    }

    public static List<ParameterType> getListOfParameters() {
        return new ArrayList();
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public String getName() {
        return "Shapley";
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.LocalInterpretationAlgorithm
    public void checkCapability(Operator operator, AbstractModel abstractModel, Table table) throws UserError {
    }
}
