package eu.radoop.spark;

import eu.radoop.tools.CommonUtils;
import eu.radoop.transfer.model.ForestMTO;
import eu.radoop.transfer.model.ModelTransferObject;
import eu.radoop.transfer.model.TreeModelMTO;
import eu.radoop.transfer.parameter.ParameterTransferObject;
import eu.radoop.transfer.parameter.SparkRandomForestParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.spark.SparkException;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:lib/radoop-spark3.jar:eu/radoop/spark/SparkRandomForestRunner.class */
public class SparkRandomForestRunner extends SparkDecisionTreeMLRunner {
    public static void main(String[] strArr) throws SparkException {
        try {
            persistModel(learn(init(strArr)));
        } catch (Exception e) {
            processException(e);
        } finally {
            close();
        }
    }

    public static ModelTransferObject learn(String str) throws SparkException {
        ParameterTransferObject parameterTransferObject = new ParameterTransferObject(str, SparkRandomForestParameter.class);
        System.out.println("pto:\n" + parameterTransferObject);
        Dataset<Row> inputAsDataFrame = getInputAsDataFrame();
        inputAsDataFrame.cache();
        checkMissingLabel(discoverLabelAndNominalFeatureMappings(inputAsDataFrame, parameterTransferObject.getParameterAsBoolean(SparkRandomForestParameter.SKIP_DISCOVER).booleanValue()));
        ArrayList arrayList = new ArrayList(Arrays.asList(columnNames));
        ArrayList arrayList2 = new ArrayList();
        String[] prepareFeatureColumns = prepareFeatureColumns(arrayList, arrayList2);
        String prepareLabelColumn = prepareLabelColumn(arrayList, arrayList2);
        String prepareFeatureVectorColumn = prepareFeatureVectorColumn(arrayList, arrayList2, prepareFeatureColumns);
        String newUniqueAlias = CommonUtils.newUniqueAlias("p_", arrayList);
        String parameterAsString = parameterTransferObject.getParameterAsString(SparkRandomForestParameter.FEATURE_SUBSET_STRATEGY);
        String parameterAsString2 = parameterTransferObject.getParameterAsString(SparkRandomForestParameter.IMPURITY);
        int intValue = parameterTransferObject.getParameterAsInteger(SparkRandomForestParameter.MAX_BINS).intValue();
        int intValue2 = parameterTransferObject.getParameterAsInteger(SparkRandomForestParameter.MAX_DEPTH).intValue();
        double doubleValue = parameterTransferObject.getParameterAsDouble(SparkRandomForestParameter.MIN_INFO_GAIN).doubleValue();
        int intValue3 = parameterTransferObject.getParameterAsInteger(SparkRandomForestParameter.MIN_INSTANCES_PER_NODE).intValue();
        int intValue4 = parameterTransferObject.getParameterAsInteger(SparkRandomForestParameter.NUM_TREES).intValue();
        long longValue = parameterTransferObject.getParameterAsLong(SparkRandomForestParameter.SEED).longValue();
        double doubleValue2 = parameterTransferObject.getParameterAsDouble(SparkRandomForestParameter.SUBSAMPLING_RATE).doubleValue();
        boolean booleanValue = parameterTransferObject.getParameterAsBoolean(SparkRandomForestParameter.CACHE_NODE_IDS).booleanValue();
        int intValue5 = parameterTransferObject.getParameterAsInteger(SparkRandomForestParameter.MAX_MEMORY_IN_MB).intValue();
        RandomForestClassifier predictionCol = new RandomForestClassifier().setLabelCol(prepareLabelColumn).setFeaturesCol(prepareFeatureVectorColumn).setPredictionCol(newUniqueAlias);
        predictionCol.setFeatureSubsetStrategy(parameterAsString);
        predictionCol.setImpurity(parameterAsString2);
        predictionCol.setMaxBins(intValue);
        predictionCol.setMaxDepth(intValue2);
        predictionCol.setMinInfoGain(doubleValue);
        predictionCol.setMinInstancesPerNode(intValue3);
        predictionCol.setNumTrees(intValue4);
        predictionCol.setSeed(longValue);
        predictionCol.setSubsamplingRate(doubleValue2);
        predictionCol.setCacheNodeIds(booleanValue);
        predictionCol.setMaxMemoryInMB(intValue5);
        arrayList2.add(predictionCol);
        PipelineModel fit = new Pipeline().setStages((PipelineStage[]) arrayList2.toArray(new PipelineStage[0])).fit(inputAsDataFrame);
        return convertForestModel(fit.stages()[fit.stages().length - 1]);
    }

    private static ForestMTO convertForestModel(RandomForestClassificationModel randomForestClassificationModel) {
        Map<String, List<String>> convertNominalMappings = convertNominalMappings(knownNominalMappings);
        TreeModelMTO[] treeModelMTOArr = new TreeModelMTO[randomForestClassificationModel.trees().length];
        for (int i = 0; i < treeModelMTOArr.length; i++) {
            treeModelMTOArr[i] = new TreeModelMTO(convertTreeNode(randomForestClassificationModel.trees()[i].rootNode()), convertNominalMappings);
        }
        return new ForestMTO(treeModelMTOArr, randomForestClassificationModel.treeWeights(), convertNominalMappings);
    }
}
