package com.rapidminer.extension.operator.models;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.extension.Utility.ParameterReplacementProcessXMLFilter;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.tree.TreeModel;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.CollectionMetaData;
import com.rapidminer.operator.ports.metadata.MDTransformationRule;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/rapidminer/extension/operator/models/LocalInterpretationWithSubprocess.class */
public class LocalInterpretationWithSubprocess extends OperatorChain {
    public static final String PARAMETER_LOCALITY = "locality";
    public static final String PARAMETER_SAMPLE_SIZE = "sample_size";
    public static final String PARAMETER_WEIGHT_THRESHOLD = "weight_threshold";
    public static final String PARAMETER_NUMBER_ATTRIBUTES = "number_of_attributes";
    public static final String PARAMETER_LOCALITY_HEURISTICS = "use_locality_heuristics";
    private InputPort exaInput;
    private InputPort modInput;
    private final OutputPort innerExampleSource;
    private final InputPort innerWeightOutput;
    private final InputPort innerModelOutput;
    private final InputPort innerPerformanceOutput;
    private OutputPort exaOutput;
    private OutputPort modOutput;
    private OutputPort localWeightsOutput;
    private OutputPort localModelsOutputPort;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/rapidminer/extension/operator/models/LocalInterpretationWithSubprocess$WeightComparator.class */
    public class WeightComparator implements Comparator<String> {
        Map<String, Double> weightMap;

        WeightComparator(Map<String, Double> map) {
            this.weightMap = map;
        }

        @Override // java.util.Comparator
        public int compare(String str, String str2) {
            return Double.compare(Math.abs(this.weightMap.get(str).doubleValue()), Math.abs(this.weightMap.get(str2).doubleValue())) * (-1);
        }
    }

    public LocalInterpretationWithSubprocess(OperatorDescription operatorDescription) {
        this(operatorDescription, "Nested Process");
    }

    protected LocalInterpretationWithSubprocess(OperatorDescription operatorDescription, String str) {
        super(operatorDescription, new String[]{str});
        this.exaInput = getInputPorts().createPort("exa", ExampleSet.class);
        this.modInput = getInputPorts().createPort("mod", PredictionModel.class);
        this.innerExampleSource = getSubprocess(0).getInnerSources().createPort("training set");
        this.innerWeightOutput = getSubprocess(0).getInnerSinks().createPort("Weight Vector", AttributeWeights.class);
        this.innerModelOutput = getSubprocess(0).getInnerSinks().createPassThroughPort("Prediction Model");
        this.innerPerformanceOutput = getSubprocess(0).getInnerSinks().createPort("Performance Vector", PerformanceVector.class);
        this.exaOutput = getOutputPorts().createPort("exa");
        this.modOutput = getOutputPorts().createPassThroughPort("mod");
        this.localWeightsOutput = getOutputPorts().createPort("wei");
        this.localModelsOutputPort = getOutputPorts().createPort("loc");
        getTransformer().addPassThroughRule(this.modInput, this.modOutput);
        this.innerModelOutput.addPrecondition(new SimplePrecondition(this.innerModelOutput, new MetaData(PredictionModel.class)) { // from class: com.rapidminer.extension.operator.models.LocalInterpretationWithSubprocess.1
            protected boolean isMandatory() {
                return false;
            }
        });
        getTransformer().addRule(new MDTransformationRule() { // from class: com.rapidminer.extension.operator.models.LocalInterpretationWithSubprocess.2
            public void transformMD() {
                LocalInterpretationWithSubprocess.this.localWeightsOutput.deliverMD(new CollectionMetaData(new MetaData(AttributeWeights.class)));
            }
        });
        getTransformer().addRule(new MDTransformationRule() { // from class: com.rapidminer.extension.operator.models.LocalInterpretationWithSubprocess.3
            public void transformMD() {
                LocalInterpretationWithSubprocess.this.localModelsOutputPort.deliverMD(new CollectionMetaData(new MetaData(PredictionModel.class)));
            }
        });
    }

