package com.rapidminer.extension.operator.clustering;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.ExampleSetUtilities;
import com.rapidminer.example.set.RemappedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.extension.exceptions.MissingValueException;
import com.rapidminer.extension.utility.SmileHelper;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.clustering.ClusterModel;
import com.rapidminer.operator.preprocessing.MaterializeDataInMemory;
import com.rapidminer.operator.preprocessing.normalization.AbstractNormalizationModel;
import com.rapidminer.operator.preprocessing.normalization.MinMaxNormalizationModel;
import com.rapidminer.tools.container.Tupel;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.stat.distribution.MultivariateGaussianMixture;
import smile.stat.distribution.MultivariateMixture;

/* loaded from: input_file:com/rapidminer/extension/operator/clustering/MultivariateGaussianMixtureModel.class */
public class MultivariateGaussianMixtureModel extends ClusterModel implements Serializable {
    private static final long serialVersionUID = 5834394253003990127L;
    public static final String LOG_NEGATIVE_LIKELIHOOD = "negative_log_likelihood";
    public static final String LIKELIHOOD = "likelihood";
    public static final String INVERTED_LIKELIHOOD = "inverted_likelihood";
    public static final String ATTRIBUTE_NAME = "score";
    public static final String ATTRIBUTE_ROLE = "score";
    private final String usedScoringMethod;
    private final boolean useDiagonal;
    private int numberOfComponents;
    private final boolean useBicOptimization;
    private final boolean addScore;
    private final boolean addDetailedScores;
    private final boolean normalize;
    private List<double[]> means;
    private List<double[][]> covs;
    private List<Double> prioris;
    private double BIC;
    private AbstractNormalizationModel normalizationModel;
    private String descriptionString;

    /* JADX INFO: Access modifiers changed from: protected */
    public MultivariateGaussianMixtureModel(ExampleSet exampleSet, int i, boolean z, boolean z2, boolean z3, boolean z4, String str, boolean z5) {
        super(exampleSet, i, false, false);
        this.BIC = Double.NaN;
        this.addScore = z3;
        this.addDetailedScores = z4;
        this.numberOfComponents = i;
        this.useDiagonal = z;
        this.useBicOptimization = z2;
        this.usedScoringMethod = str;
        this.normalize = z5;
    }

    public void fit(ExampleSet exampleSet) throws OperatorException {
        if (this.normalize) {
            this.normalizationModel = fitNormalizationModel(exampleSet);
            exampleSet = this.normalizationModel.apply(exampleSet);
        }
        try {
            double[][] exampleSetToDoubleArray = SmileHelper.exampleSetToDoubleArray(exampleSet, getTrainingAttributes(), true);
            MultivariateGaussianMixture multivariateGaussianMixture = this.useBicOptimization ? new MultivariateGaussianMixture(exampleSetToDoubleArray) : new MultivariateGaussianMixture(exampleSetToDoubleArray, this.numberOfComponents, this.useDiagonal);
            this.means = new ArrayList();
            this.covs = new ArrayList();
            this.prioris = new ArrayList();
            for (MultivariateMixture.Component component : multivariateGaussianMixture.getComponents()) {
                this.means.add(component.distribution.mean());
                this.covs.add(component.distribution.cov());
                this.prioris.add(Double.valueOf(component.priori));
            }
            this.numberOfComponents = multivariateGaussianMixture.getComponents().size();
            this.descriptionString = multivariateGaussianMixture.toString();
            this.BIC = multivariateGaussianMixture.bic(exampleSetToDoubleArray);
        } catch (MissingValueException e) {
            throw new OperatorException("The data contains missing values. Those are not allowed.");
        }
    }

