package eu.radoop.spark;

import eu.radoop.spark.pipeline.SparkNullChecker;
import eu.radoop.spark.pipeline.SparkStringIndexer;
import eu.radoop.tools.CommonUtils;
import eu.radoop.transfer.model.ContainsSplitConditionTO;
import eu.radoop.transfer.model.EdgeTO;
import eu.radoop.transfer.model.GreaterSplitConditionTO;
import eu.radoop.transfer.model.LessEqualsSplitConditionTO;
import eu.radoop.transfer.model.ModelTransferObject;
import eu.radoop.transfer.model.NominalSplitConditionTO;
import eu.radoop.transfer.model.NotContainsSplitConditionTO;
import eu.radoop.transfer.model.TreeModelMTO;
import eu.radoop.transfer.model.TreeTO;
import eu.radoop.transfer.parameter.ParameterTransferObject;
import eu.radoop.transfer.parameter.SparkDecisionTreeMLParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.spark.SparkException;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.tree.CategoricalSplit;
import org.apache.spark.ml.tree.ContinuousSplit;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.Split;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:lib/radoop-spark3.jar:eu/radoop/spark/SparkDecisionTreeMLRunner.class */
public class SparkDecisionTreeMLRunner extends AbstractSparkRunner {
    public static void main(String[] strArr) throws SparkException {
        try {
            persistModel(learn(init(strArr)));
        } catch (Exception e) {
            processException(e);
        } finally {
            close();
        }
    }