    public void doWork() throws OperatorException {
        ExampleSet<Example> exampleSet = (ExampleSet) this.exaInput.getData(ExampleSet.class);
        PredictionModel data = this.modInput.getData(PredictionModel.class);
        IOObjectCollection iOObjectCollection = new IOObjectCollection();
        IOObjectCollection iOObjectCollection2 = new IOObjectCollection();
        getProgress().setTotal(exampleSet.size());
        double parameterAsDouble = getParameterAsDouble(PARAMETER_LOCALITY);
        int parameterAsInt = getParameterAsInt(PARAMETER_SAMPLE_SIZE);
        double parameterAsDouble2 = getParameterAsDouble(PARAMETER_WEIGHT_THRESHOLD);
        int parameterAsInt2 = getParameterAsInt(PARAMETER_NUMBER_ATTRIBUTES);
        InterpretationHelper interpretationHelper = new InterpretationHelper(exampleSet, RandomGenerator.getRandomGenerator(getParameterAsBoolean("use_local_random_seed"), getParameterAsInt("local_random_seed")));
        interpretationHelper.storeMinMax(exampleSet);
        if (getParameterAsBoolean(PARAMETER_LOCALITY_HEURISTICS)) {
            interpretationHelper.setKernel_width(Math.sqrt(exampleSet.getAttributes().size() * 0.2d));
        } else {
            interpretationHelper.setKernel_width(parameterAsDouble);
        }
        GetDecTreePath getDecTreePath = new GetDecTreePath(getOperatorDescription());
        addImportanceAttributes(exampleSet, parameterAsInt2);
        ExampleSet normalize = interpretationHelper.normalize(data.apply(interpretationHelper.deNormalize(interpretationHelper.getRandom(parameterAsInt, exampleSet))));
        normalize.getAttributes().setLabel(normalize.getAttributes().getPredictedLabel());
        ExampleSet normalize2 = interpretationHelper.normalize(exampleSet);
        ExampleSet deNormalize = interpretationHelper.deNormalize(normalize);
        int i = 0;
        boolean z = true;
        Attribute attribute = null;
        Attribute addPerformanceAttribute = this.innerPerformanceOutput.isConnected() ? addPerformanceAttribute(exampleSet) : null;
        for (Example example : exampleSet) {
            deNormalize = interpretationHelper.addWeights(normalize, deNormalize, normalize2.getExample(i));
            this.innerExampleSource.deliver(interpretationHelper.filterForWeights(deNormalize, parameterAsDouble2));
            super.doWork();
            iOObjectCollection.add(this.innerWeightOutput.getData(AttributeWeights.class));
            if (this.innerModelOutput.isConnected()) {
                iOObjectCollection2.add(this.innerModelOutput.getData(PredictionModel.class));
                if (iOObjectCollection2.getObjects().get(0) instanceof TreeModel) {
                    if (z) {
                        attribute = addTreePathAttribute(exampleSet);
                        z = false;
                    }
                    example.setValue(attribute, getDecTreePath.getPath(example, this.innerModelOutput.getData(TreeModel.class).getRoot(), new StringBuffer(), true));
                }
            }
            if (this.innerPerformanceOutput.isConnected()) {
                example.setValue(addPerformanceAttribute, this.innerPerformanceOutput.getData(PerformanceVector.class).getMainCriterion().getAverage());
            }
            addImportances(example, (AttributeWeights) this.innerWeightOutput.getData(AttributeWeights.class), parameterAsInt2);
            deNormalize.getAttributes().remove(deNormalize.getAttributes().get("weight"));
            i++;
            getProgress().step();
        }
        if (this.innerModelOutput.isConnected()) {
            this.localModelsOutputPort.deliver(iOObjectCollection2);
        }
        this.localWeightsOutput.deliver(iOObjectCollection);
        this.exaOutput.deliver(exampleSet);
        this.modOutput.deliver(data);
        getProgress().complete();
    }

