package org.fabi.visualizations.rapidminer.scatter;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.DoubleArrayDataRow;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.example.table.PolynominalAttribute;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import org.fabi.visualizations.scatter.sources.ModelSource;

/* loaded from: input_file:org/fabi/visualizations/rapidminer/scatter/RapidMinerModelSource.class */
public class RapidMinerModelSource implements ModelSource {
    protected PredictionModel model;
    protected MemoryExampleTable table;
    protected ExampleSet referenceSet;
    protected Attribute label;
    protected Double[] valueArray;
    protected int oNumber;
    protected int iNumber;
    protected String name;

    public RapidMinerModelSource(PredictionModel predictionModel) {
        this(predictionModel, predictionModel.getTrainingHeader());
    }

    public RapidMinerModelSource(PredictionModel predictionModel, ExampleSet exampleSet) {
        this.model = predictionModel;
        this.name = predictionModel.getName();
        Attributes attributes = exampleSet.getAttributes();
        Attribute[] createRegularAttributeArray = attributes.createRegularAttributeArray();
        this.table = new MemoryExampleTable(createRegularAttributeArray);
        this.referenceSet = this.table.createExampleSet();
        this.iNumber = createRegularAttributeArray.length;
        this.label = attributes.getLabel();
        if (this.label != null) {
            if (this.label instanceof PolynominalAttribute) {
                this.oNumber = this.label.getMapping().size();
            } else if (this.label.isNumerical()) {
                this.oNumber++;
            }
        }
    }

    @Override // org.fabi.visualizations.scatter.sources.ModelSource
    public double[][] getModelResponses(double[][] dArr) {
        try {
            this.table.clear();
            for (double[] dArr2 : dArr) {
                this.table.addDataRow(new DoubleArrayDataRow(dArr2));
            }
            this.model.apply(this.referenceSet);
            NominalMapping mapping = this.label instanceof PolynominalAttribute ? this.label.getMapping() : null;
            double[][] dArr3 = new double[dArr.length][this.oNumber];
            int i = 0;
            for (Example example : this.referenceSet) {
                if (this.label instanceof PolynominalAttribute) {
                    for (int i2 = 0; i2 < mapping.size(); i2++) {
                        dArr3[i][i2] = example.getConfidence(mapping.mapIndex(i2));
                    }
                } else {
                    dArr3[i][0] = example.getPredictedLabel();
                }
                i++;
            }
            return dArr3;
        } catch (OperatorException e) {
            throw new RuntimeException("OperatorException: " + e.getMessage());
        }
    }

    @Override // org.fabi.visualizations.scatter.sources.ModelSource
    public int inputsNumber() {
        return this.iNumber;
    }

    @Override // org.fabi.visualizations.scatter.sources.ModelSource
    public int outputsNumber() {
        return this.oNumber;
    }

    @Override // org.fabi.visualizations.scatter.sources.ModelSource
    public String getName() {
        return this.name;
    }
}
