package eu.radoop.spark;

import eu.radoop.transfer.model.LinearRegressionMTO;
import eu.radoop.transfer.model.ModelTransferObject;
import eu.radoop.transfer.parameter.ParameterTransferObject;
import eu.radoop.transfer.parameter.SparkLinearRegressionParameter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.optimization.GradientDescent;
import org.apache.spark.mllib.regression.GeneralizedLinearModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD;

/* loaded from: input_file:lib/radoop-spark3.jar:eu/radoop/spark/SparkLinearRegressionRunner.class */
public class SparkLinearRegressionRunner extends AbstractSparkRunner {
    public static void main(String[] strArr) throws SparkException {
        try {
            persistModel(learn(init(strArr)));
        } catch (Exception e) {
            processException(e);
        } finally {
            close();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static ModelTransferObject learn(String str) throws SparkException {
        LinearRegressionWithSGD ridgeRegressionWithSGD;
        ParameterTransferObject parameterTransferObject = new ParameterTransferObject(str, SparkLinearRegressionParameter.class);
        System.out.println("pto:\n" + String.valueOf(parameterTransferObject));
        SparkLinearRegressionParameter.Method valueOf = SparkLinearRegressionParameter.Method.valueOf(parameterTransferObject.getParameterAsString(SparkLinearRegressionParameter.METHOD));
        int intValue = parameterTransferObject.getParameterAsInteger(SparkLinearRegressionParameter.ITERATIONS).intValue();
        double doubleValue = parameterTransferObject.getParameterAsDouble(SparkLinearRegressionParameter.REG_PARAM).doubleValue();
        boolean booleanValue = parameterTransferObject.getParameterAsBoolean(SparkLinearRegressionParameter.ADD_INTERCEPT).booleanValue();
        boolean booleanValue2 = parameterTransferObject.getParameterAsBoolean(SparkLinearRegressionParameter.USE_FEATURE_SCALING).booleanValue();
        double doubleValue2 = parameterTransferObject.getParameterAsDouble(SparkLinearRegressionParameter.STEP_SIZE).doubleValue();
        double doubleValue3 = parameterTransferObject.getParameterAsDouble(SparkLinearRegressionParameter.MINI_BATCH_FRACTION).doubleValue();
        double doubleValue4 = parameterTransferObject.getParameterAsDouble(SparkLinearRegressionParameter.CONVERGENCE_TOL).doubleValue();
        if (valueOf == SparkLinearRegressionParameter.Method.LINEAR_REGRESSION) {
            ridgeRegressionWithSGD = new LinearRegressionWithSGD(doubleValue2, intValue, doubleValue, doubleValue3);
        } else if (valueOf == SparkLinearRegressionParameter.Method.LASSO_REGRESSION) {
            ridgeRegressionWithSGD = new LassoWithSGD(doubleValue2, intValue, doubleValue, doubleValue3);
        } else {
            if (valueOf != SparkLinearRegressionParameter.Method.RIDGE_REGRESSION) {
                throw new IllegalArgumentException("Invalid regression method! Valid methods are Linear, Lasso and Ridge.");
            }
            ridgeRegressionWithSGD = new RidgeRegressionWithSGD(doubleValue2, intValue, doubleValue, doubleValue3);
        }
        GradientDescent optimizer = ridgeRegressionWithSGD.optimizer();
        optimizer.setStepSize(doubleValue2).setMiniBatchFraction(doubleValue3).setNumIterations(intValue);
        if (valueOf != SparkLinearRegressionParameter.Method.LINEAR_REGRESSION) {
            optimizer.setRegParam(doubleValue);
        }
        optimizer.setConvergenceTol(doubleValue4);
        ridgeRegressionWithSGD.setIntercept(booleanValue);
        ridgeRegressionWithSGD.setFeatureScaling(booleanValue2);
        JavaRDD<String[]> inputAsRDD = getInputAsRDD();
        Map hashMap = new HashMap();
        if (isNominal[labelIndex.intValue()].booleanValue()) {
            hashMap = discoverLabelMappings(inputAsRDD, true);
            inputAsRDD = checkMissingLabel(inputAsRDD, hashMap);
            checkBinominalLabel(hashMap);
        }
        JavaRDD<LabeledPoint> labeledPointRDD = getLabeledPointRDD(inputAsRDD, hashMap);
        labeledPointRDD.cache();
        return convertLinearRegressionModel(featureColumns, ridgeRegressionWithSGD.run(labeledPointRDD.rdd()), booleanValue, hashMap);
    }

    private static ModelTransferObject convertLinearRegressionModel(String[] strArr, GeneralizedLinearModel generalizedLinearModel, boolean z, Map<Integer, List<String>> map) {
        return new LinearRegressionMTO(featureColumns, generalizedLinearModel.weights().toArray(), generalizedLinearModel.intercept(), z, convertNominalMappings(map));
    }
}