    public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
        ExampleSetUtilities.checkAttributesMatching(getOperator(), getTrainingHeader().getAttributes(), exampleSet.getAttributes(), ExampleSetUtilities.SetsCompareOption.ALLOW_SUPERSET, ExampleSetUtilities.TypesCompareOption.ALLOW_SUPERTYPES);
        ExampleSet<Example> materializeExampleSet = MaterializeDataInMemory.materializeExampleSet(RemappedExampleSet.create(exampleSet, getTrainingHeader(), true, true));
        double[][] exampleSetToDoubleArray = this.normalize ? SmileHelper.exampleSetToDoubleArray(this.normalizationModel.apply(materializeExampleSet), getTrainingAttributes()) : SmileHelper.exampleSetToDoubleArray(materializeExampleSet, getTrainingAttributes());
        MultivariateGaussianMixture reconstruct = reconstruct();
        Attribute addClusterAttribute = addClusterAttribute(materializeExampleSet);
        Attribute addScoreAttribute = this.addScore ? addScoreAttribute(materializeExampleSet) : null;
        List<Attribute> addConfidenceAttributes = addConfidenceAttributes(materializeExampleSet, this.numberOfComponents);
        List<Attribute> addIndividualScoreAttributes = this.addDetailedScores ? addIndividualScoreAttributes(materializeExampleSet, this.numberOfComponents) : null;
        int i = 0;
        for (Example example : materializeExampleSet) {
            double logp = reconstruct.logp(exampleSetToDoubleArray[i]);
            if (this.addScore) {
                example.setValue(addScoreAttribute, convertLogPToScore(logp));
            }
            double[] dArr = new double[this.numberOfComponents];
            double d = 0.0d;
            int i2 = 0;
            for (int i3 = 0; i3 < this.numberOfComponents; i3++) {
                MultivariateMixture.Component component = reconstruct.getComponents().get(i3);
                double logp2 = component.distribution.logp(exampleSetToDoubleArray[i]);
                if (this.addScore && this.addDetailedScores) {
                    example.setValue(addIndividualScoreAttributes.get(i3), convertLogPToScore(logp2));
                }
                if (logp2 > d) {
                    d = logp2;
                    i2 = i3;
                }
                dArr[i3] = logp2 + Math.log(component.priori);
            }
            example.setValue(addClusterAttribute, "cluster_" + i2);
            double[] normalizeScores = normalizeScores(dArr);
            for (int i4 = 0; i4 < this.numberOfComponents; i4++) {
                example.setValue(addConfidenceAttributes.get(i4), normalizeScores[i4]);
            }
            i++;
        }
        return materializeExampleSet;
    }

    private double convertLogPToScore(double d) {
        String str = this.usedScoringMethod;
        boolean z = -1;
        switch (str.hashCode()) {
            case -1646810630:
                if (str.equals(INVERTED_LIKELIHOOD)) {
                    z = 2;
                    break;
                }
                break;
            case -414949776:
                if (str.equals(LIKELIHOOD)) {
                    z = true;
                    break;
                }
                break;
            case 356249493:
                if (str.equals(LOG_NEGATIVE_LIKELIHOOD)) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return (-1.0d) * d;
            case true:
                return Math.exp(d);
            case true:
                return 1.0d / Math.exp(d);
            default:
                return d;
        }
    }

    private MultivariateGaussianMixture reconstruct() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numberOfComponents; i++) {
            MultivariateGaussianDistribution multivariateGaussianDistribution = new MultivariateGaussianDistribution(this.means.get(i), this.covs.get(i));
            MultivariateMixture.Component component = new MultivariateMixture.Component();
            component.distribution = multivariateGaussianDistribution;
            component.priori = this.prioris.get(i).doubleValue();
            arrayList.add(component);
        }
        return new MultivariateGaussianMixture(arrayList);
    }

    public List<Attribute> getTrainingAttributes() {
        ArrayList arrayList = new ArrayList(getTrainingHeader().getAttributes().size());
        Iterator regularAttributes = getTrainingHeader().getAttributes().regularAttributes();
        while (regularAttributes.hasNext()) {
            arrayList.add(((AttributeRole) regularAttributes.next()).getAttribute());
        }
        return arrayList;
    }

    public Attribute addScoreAttribute(ExampleSet exampleSet) {
        Attributes attributes = exampleSet.getAttributes();
        if (exampleSet.getAttributes().get("attribute") != null) {
            return exampleSet.getAttributes().get("attribute");
        }
        Attribute createAttribute = AttributeFactory.createAttribute("score", 4);
        createAttribute.setTableIndex(attributes.size());
        AttributeRole attributeRole = new AttributeRole(createAttribute);
        attributeRole.setSpecial("score");
        attributes.add(attributeRole);
        exampleSet.getExampleTable().addAttribute(createAttribute);
        return createAttribute;
    }

    public Attribute addClusterAttribute(ExampleSet exampleSet) {
        Attributes attributes = exampleSet.getAttributes();
        if (exampleSet.getAttributes().get("cluster") != null) {
            return exampleSet.getAttributes().get("cluster");
        }
        Attribute createAttribute = AttributeFactory.createAttribute("cluster", 1);
        createAttribute.setTableIndex(attributes.size());
        attributes.setCluster(createAttribute);
        exampleSet.getExampleTable().addAttribute(createAttribute);
        return createAttribute;
    }

    public List<Attribute> addIndividualScoreAttributes(ExampleSet exampleSet, int i) {
        Attributes attributes = exampleSet.getAttributes();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            String str = "score(component_" + i2 + ")";
            if (exampleSet.getAttributes().get(str) == null) {
                Attribute createAttribute = AttributeFactory.createAttribute(str, 4);
                createAttribute.setTableIndex(attributes.size());
                AttributeRole attributeRole = new AttributeRole(createAttribute);
                attributeRole.setSpecial("scorecomponent_" + i2);
                attributes.add(attributeRole);
                exampleSet.getExampleTable().addAttribute(createAttribute);
                arrayList.add(createAttribute);
            } else {
                arrayList.add(exampleSet.getAttributes().get(str));
            }
        }
        return arrayList;
    }

    public List<Attribute> addConfidenceAttributes(ExampleSet exampleSet, int i) {
        Attributes attributes = exampleSet.getAttributes();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            String str = "confidence(cluster_" + i2 + ")";
            String str2 = "confidence_cluster_" + i2;
            if (exampleSet.getAttributes().get(str) == null) {
                Attribute createAttribute = AttributeFactory.createAttribute(str, 4);
                createAttribute.setTableIndex(attributes.size());
                AttributeRole attributeRole = new AttributeRole(createAttribute);
                attributeRole.setSpecial(str2);
                attributes.add(attributeRole);
                exampleSet.getExampleTable().addAttribute(createAttribute);
                arrayList.add(createAttribute);
            } else {
                arrayList.add(exampleSet.getAttributes().get(str));
            }
        }
        return arrayList;
    }

    public AbstractNormalizationModel fitNormalizationModel(ExampleSet exampleSet) {
        HashMap hashMap = new HashMap();
        exampleSet.recalculateAllAttributeStatistics();
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (attribute.isNumerical()) {
                hashMap.put(attribute.getName(), new Tupel(Double.valueOf(exampleSet.getStatistics(attribute, "minimum")), Double.valueOf(exampleSet.getStatistics(attribute, "maximum"))));
            }
        }
        return new MinMaxNormalizationModel(exampleSet, 0.0d, 1.0d, hashMap);
    }

    public double[] normalizeScores(double[] dArr) {
        double asDouble = Arrays.stream(dArr).max().getAsDouble();
        double d = 0.0d;
        for (double d2 : dArr) {
            d += Math.exp(d2 - asDouble);
        }
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i] - asDouble) / d;
        }
        return dArr2;
    }

    public int getNumberOfClusters() {
        return this.numberOfComponents;
    }

    public String toString() {
        return this.descriptionString;
    }

    public double getBIC() {
        return this.BIC;
    }
}
