package eu.radoop.spark;

import com.google.common.collect.Lists;
import eu.radoop.transfer.TransferObject;
import eu.radoop.transfer.parameter.CommonParameter;
import eu.radoop.transfer.parameter.ParameterTransferObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import org.apache.hadoop.mapred.FileAlreadyExistsException;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:lib/radoop-spark3.jar:eu/radoop/spark/AbstractSparkRunner.class */
public abstract class AbstractSparkRunner {
    public static final int COMMON_PARAM_INDEX = 0;
    public static final int RUNNER_PARAM_INDEX = 1;
    public static final int POSITIVE_INDEX = 1;
    public static final int NEGATIVE_INDEX = 0;
    protected static SparkConf conf;
    protected static JavaSparkContext sc;
    protected static SQLContext sqlContext;
    protected static SparkVersion version;
    protected static String inputDirectory;
    protected static String modelOutputDirectory;
    protected static String exampleSetOutputDirectory;
    protected static CommonParameter.FileFormat inputFormat;
    protected static String fieldSeparator;
    protected static String nullString;
    protected static String[] columnNames;
    protected static Boolean[] isNominal;
    protected static Boolean[] isFeature;
    protected static Integer labelIndex;
    protected static String[] featureColumns;
    protected static Map<Integer, List<String>> knownNominalMappings;
    protected static String labelName;
    protected static Integer[] featuresIndex;
    protected static Boolean[] isNominalFeatureIndex;
    protected static Boolean[] mappingProvided;
    static ParameterTransferObject<CommonParameter> commonPto;

    public static void processException(Exception exc, String str, JavaSparkContext javaSparkContext) throws SparkException {
        System.out.println("Exception: " + exc.toString());
        String str2 = null;
        if (javaSparkContext.version().equals("2.0.1") && exc.getMessage().startsWith("Unable to create database ")) {
            str2 = ". This is a bug identified in Spark 2.0.1, for more details please see: https://issues.apache.org/jira/browse/SPARK-17810. You can override spark.sql.warehouse.dir setting as a workaround or you can update Spark to a newer version.";
        }
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(exc.toString() + (str2 == null ? "" : str2));
        try {
            try {
                javaSparkContext.parallelize(arrayList).coalesce(1).saveAsTextFile(str);
                javaSparkContext.cancelAllJobs();
            } catch (Exception e) {
                if (!(e instanceof FileAlreadyExistsException)) {
                    throw new SparkException(exc.toString(), exc);
                }
                javaSparkContext.cancelAllJobs();
            }
            throw new SparkException(exc.toString(), exc);
        } catch (Throwable th) {
            javaSparkContext.cancelAllJobs();
            throw th;
        }
    }