    public static ModelTransferObject learn(String str) throws SparkException {
        ParameterTransferObject parameterTransferObject = new ParameterTransferObject(str, SparkDecisionTreeMLParameter.class);
        System.out.println("pto:\n" + parameterTransferObject);
        Dataset<Row> inputAsDataFrame = getInputAsDataFrame();
        inputAsDataFrame.cache();
        checkMissingLabel(discoverLabelAndNominalFeatureMappings(inputAsDataFrame, parameterTransferObject.getParameterAsBoolean(SparkDecisionTreeMLParameter.SKIP_DISCOVER).booleanValue()));
        ArrayList arrayList = new ArrayList(Arrays.asList(columnNames));
        ArrayList arrayList2 = new ArrayList();
        String[] prepareFeatureColumns = prepareFeatureColumns(arrayList, arrayList2);
        String prepareLabelColumn = prepareLabelColumn(arrayList, arrayList2);
        String prepareFeatureVectorColumn = prepareFeatureVectorColumn(arrayList, arrayList2, prepareFeatureColumns);
        String newUniqueAlias = CommonUtils.newUniqueAlias("p_", arrayList);
        String parameterAsString = parameterTransferObject.getParameterAsString(SparkDecisionTreeMLParameter.IMPURITY);
        int intValue = parameterTransferObject.getParameterAsInteger(SparkDecisionTreeMLParameter.MAX_BINS).intValue();
        int intValue2 = parameterTransferObject.getParameterAsInteger(SparkDecisionTreeMLParameter.MAX_DEPTH).intValue();
        double doubleValue = parameterTransferObject.getParameterAsDouble(SparkDecisionTreeMLParameter.MIN_INFO_GAIN).doubleValue();
        int intValue3 = parameterTransferObject.getParameterAsInteger(SparkDecisionTreeMLParameter.MIN_INSTANCES_PER_NODE).intValue();
        boolean booleanValue = parameterTransferObject.getParameterAsBoolean(SparkDecisionTreeMLParameter.CACHE_NODE_IDS).booleanValue();
        int intValue4 = parameterTransferObject.getParameterAsInteger(SparkDecisionTreeMLParameter.MAX_MEMORY_IN_MB).intValue();
        DecisionTreeClassifier predictionCol = new DecisionTreeClassifier().setLabelCol(prepareLabelColumn).setFeaturesCol(prepareFeatureVectorColumn).setPredictionCol(newUniqueAlias);
        predictionCol.setImpurity(parameterAsString);
        predictionCol.setMaxDepth(intValue2);
        predictionCol.setMaxBins(intValue);
        predictionCol.setMinInstancesPerNode(intValue3);
        predictionCol.setMinInfoGain(doubleValue);
        predictionCol.setCacheNodeIds(booleanValue);
        predictionCol.setMaxMemoryInMB(intValue4);
        arrayList2.add(predictionCol);
        PipelineModel fit = new Pipeline().setStages((PipelineStage[]) arrayList2.toArray(new PipelineStage[0])).fit(inputAsDataFrame);
        return convertTreeModel(fit.stages()[fit.stages().length - 1]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String prepareFeatureVectorColumn(List<String> list, List<PipelineStage> list2, String[] strArr) {
        String newUniqueAlias = CommonUtils.newUniqueAlias("f_", list);
        list.add(newUniqueAlias);
        list2.add(new VectorAssembler().setInputCols(strArr).setOutputCol(newUniqueAlias));
        return newUniqueAlias;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String prepareLabelColumn(List<String> list, List<PipelineStage> list2) throws SparkException {
        String str;
        if (isNominal[labelIndex.intValue()].booleanValue()) {
            str = CommonUtils.newUniqueAlias("i_" + labelName, list);
            list.add(str);
            if (mappingProvided[labelIndex.intValue()].booleanValue()) {
                list2.add(new SparkNullChecker(labelIndex.intValue(), labelName));
            }
            list2.add(new SparkStringIndexer(labelIndex.intValue(), str, knownNominalMappings.get(labelIndex), labelName));
        } else {
            str = labelName;
            list2.add(new SparkNullChecker(labelIndex.intValue(), labelName));
        }
        return str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String[] prepareFeatureColumns(List<String> list, List<PipelineStage> list2) throws SparkException {
        String str;
        String[] strArr = new String[featureColumns.length];
        for (int i = 0; i < featureColumns.length; i++) {
            String str2 = featureColumns[i];
            int intValue = featuresIndex[i].intValue();
            if (isNominal[intValue].booleanValue()) {
                str = CommonUtils.newUniqueAlias("i_" + str2, list);
                list.add(str);
                list2.add(new SparkStringIndexer(intValue, str, knownNominalMappings.get(Integer.valueOf(intValue)), str2));
            } else {
                str = str2;
                list2.add(new SparkNullChecker(intValue, str2));
            }
            strArr[i] = str;
        }
        return strArr;
    }

    static TreeModelMTO convertTreeModel(DecisionTreeClassificationModel decisionTreeClassificationModel) {
        return new TreeModelMTO(convertTreeNode(decisionTreeClassificationModel.rootNode()), convertNominalMappings(knownNominalMappings));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static TreeTO convertTreeNode(Node node) {
        List<String> list = knownNominalMappings.get(labelIndex);
        String str = list.get((int) node.prediction());
        ArrayList arrayList = new ArrayList(2);
        if (node instanceof InternalNode) {
            InternalNode internalNode = (InternalNode) node;
            convertSplit(internalNode.split(), internalNode.leftChild(), internalNode.rightChild(), arrayList);
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < node.impurityStats().stats().length; i++) {
            hashMap.put(list.get(i), Integer.valueOf((int) node.impurityStats().stats()[i]));
        }
        return new TreeTO(str, arrayList, hashMap);
    }

    private static void convertSplit(Split split, Node node, Node node2, List<EdgeTO> list) {
        int featureIndex = split.featureIndex();
        String str = featureColumns[featureIndex];
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            list.add(new EdgeTO(convertTreeNode(node), new LessEqualsSplitConditionTO(str, continuousSplit.threshold())));
            list.add(new EdgeTO(convertTreeNode(node2), new GreaterSplitConditionTO(str, continuousSplit.threshold())));
            return;
        }
        CategoricalSplit categoricalSplit = (CategoricalSplit) split;
        List<String> list2 = knownNominalMappings.get(featuresIndex[featureIndex]);
        if (list2.size() == 2) {
            list.add(new EdgeTO(convertTreeNode(node), new NominalSplitConditionTO(str, list2.get((int) categoricalSplit.leftCategories()[0]))));
            list.add(new EdgeTO(convertTreeNode(node2), new NominalSplitConditionTO(str, list2.get((int) categoricalSplit.rightCategories()[0]))));
            return;
        }
        if (categoricalSplit.rightCategories().length < categoricalSplit.leftCategories().length) {
            HashSet hashSet = new HashSet();
            for (int i = 0; i < categoricalSplit.rightCategories().length; i++) {
                hashSet.add(list2.get((int) categoricalSplit.rightCategories()[i]));
            }
            list.add(new EdgeTO(convertTreeNode(node), new NotContainsSplitConditionTO(str, hashSet)));
            list.add(new EdgeTO(convertTreeNode(node2), new ContainsSplitConditionTO(str, hashSet)));
            return;
        }
        HashSet hashSet2 = new HashSet();
        for (int i2 = 0; i2 < categoricalSplit.leftCategories().length; i2++) {
            hashSet2.add(list2.get((int) categoricalSplit.leftCategories()[i2]));
        }
        list.add(new EdgeTO(convertTreeNode(node), new ContainsSplitConditionTO(str, hashSet2)));
        list.add(new EdgeTO(convertTreeNode(node2), new NotContainsSplitConditionTO(str, hashSet2)));
    }
}
