package eu.radoop.datahandler.hive.udf;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.utils.ExampleSets;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.Ontology;
import eu.radoop.hive.HiveStaticUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Text;

@Description(name = "apply_model", value = "_FUNC_(model, array of list of regular attribute names, # of class values or 0 for regression or -1 for clustering, list of class values, list of regular attribute parameters) - reads a model file from hadoop cache and applies it to all values of a row.")
/* loaded from: input_file:lib/radoop-hive-rapidminer.jar:eu/radoop/datahandler/hive/udf/GenericUDTFApplyModel.class */
public class GenericUDTFApplyModel extends GenericUDTF implements RadoopUDF {
    PrimitiveObjectInspector modelOI = null;
    ListObjectInspector nameOI = null;
    List<PrimitiveObjectInspector> parameterOIs = new ArrayList();
    Map<String, Integer> attributeNameIndex = null;
    Model model = null;
    ExampleSet header = null;
    Map<String, String> attributeCanonicalMap;
    boolean clustering;
    int numberOfClasses;
    List<String> classNames;

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
    public void close() throws HiveException {
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentException("This function requires a model name as first parameter.");
        }
        if (!((PrimitiveObjectInspector) objectInspectorArr[0]).getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.STRING)) {
            throw new UDFArgumentTypeException(0, "A model file name was required instead " + objectInspectorArr[0].getTypeName() + " was passed.");
        }
        if (objectInspectorArr[1].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentException("This function requires an array of attribute names as second parameter.");
        }
        if (!((ListObjectInspector) objectInspectorArr[1]).getListElementObjectInspector().getCategory().equals(ObjectInspector.Category.PRIMITIVE)) {
            throw new UDFArgumentTypeException(1, "An array of attribute names in string format needs to be passed.");
        }
        if (objectInspectorArr[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentException("This function requires the number of class values as third parameter.");
        }
        switch (((PrimitiveObjectInspector) objectInspectorArr[2]).getPrimitiveCategory()) {
            case INT:
            case LONG:
                try {
                    int parseInt = Integer.parseInt(ObjectInspectorUtils.getWritableConstantValue(objectInspectorArr[2]).toString());
                    this.clustering = parseInt < 0;
                    this.numberOfClasses = Math.max(parseInt, 0);
                    int i = 3;
                    if (parseInt > 0) {
                        this.classNames = new ArrayList();
                        while (i < this.numberOfClasses + 3) {
                            if (objectInspectorArr[i].getCategory() != ObjectInspector.Category.PRIMITIVE || !((PrimitiveObjectInspector) objectInspectorArr[i]).getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.STRING)) {
                                throw new UDFArgumentTypeException(i, "A list of class names in string format needs to be passed.");
                            }
                            this.classNames.add(ObjectInspectorUtils.getWritableConstantValue(objectInspectorArr[i]).toString());
                            i++;
                        }
                    }
                    for (int i2 = i; i2 < objectInspectorArr.length - 1; i2++) {
                        if (objectInspectorArr[i2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                            throw new UDFArgumentException("This function requires all table parameters to deliver single values.");
                        }
                        switch (((PrimitiveObjectInspector) objectInspectorArr[i2]).getPrimitiveCategory()) {
                            case INT:
                            case LONG:
                            case FLOAT:
                            case DOUBLE:
                            case SHORT:
                            case STRING:
                            case BOOLEAN:
                                this.parameterOIs.add((PrimitiveObjectInspector) objectInspectorArr[i2]);
                            default:
                                throw new UDFArgumentTypeException(0, "Only primitive type arguments (excluding bytes and time stamps) are accepted instead " + objectInspectorArr[i2].getTypeName() + " was passed.");
                        }
                    }
                    this.modelOI = (PrimitiveObjectInspector) objectInspectorArr[0];
                    this.nameOI = (ListObjectInspector) objectInspectorArr[1];
                    ArrayList arrayList = new ArrayList();
                    ArrayList arrayList2 = new ArrayList();
                    arrayList.add("predicted_label");
                    if (this.clustering) {
                        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
                    } else if (this.numberOfClasses > 0) {
                        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
                        Iterator<String> it = this.classNames.iterator();
                        while (it.hasNext()) {
                            arrayList.add("confidence_" + it.next());
                            arrayList2.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
                        }
                    } else {
                        arrayList2.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
                    }
                    return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
                } catch (NumberFormatException e) {
                    throw new UDFArgumentTypeException(2, "The third parameter must be an integer. " + e.getMessage());
                }
            default:
                throw new UDFArgumentTypeException(2, "The number of classes was required instead " + objectInspectorArr[0].getTypeName() + " was passed.");
        }
    }

    private void readModel(String str) throws HiveException {
        File file = new File(str);
        if (!file.exists()) {
            try {
                file = new File(SessionState.get().getConf().getVar(HiveConf.ConfVars.DOWNLOADED_RESOURCES_DIR) + File.separator + str.substring(2));
                if (!file.exists()) {
                    throw new HiveException("Could not read model file from Distributed Cache: " + str);
                }
            } catch (NullPointerException e) {
                throw new HiveException("Could not read model file from Distributed Cache (tried resource dir): " + str);
            }
        }
        ObjectInputStream objectInputStream = null;
        try {
            try {
                try {
                    objectInputStream = new ObjectInputStream(new FileInputStream(file));
                    Model model = (IOObject) objectInputStream.readObject();
                    if (objectInputStream != null) {
                        try {
                            objectInputStream.close();
                        } catch (IOException e2) {
                        }
                    }
                    if (!(model instanceof Model)) {
                        throw new UDFArgumentTypeException(0, "A valid RapidMiner model is required." + str + " does not contain one, its type is: " + model.getClass().getSimpleName());
                    }
                    this.model = model;
                } catch (ClassNotFoundException e3) {
                    throw new UDFArgumentTypeException(0, "Could not instantiate model from model file: " + e3.toString());
                }
            } catch (Throwable th) {
                if (objectInputStream != null) {
                    try {
                        objectInputStream.close();
                    } catch (IOException e4) {
                    }
                }
                throw th;
            }
        } catch (IOException e5) {
            throw new UDFArgumentTypeException(0, "A valid model file is required: " + e5.getMessage());
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
    public void process(Object[] objArr) throws HiveException {
        if (this.model != null) {
            LogService.getRoot().setLevel(Level.OFF);
        }
        if (this.model == null || this.header == null || this.attributeCanonicalMap == null) {
            readModel((String) this.modelOI.getPrimitiveJavaObject(objArr[0]));
            this.header = this.model.getTrainingHeader();
            this.attributeCanonicalMap = new HashMap();
            Iterator allAttributes = this.header.getAttributes().allAttributes();
            while (allAttributes.hasNext()) {
                Attribute attribute = (Attribute) allAttributes.next();
                String canonicalAttributeName = HiveStaticUtils.getCanonicalAttributeName(attribute.getName());
                if (!canonicalAttributeName.equals(attribute.getName())) {
                    this.attributeCanonicalMap.put(attribute.getName(), canonicalAttributeName);
                }
            }
        }
        if (this.attributeNameIndex == null) {
            this.attributeNameIndex = new HashMap();
            int i = 3 + this.numberOfClasses;
            Iterator<?> it = this.nameOI.getList(objArr[1]).iterator();
            while (it.hasNext()) {
                this.attributeNameIndex.put(((Text) it.next()).toString(), Integer.valueOf(i));
                i++;
            }
        }
        ExampleSet createExampleSet = createExampleSet(objArr, this.header.getAttributes());
        Iterator specialAttributes = this.header.getAttributes().specialAttributes();
        while (specialAttributes.hasNext()) {
            AttributeRole attributeRole = (AttributeRole) specialAttributes.next();
            Attribute attribute2 = createExampleSet.getAttributes().get(attributeRole.getAttribute().getName());
            if (attribute2 != null) {
                createExampleSet.getAttributes().setSpecialAttribute(attribute2, attributeRole.getSpecialName());
            }
        }
        try {
            ExampleSet apply = this.model.apply(createExampleSet);
            Example example = apply.getExample(0);
            Object[] objArr2 = new Object[this.numberOfClasses + 1];
            if (this.clustering) {
                Attribute label = this.model.isAddingLabel() ? apply.getAttributes().getLabel() : apply.getAttributes().getCluster();
                if (Double.isNaN(example.getValue(label))) {
                    objArr2[0] = null;
                } else {
                    objArr2[0] = example.getNominalValue(label);
                }
            } else if (this.numberOfClasses > 0) {
                Attribute predictedLabel = apply.getAttributes().getPredictedLabel();
                if (Double.isNaN(example.getValue(predictedLabel))) {
                    objArr2[0] = null;
                } else {
                    objArr2[0] = example.getNominalValue(predictedLabel);
                }
                int i2 = 1;
                Iterator<String> it2 = this.classNames.iterator();
                while (it2.hasNext()) {
                    int i3 = i2;
                    i2++;
                    objArr2[i3] = Double.valueOf(example.getConfidence(it2.next()));
                }
            } else {
                objArr2[0] = Double.valueOf(example.getPredictedLabel());
            }
            forward(objArr2);
        } catch (OperatorException e) {
            throw new HiveException("The model could not successfully be applied to current table data: " + e.getMessage());
        }
    }

    private ExampleSet createExampleSet(Object[] objArr, Attributes attributes) throws HiveException {
        double d;
        DataRow create = new DataRowFactory(0, '.').create(attributes.allSize());
        ArrayList arrayList = new ArrayList();
        Iterator allAttributeRoles = attributes.allAttributeRoles();
        int i = -1;
        while (allAttributeRoles.hasNext()) {
            i++;
            AttributeRole attributeRole = (AttributeRole) allAttributeRoles.next();
            Attribute attribute = attributeRole.getAttribute();
            int valueType = attribute.getValueType();
            Attribute attribute2 = (Attribute) attribute.clone();
            attribute2.setTableIndex(i);
            arrayList.add(attribute2);
            String str = this.attributeCanonicalMap.get(attribute.getName());
            String name = str == null ? attribute.getName() : str;
            if (this.attributeNameIndex.containsKey(name)) {
                int intValue = this.attributeNameIndex.get(name).intValue();
                int i2 = (intValue - 3) - this.numberOfClasses;
                if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, 2)) {
                    this.parameterOIs.get(i2);
                    if (Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, 3)) {
                        switch (r0.getPrimitiveCategory()) {
                            case INT:
                                if (objArr[intValue] == null) {
                                    d = Double.NaN;
                                    break;
                                } else {
                                    d = Integer.parseInt(objArr[intValue].toString());
                                    break;
                                }
                            case LONG:
                            default:
                                if (objArr[intValue] == null) {
                                    d = Double.NaN;
                                    break;
                                } else {
                                    d = Long.parseLong(objArr[intValue].toString());
                                    break;
                                }
                        }
                    } else {
                        d = objArr[intValue] != null ? Double.parseDouble(objArr[intValue].toString()) : Double.NaN;
                    }
                } else {
                    if (!Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, 1)) {
                        throw new HiveException("Could not apply model with attribute type: " + valueType + " (attribute: " + attribute.getName() + ").");
                    }
                    d = objArr[intValue] != null ? attribute.getMapping().mapString(objArr[intValue].toString()) : Double.NaN;
                }
                create.set(attribute2, d);
            } else {
                if (!attributeRole.isSpecial()) {
                    throw new HiveException("Could not find attribute '" + attribute.getName() + "' in the attribute name list (second parameter)");
                }
                create.set(attribute2, 0.0d);
            }
        }
        return ExampleSets.from(arrayList).addDataRow(create).build();
    }

    public String toString() {
        return UDFUtils.getLocalName(GenericUDTFApplyModel.class);
    }

    public static String getName() {
        return UDFUtils.getLocalName(GenericUDTFApplyModel.class);
    }
}