    public static String createInputException(String str) {
        return "INPUT" + str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void close() {
        sc.stop();
        sc.close();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void processException(Exception exc) throws SparkException {
        processException(exc, modelOutputDirectory, sc);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void persistModel(TransferObject transferObject) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(transferObject.toJson());
        sc.parallelize(arrayList).coalesce(1).saveAsTextFile(modelOutputDirectory);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String init(String[] strArr) throws SparkException {
        conf = new SparkConf();
        sc = new JavaSparkContext(conf);
        sqlContext = new SQLContext(new SparkSession(sc.sc()));
        commonPto = new ParameterTransferObject<>(RunnerTools.readFromArgFile(strArr[0]), CommonParameter.class);
        System.out.println("commonPto:\n" + commonPto);
        version = SparkVersion.valueOf(commonPto.getParameterAsString(CommonParameter.SPARK_VERSION));
        try {
            inputDirectory = RunnerTools.resolveInputDir(sc, commonPto.getParameterAsString(CommonParameter.INPUT_DIR));
            modelOutputDirectory = commonPto.getParameterAsString(CommonParameter.MODEL_OUTPUT_DIR);
            exampleSetOutputDirectory = commonPto.getParameterAsString(CommonParameter.EXAMPLE_SET_OUTPUT_DIR);
            inputFormat = CommonParameter.FileFormat.valueOf(commonPto.getParameterAsString(CommonParameter.INPUT_FORMAT));
            fieldSeparator = commonPto.getParameterAsString(CommonParameter.FIELD_SEPARATOR);
            nullString = commonPto.getParameterAsString(CommonParameter.NULL_STRING);
            columnNames = commonPto.getParameterAsStringArray(CommonParameter.COLUMN_NAMES);
            isNominal = commonPto.getParameterAsBooleanArray(CommonParameter.IS_NOMINAL);
            isFeature = commonPto.getParameterAsBooleanArray(CommonParameter.IS_FEATURE);
            labelIndex = commonPto.getParameterAsInteger(CommonParameter.LABEL_INDEX);
            if (labelIndex != null) {
                labelName = columnNames[labelIndex.intValue()];
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            for (int i = 0; i < isFeature.length; i++) {
                if (isFeature[i].booleanValue()) {
                    arrayList.add(columnNames[i]);
                    arrayList2.add(Integer.valueOf(i));
                    arrayList3.add(isNominal[i]);
                }
            }
            featureColumns = (String[]) arrayList.toArray(new String[0]);
            featuresIndex = (Integer[]) arrayList2.toArray(new Integer[0]);
            isNominalFeatureIndex = (Boolean[]) arrayList3.toArray(new Boolean[0]);
            mappingProvided = new Boolean[isFeature.length];
            Arrays.fill((Object[]) mappingProvided, (Object) false);
            String[] parameterAsStringArray = commonPto.getParameterAsStringArray(CommonParameter.NEGATIVE_VALUES);
            String[] parameterAsStringArray2 = commonPto.getParameterAsStringArray(CommonParameter.POSITIVE_VALUES);
            knownNominalMappings = new HashMap();
            for (int i2 = 0; i2 < isFeature.length; i2++) {
                if (isNominal[i2].booleanValue() && parameterAsStringArray[i2] != null && parameterAsStringArray2[i2] != null) {
                    knownNominalMappings.put(Integer.valueOf(i2), Lists.newArrayList(parameterAsStringArray[i2], parameterAsStringArray2[i2]));
                    mappingProvided[i2] = true;
                }
            }
            return RunnerTools.readFromArgFile(strArr[1]);
        } catch (IOException e) {
            throw new SparkException("Could not resolve input directory", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<Integer, List<String>> discoverLabelMappings(JavaRDD<String[]> javaRDD, boolean z) {
        if (!knownNominalMappings.containsKey(labelIndex)) {
            knownNominalMappings.putAll(discoverNominalMappings(javaRDD, z, labelIndex.intValue()));
        }
        return knownNominalMappings;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<Integer, List<String>> discoverLabelAndNominalFeatureMappings(JavaRDD<String[]> javaRDD, boolean z, boolean z2) {
        int[] iArr = new int[1 + featureColumns.length];
        int i = 0;
        if (!knownNominalMappings.containsKey(labelIndex)) {
            i = 0 + 1;
            iArr[0] = labelIndex.intValue();
        }
        if (z2) {
            for (Integer num : featuresIndex) {
                int intValue = num.intValue();
                if (isNominal[intValue].booleanValue() && !knownNominalMappings.containsKey(Integer.valueOf(intValue))) {
                    int i2 = i;
                    i++;
                    iArr[i2] = intValue;
                }
            }
        } else {
            for (Integer num2 : featuresIndex) {
                int intValue2 = num2.intValue();
                if (isNominal[intValue2].booleanValue()) {
                    int i3 = i;
                    i++;
                    iArr[i3] = intValue2;
                }
            }
        }
        if (i > 0) {
            Map<Integer, List<String>> discoverNominalMappings = discoverNominalMappings(javaRDD, z, Arrays.copyOf(iArr, i));
            reorderDiscoveredMappings(discoverNominalMappings);
            knownNominalMappings.putAll(discoverNominalMappings);
        }
        return knownNominalMappings;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<Integer, List<String>> discoverLabelAndNominalFeatureMappings(Dataset<Row> dataset, boolean z) {
        return discoverLabelAndNominalFeatureMappings(convertDataFrameToRDD(dataset), false, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Dataset<Row> getInputAsDataFrame() {
        Dataset<Row> dataset = null;
        if (inputFormat.equals(CommonParameter.FileFormat.TEXTFILE)) {
            dataset = createDataFrame(readTextFile(inputDirectory), columnNames, isNominal);
        } else if (inputFormat.equals(CommonParameter.FileFormat.PARQUET)) {
            dataset = readParquetFile();
        }
        return dataset;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static JavaRDD<LabeledPoint> getLabeledPointRDD(JavaRDD<String[]> javaRDD, final Map<Integer, List<String>> map) {
        final int intValue = labelIndex.intValue();
        final Integer[] numArr = featuresIndex;
        final Boolean[] boolArr = isNominal;
        final String[] strArr = columnNames;
        return javaRDD.map(new Function<String[], LabeledPoint>() { // from class: eu.radoop.spark.AbstractSparkRunner.1
            private static final long serialVersionUID = -6927287925476540061L;

            public LabeledPoint call(String[] strArr2) throws Exception {
                double doubleValue;
                String str = strArr2[intValue];
                String str2 = strArr[intValue];
                if (boolArr[intValue].booleanValue()) {
                    List list = (List) map.get(Integer.valueOf(intValue));
                    if (list == null) {
                        throw new SparkException(AbstractSparkRunner.createInputException("The nominal mapping of label attribute: " + str2 + " is undiscovered."));
                    }
                    int indexOf = list.indexOf(str);
                    if (indexOf == -1) {
                        throw new SparkException(AbstractSparkRunner.createInputException("The nominal mapping of label attribute: " + str2 + " does not contain member: " + str));
                    }
                    doubleValue = indexOf;
                } else {
                    if (str == null) {
                        throw new SparkException(AbstractSparkRunner.createInputException("The non-nominal label attribute: " + str2 + " should not contain missing."));
                    }
                    try {
                        doubleValue = Double.valueOf(str).doubleValue();
                    } catch (NumberFormatException e) {
                        throw new SparkException(AbstractSparkRunner.createInputException("The non-nominal label attribute: " + str2 + " should not contain the non-numeric value: " + str));
                    }
                }
                return new LabeledPoint(doubleValue, AbstractSparkRunner.rawValuesToTypedVector(map, numArr, boolArr, strArr, strArr2));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static JavaRDD<Vector> getFeaturesAsVectorRDD(JavaRDD<String[]> javaRDD) {
        return javaRDD.map(new Function<String[], Vector>() { // from class: eu.radoop.spark.AbstractSparkRunner.2
            private static final long serialVersionUID = -3542143556424237842L;
            private final Integer[] featuresIndexFinal = AbstractSparkRunner.featuresIndex;
            private final Boolean[] isNominalFinal = AbstractSparkRunner.isNominal;
            private final String[] columnNamesFinal = AbstractSparkRunner.columnNames;

            public Vector call(String[] strArr) throws Exception {
                return AbstractSparkRunner.rawValuesToTypedVector(Collections.emptyMap(), this.featuresIndexFinal, this.isNominalFinal, this.columnNamesFinal, strArr);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static JavaRDD<String[]> getInputAsRDD() {
        JavaRDD<String[]> javaRDD = null;
        if (inputFormat.equals(CommonParameter.FileFormat.TEXTFILE)) {
            javaRDD = readTextFile(inputDirectory);
        } else if (inputFormat.equals(CommonParameter.FileFormat.PARQUET)) {
            javaRDD = convertDataFrameToRDD(readParquetFile());
        }
        return javaRDD;
    }

    protected static JavaRDD<String[]> convertDataFrameToRDD(Dataset<Row> dataset) {
        dataset.javaRDD();
        return dataset.toJavaRDD().map(new Function<Row, String[]>() { // from class: eu.radoop.spark.AbstractSparkRunner.3
            private static final long serialVersionUID = -5580788934085658408L;

            public String[] call(Row row) throws Exception {
                String[] strArr = new String[row.size()];
                for (int i = 0; i < strArr.length; i++) {
                    if (!row.isNullAt(i)) {
                        strArr[i] = row.get(i).toString();
                    }
                }
                return strArr;
            }
        });
    }

    protected static JavaRDD<String[]> checkForNull(JavaRDD<String[]> javaRDD, final Integer... numArr) throws SparkException {
        final String[] strArr = columnNames;
        return javaRDD.map(new Function<String[], String[]>() { // from class: eu.radoop.spark.AbstractSparkRunner.4
            private static final long serialVersionUID = -533664575373006435L;

            public String[] call(String[] strArr2) throws Exception {
                for (Integer num : numArr) {
                    int intValue = num.intValue();
                    if (strArr2[intValue] == null) {
                        throw new SparkException(AbstractSparkRunner.createInputException("The attribute " + strArr[intValue] + " should not contain missings for this algorithm!"));
                    }
                }
                return strArr2;
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void checkMissingLabel(Map<Integer, List<String>> map) throws SparkException {
        checkMissingLabel(null, map);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static JavaRDD<String[]> checkMissingLabel(JavaRDD<String[]> javaRDD, Map<Integer, List<String>> map) throws SparkException {
        if (javaRDD != null && mappingProvided[labelIndex.intValue()].booleanValue()) {
            return checkForNull(javaRDD, labelIndex);
        }
        if (map.get(labelIndex) == null) {
            throw new SparkException("Could not find nominal values of the label!");
        }
        if (map.get(labelIndex).contains(null)) {
            throw new SparkException(createInputException("Label should not contain missing values!"));
        }
        return javaRDD;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void checkBinominalLabel(Map<Integer, List<String>> map) throws SparkException {
        if (map.get(labelIndex) == null) {
            throw new SparkException("Could not find nominal values of the label!");
        }
        if (map.get(labelIndex).size() != 2) {
            if (map.get(labelIndex).size() != 3 || !map.get(labelIndex).contains(null)) {
                throw new SparkException(createInputException("This algorithm does not have sufficient capabilities for handling an example set with less or more than 2 different values!"));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Map<String, List<String>> convertNominalMappings(Map<Integer, List<String>> map) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<Integer, List<String>> entry : map.entrySet()) {
            hashMap.put(columnNames[entry.getKey().intValue()], entry.getValue());
        }
        return hashMap;
    }

    private static Map<Integer, List<String>> discoverNominalMappings(JavaRDD<String[]> javaRDD, boolean z, final int... iArr) {
        if (z) {
            javaRDD.cache();
        }
        return javaRDD.flatMapToPair(new PairFlatMapFunction<String[], Integer, String>() { // from class: eu.radoop.spark.AbstractSparkRunner.8
            private static final long serialVersionUID = 0;

            public Iterator<scala.Tuple2<Integer, String>> call(String[] strArr) throws Exception {
                ArrayList arrayList = new ArrayList(iArr.length);
                for (int i = 0; i < iArr.length; i++) {
                    if (strArr == null || strArr.length <= iArr[i]) {
                        throw new SparkException("Could not find nominal value: attribute index is " + iArr[i] + ", values array size is " + (strArr == null ? "null" : Integer.valueOf(strArr.length)));
                    }
                    arrayList.add(new scala.Tuple2(Integer.valueOf(iArr[i]), strArr[iArr[i]]));
                }
                return arrayList.iterator();
            }
        }).aggregateByKey(new HashSet(), new Function2<Set<String>, String, Set<String>>() { // from class: eu.radoop.spark.AbstractSparkRunner.6
            private static final long serialVersionUID = 0;

            public Set<String> call(Set<String> set, String str) throws Exception {
                set.add(str);
                return set;
            }
        }, new Function2<Set<String>, Set<String>, Set<String>>() { // from class: eu.radoop.spark.AbstractSparkRunner.7
            private static final long serialVersionUID = 0;

            public Set<String> call(Set<String> set, Set<String> set2) throws Exception {
                set.addAll(set2);
                return set;
            }
        }).mapToPair(new PairFunction<scala.Tuple2<Integer, Set<String>>, Integer, List<String>>() { // from class: eu.radoop.spark.AbstractSparkRunner.5
            private static final long serialVersionUID = 0;

            public scala.Tuple2<Integer, List<String>> call(scala.Tuple2<Integer, Set<String>> tuple2) throws Exception {
                ArrayList arrayList = new ArrayList(tuple2._2);
                Collections.sort(arrayList, new Comparator<String>() { // from class: eu.radoop.spark.AbstractSparkRunner.5.1
                    @Override // java.util.Comparator
                    public int compare(String str, String str2) {
                        if (str == null && str2 == null) {
                            return 0;
                        }
                        if (str == null) {
                            return 1;
                        }
                        if (str2 == null) {
                            return -1;
                        }
                        return str.compareTo(str2);
                    }
                });
                return new scala.Tuple2<>(tuple2._1, arrayList);
            }
        }).collectAsMap();
    }

    private static StructType createSchema(String[] strArr, Boolean[] boolArr) {
        StructField[] structFieldArr = new StructField[strArr.length];
        for (int i = 0; i < structFieldArr.length; i++) {
            structFieldArr[i] = new StructField(strArr[i], boolArr[i].booleanValue() ? DataTypes.StringType : DataTypes.DoubleType, true, Metadata.empty());
        }
        return new StructType(structFieldArr);
    }

    private static JavaRDD<String[]> readTextFile(String str) {
        return splitStringRDD(sc.textFile(str), fieldSeparator, nullString);
    }

    private static Dataset<Row> readParquetFile() {
        Dataset<Row> parquetFile = sqlContext.parquetFile(new String[]{inputDirectory});
        for (StructField structField : parquetFile.schema().fields()) {
            if (structField.dataType().equals(DataTypes.BooleanType)) {
                parquetFile = parquetFile.withColumn(structField.name(), parquetFile.col(structField.name()).cast(DataTypes.StringType));
            }
        }
        return parquetFile;
    }

    private static JavaRDD<String[]> splitStringRDD(JavaRDD<String> javaRDD, final String str, final String str2) {
        return javaRDD.map(new Function<String, String[]>() { // from class: eu.radoop.spark.AbstractSparkRunner.9
            private static final long serialVersionUID = 1;

            public String[] call(String str3) throws Exception {
                String[] split = str3.split(Pattern.quote(str));
                for (int i = 0; i < split.length; i++) {
                    if (split[i].equals(str2)) {
                        split[i] = null;
                    }
                }
                return split;
            }
        });
    }

    private static Dataset<Row> createDataFrame(JavaRDD<String[]> javaRDD, final String[] strArr, final Boolean[] boolArr) {
        return sqlContext.createDataFrame(javaRDD.map(new Function<String[], Row>() { // from class: eu.radoop.spark.AbstractSparkRunner.10
            private static final long serialVersionUID = 1;

            public Row call(String[] strArr2) throws Exception {
                Object[] objArr = new Object[strArr.length];
                for (int i = 0; i < strArr.length; i++) {
                    Object obj = null;
                    if (boolArr[i].booleanValue()) {
                        obj = strArr2[i];
                    } else if (strArr2[i] != null) {
                        try {
                            obj = Double.valueOf(strArr2[i]);
                        } catch (NumberFormatException e) {
                            throw new SparkException(AbstractSparkRunner.createInputException("The non-nominal attribute: " + strArr[i] + " should not contain the non-numeric value: " + strArr2[i] + " other than the missing placeholder."), e);
                        }
                    } else {
                        continue;
                    }
                    objArr[i] = obj;
                }
                return RowFactory.create(objArr);
            }
        }), createSchema(strArr, boolArr));
    }

    private static void reorderDiscoveredMappings(Map<Integer, List<String>> map) {
        for (Map.Entry<Integer, List<String>> entry : map.entrySet()) {
            List<String> list = knownNominalMappings.get(entry.getKey());
            List<String> value = entry.getValue();
            boolean contains = value.contains(null);
            if (list != null) {
                HashSet hashSet = new HashSet();
                hashSet.addAll(value);
                hashSet.remove(null);
                HashSet hashSet2 = new HashSet();
                hashSet2.addAll(list);
                hashSet2.remove(null);
                if (hashSet.equals(hashSet2) && contains && !list.contains(null)) {
                    list.add(null);
                }
            }
        }
    }

    private static Vector rawValuesToTypedVector(Map<Integer, List<String>> map, Integer[] numArr, Boolean[] boolArr, String[] strArr, String[] strArr2) throws SparkException {
        double[] dArr = new double[numArr.length];
        for (int i = 0; i < dArr.length; i++) {
            int intValue = numArr[i].intValue();
            String str = strArr2[intValue];
            String str2 = strArr[intValue];
            if (boolArr[intValue].booleanValue()) {
                List<String> list = map.get(Integer.valueOf(intValue));
                if (list == null) {
                    throw new SparkException(createInputException("The nominal mapping of feature attribute: " + str2 + "is undiscovered."));
                }
                int indexOf = list.indexOf(str);
                if (indexOf == -1) {
                    if (str == null) {
                        throw new SparkException(createInputException("The nominal feature attribute: " + str2 + " should not contain missings as nominal value discovering is skipped."));
                    }
                    throw new SparkException(createInputException("The nominal mapping of feature attribute: " + str2 + " does not contain member: " + str));
                }
                dArr[i] = indexOf;
            } else {
                if (str == null) {
                    throw new SparkException(createInputException("The non-nominal feature attribute: " + str2 + " should not contain missing."));
                }
                try {
                    dArr[i] = Double.valueOf(str).doubleValue();
                } catch (NumberFormatException e) {
                    throw new SparkException(createInputException("The non-nominal feature attribute: " + str2 + " should not contain the non-numeric value: " + str));
                }
            }
        }
        return Vectors.dense(dArr);
    }
}
