package eu.radoop.datahandler.hive.udf;

import com.sun.jersey.core.header.QualityFactor;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;

@Description(name = "correlation_matrix", value = "_FUNC_(num1, num2, ..., numN) - Return the correlation matrix of the given attributes.", extended = "The function takes as arguments an undefined number of numeric types and returns a 2D-list containing Doubles.\nAny attribute with a NULL value is ignored, the matrix will contain Double.NaN in the corresponding rows and columns.")
/* loaded from: input_file:lib/radoop_hive-v4.jar:eu/radoop/datahandler/hive/udf/GenericUDAFCorrelationMatrix.class */
public class GenericUDAFCorrelationMatrix extends AbstractGenericUDAFResolver implements RadoopUDF {
    static final Log LOG = LogFactory.getLog(GenericUDAFCorrelationMatrix.class.getName());

    /* loaded from: input_file:lib/radoop_hive-v4.jar:eu/radoop/datahandler/hive/udf/GenericUDAFCorrelationMatrix$GenericUDAFCorrelationMatrixEvaluator.class */
    public static class GenericUDAFCorrelationMatrixEvaluator extends GenericUDAFEvaluator implements Serializable {
        private static final long serialVersionUID = 4297797964852105983L;
        private int size;
        private Double[] row;
        private Boolean[] attributeContainsNull;
        private transient PrimitiveObjectInspector[] inputOIs;
        private transient StructObjectInspector soi;
        private transient StructField rowCountField;
        private transient StructField listField;
        private transient StructField matrixField;
        private transient WritableLongObjectInspector rowCountFieldOI;
        private transient ListObjectInspector listFieldOI;
        private transient ListObjectInspector matrixFieldOI;
        private transient ListObjectInspector matrixRowOI;
        private transient Object[] partialResult;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:lib/radoop_hive-v4.jar:eu/radoop/datahandler/hive/udf/GenericUDAFCorrelationMatrix$GenericUDAFCorrelationMatrixEvaluator$CorrAggregationBuffer.class */
        public class CorrAggregationBuffer implements GenericUDAFEvaluator.AggregationBuffer {
            List<DoubleWritable> l = new ArrayList();
            SquareMatrix q;
            Long rows;

            public CorrAggregationBuffer() {
                for (int i = 0; i < GenericUDAFCorrelationMatrixEvaluator.this.size; i++) {
                    this.l.add(new DoubleWritable(0.0d));
                }
                this.q = new SquareMatrix();
                this.q.initMatrix(GenericUDAFCorrelationMatrixEvaluator.this.size);
                this.rows = 0L;
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.row = new Double[this.size];
                this.inputOIs = new PrimitiveObjectInspector[this.size];
                for (int i = 0; i < this.size; i++) {
                    this.inputOIs[i] = (PrimitiveObjectInspector) objectInspectorArr[i];
                }
            } else {
                this.soi = (StructObjectInspector) objectInspectorArr[0];
                this.rowCountField = this.soi.getStructFieldRef("count");
                this.listField = this.soi.getStructFieldRef("l");
                this.matrixField = this.soi.getStructFieldRef(QualityFactor.QUALITY_FACTOR);
                this.rowCountFieldOI = (WritableLongObjectInspector) this.rowCountField.getFieldObjectInspector();
                this.listFieldOI = (ListObjectInspector) this.listField.getFieldObjectInspector();
                this.matrixFieldOI = (ListObjectInspector) this.matrixField.getFieldObjectInspector();
                this.matrixRowOI = (ListObjectInspector) this.matrixFieldOI.getListElementObjectInspector();
            }
            if (mode != GenericUDAFEvaluator.Mode.PARTIAL1 && mode != GenericUDAFEvaluator.Mode.PARTIAL2) {
                return ObjectInspectorFactory.getStandardMapObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add("count");
            arrayList2.add("l");
            arrayList2.add(QualityFactor.QUALITY_FACTOR);
            this.partialResult = new Object[3];
            this.partialResult[0] = new LongWritable(0L);
            this.partialResult[1] = new ArrayList();
            this.partialResult[2] = new ArrayList();
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList2, arrayList);
        }

        public GenericUDAFCorrelationMatrixEvaluator() {
        }

