package eu.radoop.spark;

import eu.radoop.transfer.model.ModelTransferObject;
import eu.radoop.transfer.model.SupportVectorMachineMTO;
import eu.radoop.transfer.parameter.ParameterTransferObject;
import eu.radoop.transfer.parameter.SparkSupportVectorMachineParameter;
import java.util.List;
import java.util.Map;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.optimization.GradientDescent;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.optimization.SimpleUpdater;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.regression.LabeledPoint;

/* loaded from: input_file:lib/radoop-spark3.jar:eu/radoop/spark/SparkSupportVectorMachineRunner.class */
public class SparkSupportVectorMachineRunner extends AbstractSparkRunner {
    public static void main(String[] strArr) throws SparkException {
        try {
            persistModel(learn(init(strArr)));
        } catch (Exception e) {
            processException(e);
        } finally {
            close();
        }
    }

    public static ModelTransferObject learn(String str) throws SparkException {
        ParameterTransferObject parameterTransferObject = new ParameterTransferObject(str, SparkSupportVectorMachineParameter.class);
        System.out.println("pto:\n" + parameterTransferObject);
        int intValue = parameterTransferObject.getParameterAsInteger(SparkSupportVectorMachineParameter.ITERATIONS).intValue();
        double doubleValue = parameterTransferObject.getParameterAsDouble(SparkSupportVectorMachineParameter.REG_PARAM).doubleValue();
        SparkSupportVectorMachineParameter.Updater valueOf = SparkSupportVectorMachineParameter.Updater.valueOf(parameterTransferObject.getParameterAsString(SparkSupportVectorMachineParameter.UPDATER));
        boolean booleanValue = parameterTransferObject.getParameterAsBoolean(SparkSupportVectorMachineParameter.ADD_INTERCEPT).booleanValue();
        boolean booleanValue2 = parameterTransferObject.getParameterAsBoolean(SparkSupportVectorMachineParameter.USE_FEATURE_SCALING).booleanValue();
        double doubleValue2 = parameterTransferObject.getParameterAsDouble(SparkSupportVectorMachineParameter.STEP_SIZE).doubleValue();
        double doubleValue3 = parameterTransferObject.getParameterAsDouble(SparkSupportVectorMachineParameter.MINI_BATCH_FRACTION).doubleValue();
        double doubleValue4 = parameterTransferObject.getParameterAsDouble(SparkSupportVectorMachineParameter.CONVERGENCE_TOL).doubleValue();
        SVMWithSGD sVMWithSGD = new SVMWithSGD();
        GradientDescent optimizer = sVMWithSGD.optimizer();
        optimizer.setStepSize(doubleValue2).setMiniBatchFraction(doubleValue3).setNumIterations(intValue).setRegParam(doubleValue);
        optimizer.setConvergenceTol(doubleValue4);
        if (valueOf.equals(SparkSupportVectorMachineParameter.Updater.L1_UPDATER)) {
            optimizer.setUpdater(new L1Updater());
        } else if (valueOf.equals(SparkSupportVectorMachineParameter.Updater.L2_UPDATER)) {
            optimizer.setUpdater(new SquaredL2Updater());
        } else {
            if (!valueOf.equals(SparkSupportVectorMachineParameter.Updater.SIMPLE_UPDATER)) {
                throw new IllegalArgumentException("Invalid updater! Valid updaters are Simple, L1 and Squared L2.");
            }
            optimizer.setUpdater(new SimpleUpdater());
        }
        sVMWithSGD.setIntercept(booleanValue);
        sVMWithSGD.setFeatureScaling(booleanValue2);
        JavaRDD<String[]> inputAsRDD = getInputAsRDD();
        Map<Integer, List<String>> discoverLabelMappings = discoverLabelMappings(inputAsRDD, true);
        JavaRDD<String[]> checkMissingLabel = checkMissingLabel(inputAsRDD, discoverLabelMappings);
        checkBinominalLabel(discoverLabelMappings);
        JavaRDD<LabeledPoint> labeledPointRDD = getLabeledPointRDD(checkMissingLabel, discoverLabelMappings);
        labeledPointRDD.cache();
        return convertSVMModel(featureColumns, sVMWithSGD.run(labeledPointRDD.rdd()), booleanValue, discoverLabelMappings);
    }

    private static ModelTransferObject convertSVMModel(String[] strArr, SVMModel sVMModel, boolean z, Map<Integer, List<String>> map) {
        return new SupportVectorMachineMTO(strArr, sVMModel.weights().toArray(), sVMModel.intercept(), z, convertNominalMappings(map));
    }
}
