package com.rapidminer.extension.shapelet.operator;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.utils.ExampleSetBuilder;
import com.rapidminer.example.utils.ExampleSets;
import com.rapidminer.extension.shapelet.ioobject.Shapelet;
import com.rapidminer.extension.shapelet.ioobject.ShapeletModel;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.error.AttributeWrongTypeError;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.CollectionPrecondition;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.GenerateNewExampleSetMDRule;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

/* loaded from: input_file:com/rapidminer/extension/shapelet/operator/ShapeletTransformation.class */
public class ShapeletTransformation extends Operator {
    private InputPort collectionInputPort;
    private InputPort shapeletModelInputPort;
    private OutputPort featureVectorOutputPort;
    private OutputPort originalCollectionOutputPort;
    private OutputPort originalShapeletModelOutputPort;
    public static final String PARAMETER_METRIC = "metric";
    public static final String PARAMETER_AGGREGATION = "aggregation";

    /* loaded from: input_file:com/rapidminer/extension/shapelet/operator/ShapeletTransformation$AggregationType.class */
    public enum AggregationType {
        MIN("minimum"),
        MAX("maximum"),
        STATS("descriptive statistics");

        private String text;

        AggregationType(String str) {
            this.text = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.text;
        }
    }

    /* loaded from: input_file:com/rapidminer/extension/shapelet/operator/ShapeletTransformation$Metric.class */
    public enum Metric {
        SQEUCLIDEAN("squared euclidean"),
        HAMMING("hamming"),
        CROSS_CORRELATION("cross correlation");

        private String text;

        Metric(String str) {
            this.text = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.text;
        }
    }

    public ShapeletTransformation(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.collectionInputPort = getInputPorts().createPort("collection", IOObjectCollection.class);
        this.shapeletModelInputPort = getInputPorts().createPort("shapelet model", ShapeletModel.class);
        this.featureVectorOutputPort = getOutputPorts().createPort("feature vector");
        this.originalCollectionOutputPort = getOutputPorts().createPort("original collection");
        this.originalShapeletModelOutputPort = getOutputPorts().createPort("original shapelet model");
        this.collectionInputPort.addPrecondition(new CollectionPrecondition(new SimplePrecondition(this.collectionInputPort, new MetaData(ExampleSet.class))));
        getTransformer().addPassThroughRule(this.collectionInputPort, this.originalCollectionOutputPort);
        getTransformer().addPassThroughRule(this.shapeletModelInputPort, this.originalShapeletModelOutputPort);
        getTransformer().addGenerationRule(this.featureVectorOutputPort, ExampleSet.class);
        getTransformer().addRule(new GenerateNewExampleSetMDRule(this.featureVectorOutputPort) { // from class: com.rapidminer.extension.shapelet.operator.ShapeletTransformation.1
            public MetaData modifyMetaData(ExampleSetMetaData exampleSetMetaData) {
                exampleSetMetaData.addAttribute(new AttributeMetaData("Metal Unit ID", 1, "id"));
                return exampleSetMetaData;
            }
        });
    }

    public void doWork() throws OperatorException {
        IOObjectCollection data = this.collectionInputPort.getData(IOObjectCollection.class);
        getProgress().setTotal(data.size());
        this.originalCollectionOutputPort.deliver(data);
        ShapeletModel data2 = this.shapeletModelInputPort.getData(ShapeletModel.class);
        this.originalShapeletModelOutputPort.deliver(data2);
        Metric metric = getMetric(getParameterAsString(PARAMETER_METRIC));
        AggregationType aggregationType = getAggregationType(getParameterAsString(PARAMETER_AGGREGATION));
        Set<String> keys = data2.keys();
        LinkedList linkedList = new LinkedList();
        Attribute createAttribute = AttributeFactory.createAttribute("Metal Unit ID", 1);
        linkedList.add(createAttribute);
        addAttributesToList(linkedList, metric, aggregationType, keys);
        ExampleSetBuilder from = ExampleSets.from(linkedList);
        for (int i = 0; i < data.size(); i++) {
            if (i % (data.size() / 10.0d) == 0.0d) {
                getProgress().setCompleted(i);
            }
            try {
                ExampleSet exampleSet = (ExampleSet) data.getElement(i, false);
                double[] dArr = new double[linkedList.size()];
                dArr[0] = createAttribute.getMapping().mapString(String.valueOf(i));
                int i2 = 0 + 1;
                for (String str : keys) {
                    Shapelet shapelet = data2.getShapelet(str);
                    List<Double> values = data2.getValues(str);
                    Attribute attribute = exampleSet.getAttributes().get(shapelet.getAttributeName());
                    if (attribute == null) {
                        throw new UserError(this, "shapelet_extension.shapelet.attribute_not_in_example_set", new Object[]{Integer.valueOf(i), shapelet.getAttributeName()});
                    }
                    if (!attribute.isNumerical()) {
                        throw new AttributeWrongTypeError(this, attribute, new int[]{2});
                    }
                    double[] currentExampleData = getCurrentExampleData(exampleSet, attribute);
                    int size = values.size();
                    DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
                    for (int i3 = 0; i3 < (exampleSet.size() - size) + 1; i3++) {
                        descriptiveStatistics.addValue(getDistance(i3, currentExampleData, values, metric));
                    }
                    for (double d : getAggregationResults(descriptiveStatistics, aggregationType)) {
                        dArr[i2] = d;
                        i2++;
                    }
                }
                from.addRow(dArr);
            } catch (ClassCastException e) {
                throw new UserError(this, "shapelet_extension.shapelet.not_example_set", new Object[]{Integer.valueOf(i)});
            }
        }
        from.withRole(createAttribute, "id");
        this.featureVectorOutputPort.deliver(from.build());
        getProgress().complete();
    }

