package com.rapidminer.extension.operator.learner;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.operator.AbstractSmileLearner;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import smile.clustering.Clustering;
import smile.regression.GradientTreeBoost;
import smile.regression.Regression;

/* loaded from: input_file:com/rapidminer/extension/operator/learner/GradientBoostedTreeOperator.class */
public class GradientBoostedTreeOperator extends AbstractSmileLearner {
    public static final String PARAMETER_NTREES = "number_of_trees";
    public static final String PARAMETER_MAX_NODES = "maximum_nodes";
    public static final String PARAMETER_SHRINKAGE = "shrinkage (learning_rate)";
    public static final String PARAMETER_SAMPLING = "sample_ratio";
    public static final String PARAMETER_LOSS = "loss_function";
    protected static final String[] LOSS_TYPES = {GradientTreeBoost.Loss.LeastSquares.name(), GradientTreeBoost.Loss.LeastAbsoluteDeviation.name(), GradientTreeBoost.Loss.Huber.name()};
    private AttributeWeights weights;

    public GradientBoostedTreeOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.weights = null;
    }

    @Override // com.rapidminer.extension.operator.AbstractSmileLearner
    protected Regression<double[]> createRegressionModel(double[][] dArr, double[] dArr2, Attributes attributes) throws UndefinedParameterError {
        GradientTreeBoost.Loss loss = GradientTreeBoost.Loss.LeastSquares;
        if (getParameterAsString(PARAMETER_LOSS).equals(GradientTreeBoost.Loss.LeastAbsoluteDeviation.name())) {
            loss = GradientTreeBoost.Loss.LeastAbsoluteDeviation;
        } else if (getParameterAsString(PARAMETER_LOSS).equals(GradientTreeBoost.Loss.Huber.name())) {
            loss = GradientTreeBoost.Loss.Huber;
        }
        GradientTreeBoost gradientTreeBoost = new GradientTreeBoost(dArr, dArr2, loss, getParameterAsInt("number_of_trees"), getParameterAsInt(PARAMETER_MAX_NODES), getParameterAsDouble(PARAMETER_SHRINKAGE), getParameterAsDouble(PARAMETER_SAMPLING));
        this.weights = new AttributeWeights();
        int i = 0;
        Iterator it = attributes.iterator();
        while (it.hasNext()) {
            this.weights.setWeight(((Attribute) it.next()).getName(), gradientTreeBoost.importance()[i]);
            i++;
        }
        return gradientTreeBoost;
    }

    public AttributeWeights getWeights(ExampleSet exampleSet) throws OperatorException {
        if (this.weights == null) {
            throw new OperatorException("GBT weights are not available");
        }
        return this.weights;
    }

    public boolean canCalculateWeights() {
        return true;
    }

    @Override // com.rapidminer.extension.operator.AbstractSmileLearner
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return operatorCapability == OperatorCapability.NUMERICAL_ATTRIBUTES || operatorCapability == OperatorCapability.NUMERICAL_LABEL;
    }

    @Override // com.rapidminer.extension.operator.AbstractSmileLearner
    public List<ParameterType> getParameterTypes() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ParameterTypeCategory(PARAMETER_LOSS, "Used Loss Function. Similar to 'distribution' in Gradient Boosted Tree.", LOSS_TYPES, 1, false));
        arrayList.add(new ParameterTypeInt("number_of_trees", "The number of trees in the ensemble.", 1, Clustering.OUTLIER, 100, false));
        arrayList.add(new ParameterTypeInt(PARAMETER_MAX_NODES, "The maximum number of nodes per tree.", 1, Clustering.OUTLIER, 512, false));
        arrayList.add(new ParameterTypeDouble(PARAMETER_SHRINKAGE, "The shrinkage parameter, also known as learnign rate.", 0.0d, 1.0d, 0.1d, false));
        arrayList.add(new ParameterTypeDouble(PARAMETER_SAMPLING, "Sub-Sample Ratio.", 0.0d, 1.0d, 0.9d, false));
        arrayList.addAll(super.getParameterTypes());
        return arrayList;
    }
}
