package com.rapidminer.extension.xgboost.model;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.table.Table;
import com.rapidminer.belt.table.Tables;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.IOTablePredictionModel;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import ml.dmlc.xgboost4j.java.XGBoostError;

/* loaded from: input_file:com/rapidminer/extension/xgboost/model/XGBoostModel.class */
public class XGBoostModel extends IOTablePredictionModel {
    private final Map<String, String> parameters;
    private final int iterations;
    private final byte[] booster;

    public XGBoostModel() {
        this.parameters = Collections.emptyMap();
        this.iterations = 0;
        this.booster = null;
    }

    public XGBoostModel(IOTable iOTable, Map<String, String> map, int i, byte[] bArr) {
        super(iOTable, Tables.ColumnSetRequirement.EQUAL, new Tables.TypeRequirement[]{Tables.TypeRequirement.REQUIRE_MATCHING_TYPES});
        this.parameters = map;
        this.iterations = i;
        this.booster = bArr;
    }

    protected Column performPrediction(Table table, Map<String, Column> map, Operator operator) throws OperatorException {
        if (table.height() == 0) {
            return getLabelColumn();
        }
        try {
            return XGBoostWrapper.predict(this, table, map);
        } catch (ConversionException e) {
            throw new UserError((Operator) null, e, "xgboost.conversion_error", new Object[]{e.getMessage()});
        } catch (IOException | XGBoostError e2) {
            throw new UserError((Operator) null, e2, "xgboost.generic_error", new Object[]{e2.getMessage()});
        }
    }

    public String getName() {
        return "XGBoost";
    }

    public String toString() {
        StringBuilder append = new StringBuilder("XGBoost prediction model for label '").append(getLabelName()).append("'.\n\nTraining hyper parameters: \n\n");
        this.parameters.forEach((str, str2) -> {
            append.append(str).append(" = ").append(str2).append("\n");
        });
        append.append("\nBoosting iterations: ").append(this.iterations);
        return append.toString();
    }

    public Column getLabelColumn() {
        return super.getLabelColumn();
    }

    public Map<String, Object> getParameters() {
        return Collections.unmodifiableMap(this.parameters);
    }

    public int getIterations() {
        return this.iterations;
    }

    public byte[] getBooster() {
        return this.booster;
    }
}
