package com.rapidminer.extension.interpretation.algorithm.conformal_prediction.regression;

import com.rapidminer.adaption.belt.ContextAdapter;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NumericBuffer;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.reader.NumericReader;
import com.rapidminer.belt.reader.NumericRowReader;
import com.rapidminer.belt.reader.Readers;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.util.ColumnRole;
import com.rapidminer.extension.interpretation.algorithm.conformal_prediction.ConformalPredictionModel;
import com.rapidminer.extension.interpretation.utility.BeltUtilities;
import com.rapidminer.operator.GeneralModel;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.IOTablePredictionModel;
import com.rapidminer.studio.concurrency.internal.SequentialConcurrencyContext;

/* loaded from: input_file:com/rapidminer/extension/interpretation/algorithm/conformal_prediction/regression/ConformalRegressionModel.class */
public abstract class ConformalRegressionModel extends ConformalPredictionModel {
    public static String LOWER_BOUND_NAME = "Lower Bound";
    public static String UPPER_BOUND_NAME = "Upper Bound";

    public ConformalRegressionModel(IOTablePredictionModel iOTablePredictionModel) {
        super(iOTablePredictionModel);
    }

    @Override // com.rapidminer.extension.interpretation.algorithm.conformal_prediction.ConformalPredictionModel
    public void train(IOTable iOTable, double d, Operator operator) throws OperatorException {
        Table table = BeltUtilities.applyModelIfUnApplied(this.model, iOTable, operator).getTable();
        NumericBuffer estimateUncertainty = estimateUncertainty(table, this.model, operator);
        Column column = (Column) table.select().withMetaData(ColumnRole.PREDICTION).columns().get(0);
        NumericRowReader numericRowReader = Readers.numericRowReader(column, (Column) table.select().withMetaData(ColumnRole.LABEL).columns().get(0), new Column[0]);
        int i = 0;
        double[] dArr = new double[column.size()];
        while (numericRowReader.hasRemaining()) {
            numericRowReader.move();
            double d2 = numericRowReader.get(0);
            dArr[i] = Math.abs(numericRowReader.get(1) - d2) / estimateUncertainty.get(i);
            i++;
        }
        this.qHat = getAdjustedPercentile(dArr, d);
    }

    public IOTable apply(IOTable iOTable, Operator operator) throws OperatorException {
        Table table = BeltUtilities.applyModelIfUnApplied(this.model, iOTable, operator).getTable();
        NumericBuffer estimateUncertainty = estimateUncertainty(table, this.model, operator);
        Column column = (Column) table.select().withMetaData(ColumnRole.PREDICTION).columns().get(0);
        NumericReader numericReader = Readers.numericReader(column);
        NumericBuffer realBuffer = Buffers.realBuffer(column.size());
        NumericBuffer realBuffer2 = Buffers.realBuffer(column.size());
        int i = 0;
        while (numericReader.hasRemaining()) {
            double read = numericReader.read();
            double d = estimateUncertainty.get(i);
            realBuffer.set(i, read - (d * this.qHat));
            realBuffer2.set(i, read + (d * this.qHat));
            i++;
        }
        return BeltUtilities.applyModelIfUnApplied(this.model, new IOTable(Builders.newTableBuilder(iOTable.getTable()).add(LOWER_BOUND_NAME, realBuffer.toColumn()).addMetaData(LOWER_BOUND_NAME, ColumnRole.INTERPRETATION).add(UPPER_BOUND_NAME, realBuffer2.toColumn()).addMetaData(UPPER_BOUND_NAME, ColumnRole.INTERPRETATION).build(ContextAdapter.adapt(new SequentialConcurrencyContext()))), operator);
    }

    public abstract NumericBuffer estimateUncertainty(Table table, GeneralModel generalModel, Operator operator) throws OperatorException;

    @Override // com.rapidminer.extension.interpretation.algorithm.conformal_prediction.ConformalPredictionModel
    public boolean isModelKind(GeneralModel.ModelKind modelKind) {
        GeneralModel.ModelKind modelKind2 = GeneralModel.ModelKind.POSTPROCESSING;
        return modelKind.equals(GeneralModel.ModelKind.POSTPROCESSING);
    }
}