        public GenericUDAFCorrelationMatrixEvaluator(int i) {
            this.size = i;
            this.attributeContainsNull = new Boolean[this.size];
            for (int i2 = 0; i2 < this.size; i2++) {
                this.attributeContainsNull[i2] = false;
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = new CorrAggregationBuffer();
            reset(corrAggregationBuffer);
            return corrAggregationBuffer;
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = (CorrAggregationBuffer) aggregationBuffer;
            corrAggregationBuffer.l.clear();
            for (int i = 0; i < this.size; i++) {
                corrAggregationBuffer.l.add(new DoubleWritable(0.0d));
            }
            corrAggregationBuffer.q = new SquareMatrix();
            corrAggregationBuffer.q.initMatrix(this.size);
            corrAggregationBuffer.rows = 0L;
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = (CorrAggregationBuffer) aggregationBuffer;
            for (int i = 0; i < objArr.length; i++) {
                if (objArr[i] != null) {
                    this.row[i] = Double.valueOf(PrimitiveObjectInspectorUtils.getDouble(objArr[i], this.inputOIs[i]));
                    DoubleWritable doubleWritable = corrAggregationBuffer.l.get(i);
                    doubleWritable.set(doubleWritable.get() + this.row[i].doubleValue());
                } else {
                    this.attributeContainsNull[i] = true;
                    this.row[i] = Double.valueOf(Double.NaN);
                }
            }
            corrAggregationBuffer.q.addPartial(this.row);
            Long l = corrAggregationBuffer.rows;
            corrAggregationBuffer.rows = Long.valueOf(corrAggregationBuffer.rows.longValue() + 1);
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = (CorrAggregationBuffer) aggregationBuffer;
            this.partialResult[0] = new LongWritable(corrAggregationBuffer.rows.longValue());
            this.partialResult[1] = corrAggregationBuffer.l;
            this.partialResult[2] = corrAggregationBuffer.q.getMxAsList();
            return this.partialResult;
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = (CorrAggregationBuffer) aggregationBuffer;
            if (obj != null) {
                Long valueOf = Long.valueOf(this.rowCountFieldOI.get(this.soi.getStructFieldData(obj, this.rowCountField)));
                List<?> list = this.listFieldOI.getList(this.soi.getStructFieldData(obj, this.listField));
                ArrayList arrayList = new ArrayList();
                Iterator<?> it = this.matrixFieldOI.getList(this.soi.getStructFieldData(obj, this.matrixField)).iterator();
                while (it.hasNext()) {
                    arrayList.add(this.matrixRowOI.getList(it.next()));
                }
                corrAggregationBuffer.q.addMatrix(arrayList);
                int i = 0;
                Iterator<?> it2 = list.iterator();
                while (it2.hasNext()) {
                    DoubleWritable doubleWritable = (DoubleWritable) it2.next();
                    DoubleWritable doubleWritable2 = corrAggregationBuffer.l.get(i);
                    doubleWritable2.set(doubleWritable2.get() + doubleWritable.get());
                    i++;
                }
                corrAggregationBuffer.rows = Long.valueOf(corrAggregationBuffer.rows.longValue() + valueOf.longValue());
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            CorrAggregationBuffer corrAggregationBuffer = (CorrAggregationBuffer) aggregationBuffer;
            return calculateCorrelationMatrix(corrAggregationBuffer.l, corrAggregationBuffer.q, corrAggregationBuffer.rows.longValue());
        }

        public Map<IntWritable, List<DoubleWritable>> calculateCorrelationMatrix(List<DoubleWritable> list, SquareMatrix squareMatrix, long j) {
            SquareMatrix squareMatrix2 = new SquareMatrix();
            squareMatrix2.initMatrix(squareMatrix.getSize());
            for (int i = 0; i < squareMatrix.getSize(); i++) {
                for (int i2 = 0; i2 < squareMatrix.getSize(); i2++) {
                    double d = Double.NaN;
                    if (!this.attributeContainsNull[i].booleanValue() && !this.attributeContainsNull[i2].booleanValue()) {
                        d = calculateCorrElement(i, i2, list, squareMatrix, j);
                    }
                    squareMatrix2.set(i, i2, d);
                    squareMatrix2.set(i2, i, d);
                }
            }
            return squareMatrix2.getMxAsMap();
        }

        private double calculateCorrElement(int i, int i2, List<DoubleWritable> list, SquareMatrix squareMatrix, long j) {
            return ((j * squareMatrix.get(i, i2)) - (list.get(i).get() * list.get(i2).get())) / (Math.sqrt((j * squareMatrix.get(i, i)) - (list.get(i).get() * list.get(i).get())) * Math.sqrt((j * squareMatrix.get(i2, i2)) - (list.get(i2).get() * list.get(i2).get())));
        }

        public int getSize() {
            return this.size;
        }

        public void setSize(int i) {
            this.size = i;
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver, org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length == 0) {
            throw new UDFArgumentTypeException(0, "At least one argument is expected.");
        }
        for (int i = 0; i < typeInfoArr.length; i++) {
            if (typeInfoArr[i].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                throw new UDFArgumentTypeException(i + 1, "Only primitive type arguments are accepted but " + typeInfoArr[i].getTypeName() + " was passed as parameter " + (i + 1));
            }
            switch (((PrimitiveTypeInfo) typeInfoArr[i]).getPrimitiveCategory()) {
                case BYTE:
                case DECIMAL:
                case DOUBLE:
                case FLOAT:
                case INT:
                case LONG:
                case SHORT:
                case TIMESTAMP:
                case BINARY:
                case STRING:
                case BOOLEAN:
                case UNKNOWN:
                case VOID:
                default:
                    throw new UDFArgumentTypeException(i + 1, "Only numeric type arguments are accepted but " + typeInfoArr[i].getTypeName() + " is passed.");
            }
        }
        return new GenericUDAFCorrelationMatrixEvaluator(typeInfoArr.length);
    }

    public static String getName() {
        return UDFUtils.getLocalName(GenericUDAFCorrelationMatrix.class);
    }
}