    private double[] getAggregationResults(DescriptiveStatistics descriptiveStatistics, AggregationType aggregationType) {
        if (descriptiveStatistics.getN() == 0) {
            if (aggregationType == AggregationType.MIN || aggregationType == AggregationType.MAX) {
                return new double[]{Double.NaN, Double.NaN};
            }
            if (aggregationType == AggregationType.STATS) {
                return new double[]{Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN};
            }
        }
        switch (aggregationType) {
            case MIN:
                double[] dArr = {Double.MAX_VALUE, -1.0d};
                for (int i = 0; i < descriptiveStatistics.getN(); i++) {
                    double element = descriptiveStatistics.getElement(i);
                    if (element < dArr[0]) {
                        dArr[0] = element;
                        dArr[1] = i;
                    }
                }
                return dArr;
            case MAX:
                double[] dArr2 = {-1.7976931348623157E308d, -1.0d};
                for (int i2 = 0; i2 < descriptiveStatistics.getN(); i2++) {
                    double element2 = descriptiveStatistics.getElement(i2);
                    if (element2 > dArr2[0]) {
                        dArr2[0] = element2;
                        dArr2[1] = i2;
                    }
                }
                return dArr2;
            case STATS:
                return new double[]{descriptiveStatistics.getMin(), descriptiveStatistics.getMax(), descriptiveStatistics.getMean(), descriptiveStatistics.getPercentile(50.0d), descriptiveStatistics.getStandardDeviation(), descriptiveStatistics.getSkewness(), descriptiveStatistics.getKurtosis()};
            default:
                throw new IllegalArgumentException("Provided aggregation method is not supported: " + aggregationType);
        }
    }

    private double getDistance(int i, double[] dArr, List<Double> list, Metric metric) {
        double d = 0.0d;
        for (int i2 = 0; i2 < list.size(); i2++) {
            switch (metric) {
                case SQEUCLIDEAN:
                    d += Math.pow(list.get(i2).doubleValue() - dArr[i + i2], 2.0d);
                    break;
                case HAMMING:
                    if (list.get(i2).doubleValue() != dArr[i + i2]) {
                        d += 1.0d / list.size();
                        break;
                    } else {
                        break;
                    }
                case CROSS_CORRELATION:
                    d += list.get(i2).doubleValue() * dArr[i + i2];
                    break;
            }
        }
        return d;
    }

    private List<Attribute> addAttributesToList(List<Attribute> list, Metric metric, AggregationType aggregationType, Set<String> set) {
        for (String str : set) {
            switch (aggregationType) {
                case MIN:
                    list.add(AttributeFactory.createAttribute("Minimum(" + metric + ") of " + str, 4));
                    list.add(AttributeFactory.createAttribute("Position Minimum(" + metric + ") of " + str, 3));
                    break;
                case MAX:
                    list.add(AttributeFactory.createAttribute(str + " Max " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " MaxPosition " + metric, 3));
                    break;
                case STATS:
                    list.add(AttributeFactory.createAttribute(str + " Min " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Max " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Mean " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Median " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Std " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Skew " + metric, 4));
                    list.add(AttributeFactory.createAttribute(str + " Kurt " + metric, 4));
                    break;
            }
        }
        return list;
    }

    private double[] getCurrentExampleData(ExampleSet exampleSet, Attribute attribute) {
        double[] dArr = new double[exampleSet.size()];
        int i = 0;
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            dArr[i] = ((Example) it.next()).getValue(attribute);
            i++;
        }
        return dArr;
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_METRIC, "Metric used for calculating the distance of the sub time series and the shapelet candidates.", (String[]) Arrays.stream(Metric.class.getEnumConstants()).map((v0) -> {
            return v0.toString();
        }).toArray(i -> {
            return new String[i];
        }), 0, false));
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_AGGREGATION, "Aggregation used for transforming the calculated distances into a feature vector.", (String[]) Arrays.stream(AggregationType.class.getEnumConstants()).map((v0) -> {
            return v0.toString();
        }).toArray(i2 -> {
            return new String[i2];
        }), 0, false));
        return parameterTypes;
    }

    private AggregationType getAggregationType(String str) {
        for (AggregationType aggregationType : AggregationType.values()) {
            if (str.equals(aggregationType.toString())) {
                return aggregationType;
            }
        }
        throw new IllegalArgumentException("Aggregation type for String " + str + " could not be found.");
    }

    private Metric getMetric(String str) {
        for (Metric metric : Metric.values()) {
            if (str.equals(metric.toString())) {
                return metric;
            }
        }
        throw new IllegalArgumentException("Metric for String " + str + " could not be found.");
    }
}