    private void addImportanceAttributes(ExampleSet exampleSet, int i) {
        Attributes attributes = exampleSet.getAttributes();
        for (int i2 = 0; i2 < i; i2++) {
            Attribute createAttribute = AttributeFactory.createAttribute("Important Attribute " + Integer.toString(i2), 1);
            createAttribute.setTableIndex(attributes.size());
            AttributeRole attributeRole = new AttributeRole(createAttribute);
            attributeRole.setSpecial("Important Attribute " + Integer.toString(i2));
            attributes.add(attributeRole);
            exampleSet.getExampleTable().addAttribute(createAttribute);
            Attribute createAttribute2 = AttributeFactory.createAttribute("Importance Attribute " + Integer.toString(i2), 4);
            createAttribute2.setTableIndex(attributes.size());
            AttributeRole attributeRole2 = new AttributeRole(createAttribute2);
            attributeRole2.setSpecial("Importance Attribute " + Integer.toString(i2));
            attributes.add(attributeRole2);
            exampleSet.getExampleTable().addAttribute(createAttribute2);
        }
    }

    private Attribute addTreePathAttribute(ExampleSet exampleSet) {
        Attributes attributes = exampleSet.getAttributes();
        Attribute createAttribute = AttributeFactory.createAttribute("Decision Tree Path", 1);
        createAttribute.setTableIndex(attributes.size());
        AttributeRole attributeRole = new AttributeRole(createAttribute);
        attributeRole.setSpecial("path");
        attributes.add(attributeRole);
        exampleSet.getExampleTable().addAttribute(createAttribute);
        return createAttribute;
    }

    private Attribute addPerformanceAttribute(ExampleSet exampleSet) {
        Attributes attributes = exampleSet.getAttributes();
        Attribute createAttribute = AttributeFactory.createAttribute("Performance", 4);
        createAttribute.setTableIndex(attributes.size());
        AttributeRole attributeRole = new AttributeRole(createAttribute);
        attributeRole.setSpecial("performance");
        attributes.add(attributeRole);
        exampleSet.getExampleTable().addAttribute(createAttribute);
        return createAttribute;
    }

    private void addImportances(Example example, AttributeWeights attributeWeights, int i) throws OperatorException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (String str : attributeWeights.getAttributeNames()) {
            linkedHashMap.put(str, Double.valueOf(attributeWeights.getWeight(str)));
        }
        ArrayList arrayList = new ArrayList(attributeWeights.getAttributeNames());
        Collections.sort(arrayList, new WeightComparator(linkedHashMap));
        int i2 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            String str2 = (String) it.next();
            if (i2 < i) {
                Attribute attribute = example.getAttributes().get("Important Attribute " + Integer.toString(i2));
                Attribute attribute2 = example.getAttributes().get("Importance Attribute " + Integer.toString(i2));
                example.setValue(attribute, str2);
                example.setValue(attribute2, attributeWeights.getWeight(str2));
            }
            i2++;
        }
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_LOCALITY_HEURISTICS, "Use Heuristics for Locality: sqrt(#atts)*0.2.", true));
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_LOCALITY, "The amount of locality used for the local models. Smaller means more local.", 0.0d, Double.MAX_VALUE, 0.3d);
        parameterTypeDouble.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LOCALITY_HEURISTICS, true, false));
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeInt(PARAMETER_SAMPLE_SIZE, "Number of examples to be drawn.", 0, Integer.MAX_VALUE, 1000));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUMBER_ATTRIBUTES, "Top k attributes to be added to the ExampleSet.", 1, Integer.MAX_VALUE, 3));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_WEIGHT_THRESHOLD, "Threshold on the weight (~Distance) to the point to get the interpretation for.", 0.0d, Double.MAX_VALUE, 1.0E-7d));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }

    static {
        ParameterReplacementProcessXMLFilter.registerReplacement(LocalInterpretationWithSubprocess.class, "Locality", PARAMETER_LOCALITY);
        ParameterReplacementProcessXMLFilter.registerReplacement(LocalInterpretationWithSubprocess.class, "Sample Size", PARAMETER_SAMPLE_SIZE);
        ParameterReplacementProcessXMLFilter.registerReplacement(LocalInterpretationWithSubprocess.class, "Weight Threshold", PARAMETER_WEIGHT_THRESHOLD);
        ParameterReplacementProcessXMLFilter.registerReplacement(LocalInterpretationWithSubprocess.class, "Number of Attributes", PARAMETER_NUMBER_ATTRIBUTES);
        ParameterReplacementProcessXMLFilter.registerReplacement(LocalInterpretationWithSubprocess.class, "Use Locality Heuristics", PARAMETER_LOCALITY_HEURISTICS);
    }
}
