package com.rapidminer.ispr.operator.learner.selection.models;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.set.SelectedExampleSet;
import com.rapidminer.ispr.operator.learner.classifiers.IS_KNNClassificationModel;
import com.rapidminer.ispr.operator.learner.selection.models.decisionfunctions.IISDecisionFunction;
import com.rapidminer.ispr.operator.learner.tools.DataIndex;
import com.rapidminer.ispr.operator.learner.tools.KNNTools;
import com.rapidminer.ispr.operator.learner.tools.PRulesUtil;
import com.rapidminer.ispr.tools.math.container.DoubleObjectContainer;
import com.rapidminer.ispr.tools.math.container.GeometricCollectionTypes;
import com.rapidminer.ispr.tools.math.container.ISPRGeometricDataCollection;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:com/rapidminer/ispr/operator/learner/selection/models/AllKNNInstanceSelectionGeneralModel.class */
public class AllKNNInstanceSelectionGeneralModel extends AbstractInstanceSelectorModel {
    private DistanceMeasure measure;
    private int lowerK;
    private int upperK;
    private IISDecisionFunction loss;
    IS_KNNClassificationModel<Number> model;

    public AllKNNInstanceSelectionGeneralModel(DistanceMeasure distanceMeasure, int i, int i2, IISDecisionFunction iISDecisionFunction) {
        this.measure = distanceMeasure;
        this.lowerK = i;
        this.upperK = i2;
        this.loss = iISDecisionFunction;
    }

    @Override // com.rapidminer.ispr.operator.learner.selection.models.AbstractInstanceSelectorModel
    public DataIndex selectInstances(SelectedExampleSet selectedExampleSet) {
        Attributes attributes = selectedExampleSet.getAttributes();
        DataIndex index = selectedExampleSet.getIndex();
        Attribute label = attributes.getLabel();
        ISPRGeometricDataCollection<Number> initializeKNearestNeighbourFactory = KNNTools.initializeKNearestNeighbourFactory(GeometricCollectionTypes.LINEAR_SEARCH, selectedExampleSet, this.measure);
        this.loss.init(initializeKNearestNeighbourFactory);
        if (label.isNominal()) {
            int i = 0;
            int[] iArr = new int[label.getMapping().size()];
            Iterator<double[]> samplesIterator = initializeKNearestNeighbourFactory.samplesIterator();
            Iterator<Number> storedValueIterator = initializeKNearestNeighbourFactory.storedValueIterator();
            while (samplesIterator.hasNext() && storedValueIterator.hasNext()) {
                double doubleValue = storedValueIterator.next().doubleValue();
                double[] next = samplesIterator.next();
                Arrays.fill(iArr, 0);
                int i2 = 0;
                Iterator<DoubleObjectContainer<Number>> it = initializeKNearestNeighbourFactory.getNearestValueDistances(this.upperK, next).iterator();
                while (it.hasNext()) {
                    int intValue = it.next().getSecond().intValue();
                    iArr[intValue] = iArr[intValue] + 1;
                    if (i2 > this.lowerK) {
                        if (this.loss.getValue(doubleValue, PRulesUtil.findMostFrequentValue(iArr), next) > 0.0d) {
                            index.set(i, false);
                        }
                    }
                    i2++;
                }
                i++;
            }
        } else if (label.isNumerical()) {
            int i3 = 0;
            Iterator<double[]> samplesIterator2 = initializeKNearestNeighbourFactory.samplesIterator();
            Iterator<Number> storedValueIterator2 = initializeKNearestNeighbourFactory.storedValueIterator();
            while (samplesIterator2.hasNext() && storedValueIterator2.hasNext()) {
                double doubleValue2 = storedValueIterator2.next().doubleValue();
                double[] next2 = samplesIterator2.next();
                double d = 0.0d;
                int i4 = 0;
                Iterator<DoubleObjectContainer<Number>> it2 = initializeKNearestNeighbourFactory.getNearestValueDistances(this.upperK, next2).iterator();
                while (it2.hasNext()) {
                    d += it2.next().getSecond().doubleValue();
                    if (i4 > this.lowerK) {
                        if (this.loss.getValue(doubleValue2, d / i4, next2) > 0.0d) {
                            index.set(i3, false);
                        }
                    }
                    i4++;
                }
                i3++;
            }
        }
        return index;
    }
}
