package com.rapidminer.extension.xgboost.model;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NominalBuffer;
import com.rapidminer.belt.buffer.NumericBuffer;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.column.Columns;
import com.rapidminer.belt.column.Dictionary;
import com.rapidminer.belt.reader.CategoricalReader;
import com.rapidminer.belt.reader.NumericReader;
import com.rapidminer.belt.reader.Readers;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.util.ColumnRole;
import com.rapidminer.extension.xgboost.model.CheckedDMatrix;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import java.util.stream.IntStream;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

/* loaded from: input_file:com/rapidminer/extension/xgboost/model/XGBoostWrapper.class */
public class XGBoostWrapper {
    private static final Object XGB_LOCK = new Object();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.rapidminer.extension.xgboost.model.XGBoostWrapper$1, reason: invalid class name */
    /* loaded from: input_file:com/rapidminer/extension/xgboost/model/XGBoostWrapper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$rapidminer$belt$column$Column$Category = new int[Column.Category.values().length];

        static {
            try {
                $SwitchMap$com$rapidminer$belt$column$Column$Category[Column.Category.NUMERIC.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$rapidminer$belt$column$Column$Category[Column.Category.CATEGORICAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    private XGBoostWrapper() {
        throw new AssertionError("Static utility class must not be initialized");
    }

    public static XGBoostModel train(Table table, Table table2, Map<String, String> map, int i, int i2, BooleanSupplier booleanSupplier) throws XGBoostError, ConversionException {
        XGBoostModel xGBoostModel;
        if (table.height() == 0) {
            throw new IllegalArgumentException("Training table must not be empty");
        }
        Map<String, float[]> createTrainingMatrices = createTrainingMatrices(table);
        Map<String, float[]> createTrainingMatrices2 = table2 == null ? null : createTrainingMatrices(table2);
        HashMap hashMap = new HashMap(map);
        selectObjective(table, hashMap);
        synchronized (XGB_LOCK) {
            CheckedDMatrix matrix = toMatrix(createTrainingMatrices);
            matrix.setSentinel(booleanSupplier);
            try {
                Booster train = XGBoost.train(matrix, new HashMap(hashMap), i, createTrainingMatrices2 == null ? Collections.emptyMap() : Collections.singletonMap("validation", toMatrix(createTrainingMatrices2)), (float[][]) null, null, null, i2);
                int version = train.getVersion() / 2;
                byte[] byteArray = train.toByteArray();
                train.dispose();
                xGBoostModel = new XGBoostModel(new IOTable(table), hashMap, version, byteArray);
                matrix.dispose();
            } catch (CheckedDMatrix.UsageBlockedException e) {
                matrix.dispose();
                return null;
            } catch (Throwable th) {
                matrix.dispose();
                throw th;
            }
        }
        return xGBoostModel;
    }

    public static Column predict(XGBoostModel xGBoostModel, Table table, Map<String, Column> map) throws XGBoostError, IOException {
        float[][] predict;
        if (table.height() == 0) {
            throw new IllegalArgumentException("Scoring table must not be empty");
        }
        float[] createFeatureMatrix = createFeatureMatrix(table);
        int length = createFeatureMatrix.length / table.height();
        int height = table.height();
        synchronized (XGB_LOCK) {
            Booster loadModel = XGBoost.loadModel(xGBoostModel.getBooster());
            DMatrix dMatrix = new DMatrix(createFeatureMatrix, height, length);
            predict = loadModel.predict(dMatrix);
            dMatrix.dispose();
            loadModel.dispose();
        }
        Column labelColumn = xGBoostModel.getLabelColumn();
        return labelColumn.type().category() == Column.Category.CATEGORICAL ? Columns.isAtMostBicategorical(labelColumn) ? predictBicategorical(predict, labelColumn, map) : predictCategorical(predict, labelColumn, map) : predictRegression(predict);
    }

    private static void selectObjective(Table table, Map<String, String> map) {
        if (map.containsKey("objective")) {
            return;
        }
        Column column = (Column) table.select().withMetaData(ColumnRole.LABEL).columns().get(0);
        if (column.type().category() != Column.Category.CATEGORICAL) {
            map.put("objective", "reg:squarederror");
        } else if (Columns.isAtMostBicategorical(column)) {
            map.put("objective", "binary:logistic");
        } else {
            map.put("objective", "multi:softprob");
            map.put("num_class", Integer.toString(column.getDictionary().size()));
        }
    }

    private static CheckedDMatrix toMatrix(Map<String, float[]> map) throws XGBoostError {
        float[] fArr = map.get("features");
        float[] fArr2 = map.get("label");
        int length = fArr2.length;
        CheckedDMatrix checkedDMatrix = new CheckedDMatrix(fArr, length, fArr.length / length);
        checkedDMatrix.setLabel(fArr2);
        if (map.containsKey("weights")) {
            checkedDMatrix.setWeight(map.get("weights"));
        }
        return checkedDMatrix;
    }

    private static Column predictBicategorical(float[][] fArr, Column column, Map<String, Column> map) {
        Dictionary dictionary = column.getDictionary();
        int negativeIndex = dictionary.isBoolean() ? dictionary.getNegativeIndex() : IntStream.range(1, dictionary.maximalIndex()).filter(i -> {
            return dictionary.get(i) != null;
        }).findFirst().orElse(-1);
        int positiveIndex = dictionary.isBoolean() ? dictionary.getPositiveIndex() : dictionary.maximalIndex();
        String str = dictionary.get(negativeIndex);
        String str2 = dictionary.get(positiveIndex);
        NominalBuffer nominalBuffer = Buffers.nominalBuffer(fArr.length);
        NumericBuffer realBuffer = Buffers.realBuffer(fArr.length, false);
        NumericBuffer realBuffer2 = Buffers.realBuffer(fArr.length, false);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            float f = fArr[i2][0];
            nominalBuffer.set(i2, ((double) f) < 0.5d ? str : str2);
            realBuffer.set(i2, 1.0d - f);
            realBuffer2.set(i2, f);
        }
        if (str != null) {
            map.put(str, realBuffer.toColumn());
        }
        map.put(str2, realBuffer2.toColumn());
        return Columns.changeDictionary(nominalBuffer.toColumn(), column);
    }

    private static Column predictCategorical(float[][] fArr, Column column, Map<String, Column> map) {
        int length = fArr[0].length;
        Dictionary dictionary = column.getDictionary();
        NominalBuffer nominalBuffer = Buffers.nominalBuffer(fArr.length);
        NumericBuffer[] numericBufferArr = new NumericBuffer[length];
        Arrays.setAll(numericBufferArr, i -> {
            return Buffers.realBuffer(fArr.length, false);
        });
        for (int i2 = 0; i2 < fArr.length; i2++) {
            float[] fArr2 = fArr[i2];
            float f = Float.NEGATIVE_INFINITY;
            int i3 = -1;
            for (int i4 = 0; i4 < length; i4++) {
                float f2 = fArr2[i4];
                if (f2 > f) {
                    f = f2;
                    i3 = i4;
                }
                numericBufferArr[i4].set(i2, f2);
            }
            nominalBuffer.set(i2, dictionary.get(i3 + 1));
        }
        for (int i5 = 0; i5 < length; i5++) {
            String str = dictionary.get(i5 + 1);
            if (str != null) {
                map.put(str, numericBufferArr[i5].toColumn());
            }
        }
        return Columns.changeDictionary(nominalBuffer.toColumn(), column);
    }

    private static Column predictRegression(float[][] fArr) {
        NumericBuffer realBuffer = Buffers.realBuffer(fArr.length, false);
        for (int i = 0; i < fArr.length; i++) {
            realBuffer.set(i, fArr[i][0]);
        }
        return realBuffer.toColumn();
    }

    private static float[] createFeatureMatrix(Table table) throws ConversionException {
        List<Column> columns = table.select().withoutMetaData(ColumnRole.class).columns();
        if (columns.isEmpty()) {
            throw new IllegalArgumentException("Data table does not contain any feature");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<Column> arrayList2 = new ArrayList();
        long j = 0;
        for (Column column : columns) {
            switch (AnonymousClass1.$SwitchMap$com$rapidminer$belt$column$Column$Category[column.type().category().ordinal()]) {
                case 1:
                    arrayList.add(column);
                    j++;
                    break;
                case 2:
                    arrayList2.add(column);
                    j += Columns.isAtMostBicategorical(column) ? 1L : column.getDictionary().size();
                    break;
            }
        }
        if (j * table.height() > 2147483639) {
            throw new ConversionException("Size of encoded data set exceeds runtime limit");
        }
        int i = (int) j;
        float[] fArr = new float[i * table.height()];
        Arrays.fill(fArr, Float.NaN);
        int i2 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            readNumericColumn((Column) it.next(), fArr, i2, i);
            i2++;
        }
        for (Column column2 : arrayList2) {
            if (Columns.isAtMostBicategorical(column2)) {
                readBicategoricalColumn(column2, fArr, i2, i);
                i2++;
            } else {
                Dictionary dictionary = column2.getDictionary();
                readCategoricalColumn(column2, fArr, i2, i);
                i2 += dictionary.size();
            }
        }
        return fArr;
    }

    private static Map<String, float[]> createTrainingMatrices(Table table) throws ConversionException {
        HashMap hashMap = new HashMap();
        hashMap.put("features", createFeatureMatrix(table));
        List columns = table.select().withMetaData(ColumnRole.WEIGHT).columns();
        if (!columns.isEmpty()) {
            float[] fArr = new float[table.height()];
            readNumericColumn((Column) columns.get(0), fArr, 0, 1);
            hashMap.put("weights", fArr);
        }
        List columns2 = table.select().withMetaData(ColumnRole.LABEL).columns();
        if (columns2.isEmpty()) {
            throw new IllegalArgumentException("Input table has no label");
        }
        Column column = (Column) columns2.get(0);
        float[] fArr2 = new float[table.height()];
        switch (AnonymousClass1.$SwitchMap$com$rapidminer$belt$column$Column$Category[column.type().category().ordinal()]) {
            case 1:
                readNumericColumn(column, fArr2, 0, 1);
                break;
            case 2:
                if (Columns.isAtMostBicategorical(column)) {
                    readBicategoricalColumn(column, fArr2, 0, 1);
                    break;
                } else {
                    readNumericColumn(column, fArr2, 0, 1);
                    for (int i = 0; i < fArr2.length; i++) {
                        int i2 = i;
                        fArr2[i2] = fArr2[i2] - 1.0f;
                    }
                    break;
                }
            default:
                throw new IllegalArgumentException("Unsupported label column");
        }
        hashMap.put("label", fArr2);
        return hashMap;
    }

    private static void readNumericColumn(Column column, float[] fArr, int i, int i2) {
        NumericReader numericReader = Readers.numericReader(column);
        int i3 = i;
        while (true) {
            int i4 = i3;
            if (!numericReader.hasRemaining()) {
                return;
            }
            fArr[i4] = (float) numericReader.read();
            i3 = i4 + i2;
        }
    }

    private static void readBicategoricalColumn(Column column, float[] fArr, int i, int i2) {
        Dictionary dictionary = column.getDictionary();
        int negativeIndex = dictionary.isBoolean() ? dictionary.getNegativeIndex() : IntStream.range(1, dictionary.maximalIndex()).filter(i3 -> {
            return dictionary.get(i3) != null;
        }).findFirst().orElse(-1);
        int positiveIndex = dictionary.isBoolean() ? dictionary.getPositiveIndex() : dictionary.maximalIndex();
        CategoricalReader categoricalReader = Readers.categoricalReader(column);
        int i4 = i;
        while (true) {
            int i5 = i4;
            if (!categoricalReader.hasRemaining()) {
                return;
            }
            int read = categoricalReader.read();
            if (read == negativeIndex) {
                fArr[i5] = 0.0f;
            } else if (read == positiveIndex) {
                fArr[i5] = 1.0f;
            }
            i4 = i5 + i2;
        }
    }

    private static void readCategoricalColumn(Column column, float[] fArr, int i, int i2) {
        Dictionary dictionary = column.getDictionary();
        if (dictionary.size() == 0) {
            return;
        }
        int[] iArr = new int[dictionary.maximalIndex() + 1];
        int i3 = 0;
        for (int i4 = 1; i4 <= dictionary.maximalIndex(); i4++) {
            if (dictionary.get(i4) != null) {
                iArr[i4] = i3;
                i3++;
            }
        }
        CategoricalReader categoricalReader = Readers.categoricalReader(column);
        int i5 = i;
        while (true) {
            int i6 = i5;
            if (!categoricalReader.hasRemaining()) {
                return;
            }
            int read = categoricalReader.read();
            if (read > 0) {
                fArr[i6 + iArr[read]] = 1.0f;
            }
            i5 = i6 + i2;
        }
    }
}
