package com.rapidminer.ispr.operator.learner.classifiers;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.ispr.operator.learner.tools.KNNTools;
import com.rapidminer.ispr.operator.learner.tools.PRulesUtil;
import com.rapidminer.ispr.tools.math.container.ISPRGeometricDataCollection;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/ispr/operator/learner/classifiers/IS_KNNClassificationModel.class */
public class IS_KNNClassificationModel<T extends Serializable> extends PredictionModel {
    private static final long serialVersionUID = -6292869962412072573L;
    private final int k;
    private final int size;
    private int attributesNumber;
    private final ISPRGeometricDataCollection<T> samples;
    private boolean useCovariance;
    private HashMap<Integer, double[][]> covarianceMatrix;
    private final VotingType weightedNN;
    private PredictionType predictionType;
    private boolean generateID;
    private final List<String> trainingAttributeNames;

    public IS_KNNClassificationModel(ExampleSet exampleSet, ISPRGeometricDataCollection<T> iSPRGeometricDataCollection, int i, VotingType votingType, PredictionType predictionType) {
        super(exampleSet);
        this.generateID = false;
        this.k = i;
        this.size = exampleSet.size();
        this.samples = iSPRGeometricDataCollection;
        this.weightedNN = votingType;
        Attributes attributes = exampleSet.getAttributes();
        this.trainingAttributeNames = new ArrayList(attributes.size());
        Iterator it = attributes.iterator();
        while (it.hasNext()) {
            this.trainingAttributeNames.add(((Attribute) it.next()).getName());
        }
        this.useCovariance = false;
        this.predictionType = predictionType;
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute attribute) throws OperatorException {
        Attributes attributes = exampleSet.getAttributes();
        this.attributesNumber = attributes.size();
        double[] dArr = new double[this.attributesNumber];
        List<Attribute> reorderAttributesByName = PRulesUtil.reorderAttributesByName(attributes, this.trainingAttributeNames);
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            int i = 0;
            Iterator<Attribute> it2 = reorderAttributesByName.iterator();
            while (it2.hasNext()) {
                dArr[i] = example.getValue(it2.next());
                i++;
            }
            switch (this.predictionType) {
                case Classification:
                    double[] dArr2 = new double[attribute.getMapping().size()];
                    KNNTools.doNNVotes(dArr2, dArr, this.samples, this.k, this.weightedNN);
                    example.setValue(attribute, PRulesUtil.findMostFrequentValue(dArr2));
                    for (int i2 = 0; i2 < dArr2.length; i2++) {
                        example.setConfidence(attribute.getMapping().mapIndex(i2), dArr2[i2]);
                    }
                    break;
                case Clustering:
                    KNNTools.doNNVotes(new double[attribute.getMapping().size()], dArr, this.samples, this.k, this.weightedNN);
                    example.setValue(attribute, PRulesUtil.findMostFrequentValue(r0));
                    break;
                case Regression:
                    example.setValue(attribute, KNNTools.getRegVotes(dArr, this.samples, this.k, this.weightedNN));
                    break;
            }
        }
        return exampleSet;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.weightedNN.toString());
        sb.append(this.k).append("-Nearest Neighbour model for ").append(this.predictionType).append(".").append(Tools.getLineSeparator());
        sb.append("The model contains ").append(this.size).append(" examples with ").append(this.attributesNumber).append(" dimensions of the following classes:");
        sb.append(Tools.getLineSeparator());
        if (this.predictionType == PredictionType.Classification) {
            Iterator it = getTrainingHeader().getAttributes().getLabel().getMapping().getValues().iterator();
            while (it.hasNext()) {
                sb.append("  ").append((String) it.next()).append(Tools.getLineSeparator());
            }
        }
        if (this.predictionType == PredictionType.Clustering) {
            Iterator it2 = getTrainingHeader().getAttributes().getCluster().getMapping().getValues().iterator();
            while (it2.hasNext()) {
                sb.append("  ").append((String) it2.next()).append(Tools.getLineSeparator());
            }
        }
        return sb.toString();
    }

    public ISPRGeometricDataCollection<T> getSamples() {
        return this.samples;
    }

    public boolean isUseCovariance() {
        return this.useCovariance;
    }

    public void setUseCovariance(boolean z) {
        this.useCovariance = z;
    }

    public void setCovarianceMatrix(HashMap<Integer, double[][]> hashMap) {
        this.covarianceMatrix = hashMap;
    }

    public HashMap<Integer, double[][]> getCovarianceMatrix() {
        return this.covarianceMatrix;
    }

    public PredictionType getPredictionType() {
        return this.predictionType;
    }

    public void setPredictionType(PredictionType predictionType) {
        this.predictionType = predictionType;
    }

    public boolean isGenerateID() {
        return this.generateID;
    }

    public void setGenerateID(boolean z) {
        this.generateID = z;
    }
}
