package eu.radoop.datahandler.hive.udf;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ShortObjectInspector;

@Description(name = "score_naive_bayes")
/* loaded from: input_file:lib/radoop_hive-v4.jar:eu/radoop/datahandler/hive/udf/GenericUDFScoreNaiveBayes.class */
public class GenericUDFScoreNaiveBayes extends GenericUDF implements RadoopUDF {
    private ObjectInspector[] argOIs;
    private List<DoubleWritable> result;
    private HiveSimpleDistributionModelSkeleton model;

    /* loaded from: input_file:lib/radoop_hive-v4.jar:eu/radoop/datahandler/hive/udf/GenericUDFScoreNaiveBayes$HiveSimpleDistributionModelSkeleton.class */
    public static class HiveSimpleDistributionModelSkeleton {
        public final int numberOfClasses;
        public final int numberOfAttributes;
        public final boolean[] nominal;
        public final String[][] attributeValues;
        public final double[] priors;
        public final double[][][] distributionProperties;
        public final boolean laplaceCorrectionEnabled;
        public final List<Map<String, Integer>> nominalMap;
        private int index = 0;
        private StringBuilder tempStr = new StringBuilder();

        private int parseNextInt(String str) {
            while (this.index < str.length()) {
                char charAt = str.charAt(this.index);
                if (charAt != ',') {
                    this.tempStr.append(charAt);
                }
                if (charAt == ',' || this.index == str.length() - 1) {
                    int parseInt = Integer.parseInt(this.tempStr.toString());
                    this.tempStr.setLength(0);
                    this.index++;
                    return parseInt;
                }
                this.index++;
            }
            return -1;
        }

        private double parseNextDouble(String str) {
            while (this.index < str.length()) {
                char charAt = str.charAt(this.index);
                if (charAt != ',') {
                    this.tempStr.append(charAt);
                }
                if (charAt == ',' || this.index == str.length() - 1) {
                    double parseDouble = Double.parseDouble(this.tempStr.toString());
                    this.tempStr.setLength(0);
                    this.index++;
                    return parseDouble;
                }
                this.index++;
            }
            return -1.0d;
        }

        private double[] parseNextDoubles(String str, int i) {
            double[] dArr = new double[i];
            for (int i2 = 0; i2 < i; i2++) {
                dArr[i2] = parseNextDouble(str);
            }
            return dArr;
        }

        private boolean[] parseNextBooleans(String str, int i) {
            boolean[] zArr = new boolean[i];
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = i2;
                int i4 = this.index;
                this.index = i4 + 1;
                zArr[i3] = str.charAt(i4) == 't';
            }
            return zArr;
        }

        private String parseNextString(String str) {
            boolean z = true;
            boolean z2 = false;
            while (this.index < str.length()) {
                char charAt = str.charAt(this.index);
                if (charAt == '\"') {
                    if (z) {
                        z = false;
                    } else {
                        if (!z2) {
                            String sb = this.tempStr.toString();
                            this.tempStr.setLength(0);
                            this.index++;
                            return sb;
                        }
                        this.tempStr.append(charAt);
                        z2 = !z2;
                    }
                } else if (charAt == '\\') {
                    if (z2) {
                        this.tempStr.append(charAt);
                    }
                    z2 = !z2;
                } else {
                    if (z2) {
                        throw new IllegalArgumentException("An escape char should only be followed by an escape char or a quote (index at " + this.index + ")");
                    }
                    this.tempStr.append(charAt);
                }
                this.index++;
            }
            this.tempStr.setLength(0);
            return null;
        }

        private String[] parseNextStrings(String str, int i) {
            String[] strArr = new String[i];
            for (int i2 = 0; i2 < i; i2++) {
                strArr[i2] = parseNextString(str);
                if (this.index < str.length() && str.charAt(this.index) == ',') {
                    this.index++;
                }
            }
            return strArr;
        }

        /* JADX WARN: Type inference failed for: r1v14, types: [java.lang.String[], java.lang.String[][]] */
        public HiveSimpleDistributionModelSkeleton(String str) {
            try {
                this.numberOfClasses = parseNextInt(str);
                this.numberOfAttributes = parseNextInt(str);
                this.laplaceCorrectionEnabled = parseNextBooleans(str, 1)[0];
                this.nominal = parseNextBooleans(str, this.numberOfAttributes);
                this.attributeValues = new String[this.numberOfAttributes];
                this.nominalMap = new ArrayList();
                for (int i = 0; i < this.numberOfAttributes; i++) {
                    if (this.nominal[i]) {
                        this.attributeValues[i] = parseNextStrings(str, parseNextInt(str));
                        this.nominalMap.add(new HashMap());
                        for (int i2 = 0; i2 < this.attributeValues[i].length - 1; i2++) {
                            this.nominalMap.get(this.nominalMap.size() - 1).put(this.attributeValues[i][i2], Integer.valueOf(i2));
                        }
                    }
                }
                this.priors = parseNextDoubles(str, this.numberOfClasses);
                this.distributionProperties = new double[this.numberOfAttributes][this.numberOfClasses];
                for (int i3 = 0; i3 < this.distributionProperties.length; i3++) {
                    for (int i4 = 0; i4 < this.distributionProperties[i3].length; i4++) {
                        this.distributionProperties[i3][i4] = parseNextDoubles(str, this.nominal[i3] ? this.attributeValues[i3].length : 3);
                    }
                }
            } catch (NumberFormatException e) {
                throw new RuntimeException("Error during model deserialization: at index " + this.index + ": " + e.getMessage());
            }
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        this.argOIs = objectInspectorArr;
        if (ObjectInspectorUtils.isConstantObjectInspector(this.argOIs[0])) {
            return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        }
        throw new UDFArgumentTypeException(0, "The first argument must be a constant string (model file path) but " + this.argOIs[0].toString() + " was given.");
    }

    private void initModel() throws HiveException {
        InputStreamReader inputStreamReader;
        String obj = ObjectInspectorUtils.getWritableConstantValue(this.argOIs[0]).toString();
        String str = "./" + obj.substring(obj.lastIndexOf(47) + 1);
        StringBuilder sb = new StringBuilder();
        try {
            try {
                inputStreamReader = new InputStreamReader(new FileInputStream(str), StandardCharsets.UTF_8);
            } catch (FileNotFoundException e) {
                try {
                    inputStreamReader = new InputStreamReader(new FileInputStream(SessionState.get().getConf().getVar(HiveConf.ConfVars.DOWNLOADED_RESOURCES_DIR) + File.separator + str.substring(2)), StandardCharsets.UTF_8);
                } catch (FileNotFoundException e2) {
                    throw new HiveException("Could not read model file from Distributed Cache (file not found in resource dir): " + SessionState.get().getConf().getVar(HiveConf.ConfVars.DOWNLOADED_RESOURCES_DIR) + File.separator + str.substring(2));
                } catch (NullPointerException e3) {
                    throw new HiveException("Could not read model file from Distributed Cache (tried resource dir): " + str);
                }
            }
            for (int read = inputStreamReader.read(); read != -1; read = inputStreamReader.read()) {
                sb.append((char) read);
            }
            inputStreamReader.close();
            this.model = new HiveSimpleDistributionModelSkeleton(sb.toString());
            this.result = new ArrayList(this.model.numberOfClasses);
            for (int i = 0; i < this.model.numberOfClasses; i++) {
                this.result.add(new DoubleWritable(Double.NaN));
            }
        } catch (IOException e4) {
            throw new RuntimeException("IOException during reading " + str, e4);
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        double d;
        if (this.model == null) {
            initModel();
        }
        double[] dArr = new double[this.model.numberOfClasses];
        double d2 = 0.0d;
        for (int i = 0; i < this.model.numberOfClasses; i++) {
            dArr[i] = this.model.priors[i];
        }
        int i2 = 0;
        for (int i3 = 0; i3 < deferredObjectArr.length - 1; i3++) {
            Object obj = deferredObjectArr[i3 + 1].get();
            if (this.model.nominal[i3]) {
                if (obj != null) {
                    Integer num = this.model.nominalMap.get(i2).get(obj.toString());
                    if (num != null) {
                        for (int i4 = 0; i4 < this.model.numberOfClasses; i4++) {
                            if (num.intValue() < this.model.distributionProperties[i3][i4].length) {
                                int i5 = i4;
                                dArr[i5] = dArr[i5] + this.model.distributionProperties[i3][i4][num.intValue()];
                            }
                        }
                    }
                } else {
                    for (int i6 = 0; i6 < this.model.numberOfClasses; i6++) {
                        int i7 = i6;
                        dArr[i7] = dArr[i7] + this.model.distributionProperties[i3][i6][this.model.distributionProperties[i3][i6].length - 1];
                    }
                }
                i2++;
            } else if (obj != null) {
                switch (((PrimitiveObjectInspector) this.argOIs[i3 + 1]).getPrimitiveCategory()) {
                    case VOID:
                    case FLOAT:
                        d = ((FloatObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    case DOUBLE:
                        d = ((DoubleObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    case BYTE:
                        d = ((ByteObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    case SHORT:
                        d = ((ShortObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    case INT:
                        d = ((IntObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    case LONG:
                        d = ((LongObjectInspector) this.argOIs[i3 + 1]).get(obj);
                        break;
                    default:
                        throw new HiveException("Unknown type for argument " + (i3 + 1) + ": " + this.argOIs[i3 + 1].getTypeName());
                }
                for (int i8 = 0; i8 < this.model.numberOfClasses; i8++) {
                    double d3 = (d - this.model.distributionProperties[i3][i8][0]) / this.model.distributionProperties[i3][i8][1];
                    int i9 = i8;
                    dArr[i9] = dArr[i9] - (this.model.distributionProperties[i3][i8][2] + ((0.5d * d3) * d3));
                }
            } else {
                continue;
            }
        }
        double d4 = Double.NEGATIVE_INFINITY;
        for (int i10 = 0; i10 < this.model.numberOfClasses; i10++) {
            if (!Double.isNaN(dArr[i10]) && dArr[i10] > d4) {
                d4 = dArr[i10];
            }
        }
        for (int i11 = 0; i11 < this.model.numberOfClasses; i11++) {
            if (Double.isNaN(dArr[i11])) {
                dArr[i11] = 0.0d;
            } else {
                dArr[i11] = Math.exp(dArr[i11] - d4);
                d2 += dArr[i11];
            }
        }
        if (d4 != Double.NEGATIVE_INFINITY) {
            for (int i12 = 0; i12 < this.model.numberOfClasses; i12++) {
                this.result.get(i12).set(dArr[i12] / d2);
            }
        }
        return this.result;
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public String[] getRequiredFiles() {
        return new String[]{ObjectInspectorUtils.getWritableConstantValue(this.argOIs[0]).toString()};
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public String getDisplayString(String[] strArr) {
        return UDFUtils.getFuctionString(getName(), strArr);
    }

    public static String getName() {
        return UDFUtils.getLocalName(GenericUDFScoreNaiveBayes.class);
    }
}
