package cc.mallet.classify;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/FeatureConstraintUtil.class */
public class FeatureConstraintUtil {
    private static Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/classify/FeatureConstraintUtil$Element.class */
    public static class Element implements Comparable<Element> {
        private int index;
        private double value;

        public Element(int i, double d) {
            this.index = i;
            this.value = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Element element) {
            return Double.compare(this.value, element.value);
        }
    }

    static {
        $assertionsDisabled = !FeatureConstraintUtil.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(FeatureConstraintUtil.class.getName());
    }

    public static HashMap<Integer, double[][]> readRangeConstraintsFromFile(String str, InstanceList instanceList) {
        HashMap<Integer, double[][]> hashMap = new HashMap<>();
        for (int i = 0; i < instanceList.getTargetAlphabet().size(); i++) {
            System.err.println(instanceList.getTargetAlphabet().lookupObject(i));
        }
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\\s+");
                String str2 = split[0];
                int lookupIndex = instanceList.getDataAlphabet().lookupIndex(str2, false);
                if (lookupIndex == -1) {
                    throw new RuntimeException("Feature " + str2 + " not found in the alphabet!");
                }
                double[][] dArr = new double[instanceList.getTargetAlphabet().size()][2];
                for (double[] dArr2 : dArr) {
                    Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
                }
                for (int i2 = 1; i2 < split.length; i2++) {
                    String[] split2 = split[i2].split(":");
                    int lookupIndex2 = instanceList.getTargetAlphabet().lookupIndex(split2[0], false);
                    if (!$assertionsDisabled && lookupIndex2 == -1) {
                        throw new AssertionError(split2[0]);
                    }
                    if (split2[1].contains(",")) {
                        String[] split3 = split2[1].split(",");
                        double parseDouble = Double.parseDouble(split3[0]);
                        double parseDouble2 = Double.parseDouble(split3[1]);
                        dArr[lookupIndex2][0] = parseDouble;
                        dArr[lookupIndex2][1] = parseDouble2;
                    } else {
                        double parseDouble3 = Double.parseDouble(split2[1]);
                        dArr[lookupIndex2][0] = parseDouble3;
                        dArr[lookupIndex2][1] = parseDouble3;
                    }
                }
                hashMap.put(Integer.valueOf(lookupIndex), dArr);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return hashMap;
    }

    public static HashMap<Integer, double[]> readConstraintsFromFile(String str, InstanceList instanceList) {
        return testConstraintsFileIndexBased(str) ? readConstraintsFromFileIndex(str, instanceList) : readConstraintsFromFileString(str, instanceList);
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileString(String str, InstanceList instanceList) {
        HashMap<Integer, double[]> hashMap = new HashMap<>();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\\s+");
                int lookupIndex = instanceList.getDataAlphabet().lookupIndex(split[0], false);
                if (!$assertionsDisabled && split.length - 1 != instanceList.getTargetAlphabet().size()) {
                    throw new AssertionError(String.valueOf(split.length) + " " + instanceList.getTargetAlphabet().size());
                }
                double[] dArr = new double[split.length - 1];
                for (int i = 1; i < split.length; i++) {
                    String[] split2 = split[i].split(":");
                    int lookupIndex2 = instanceList.getTargetAlphabet().lookupIndex(split2[0], false);
                    if (!$assertionsDisabled && lookupIndex2 == -1) {
                        throw new AssertionError("Label " + split2[0] + " not found");
                    }
                    dArr[lookupIndex2] = Double.parseDouble(split2[1]);
                }
                hashMap.put(Integer.valueOf(lookupIndex), dArr);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return hashMap;
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileIndex(String str, InstanceList instanceList) {
        HashMap<Integer, double[]> hashMap = new HashMap<>();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\\s+");
                int parseInt = Integer.parseInt(split[0]);
                if (!$assertionsDisabled && split.length - 1 != instanceList.getTargetAlphabet().size()) {
                    throw new AssertionError();
                }
                double[] dArr = new double[split.length - 1];
                for (int i = 1; i < split.length; i++) {
                    dArr[i - 1] = Double.parseDouble(split[i]);
                }
                hashMap.put(Integer.valueOf(parseInt), dArr);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return hashMap;
    }

    private static boolean testConstraintsFileIndexBased(String str) {
        String str2 = "";
        try {
            str2 = new BufferedReader(new FileReader(new File(str))).readLine();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return !str2.contains(":");
    }

    public static ArrayList<Integer> selectFeaturesByInfoGain(InstanceList instanceList, int i) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        InfoGain infoGain = new InfoGain(instanceList);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Integer.valueOf(infoGain.getIndexAtRank(i2)));
        }
        return arrayList;
    }

    public static ArrayList<Integer> selectTopLDAFeatures(int i, ParallelTopicModel parallelTopicModel, Alphabet alphabet) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        Alphabet alphabet2 = parallelTopicModel.getAlphabet();
        int numTopics = parallelTopicModel.getNumTopics();
        Object[][] topWords = parallelTopicModel.getTopWords(alphabet2.size());
        for (int i2 = 0; i2 < alphabet2.size(); i2++) {
            for (int i3 = 0; i3 < numTopics; i3++) {
                String obj = topWords[i3][i2].toString();
                int lookupIndex = alphabet.lookupIndex(obj, false);
                if (lookupIndex >= 0 && !arrayList.contains(Integer.valueOf(lookupIndex))) {
                    logger.info("Selected feature: " + ((Object) obj));
                    arrayList.add(Integer.valueOf(lookupIndex));
                    if (arrayList.size() == i) {
                        return arrayList;
                    }
                }
            }
        }
        return arrayList;
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList instanceList, ArrayList<Integer> arrayList) {
        return setTargetsUsingData(instanceList, arrayList, true);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList instanceList, ArrayList<Integer> arrayList, boolean z) {
        return setTargetsUsingData(instanceList, arrayList, false, z);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList instanceList, ArrayList<Integer> arrayList, boolean z, boolean z2) {
        HashMap<Integer, double[]> hashMap = new HashMap<>();
        double[][] featureLabelCounts = getFeatureLabelCounts(instanceList, z);
        for (int i = 0; i < arrayList.size(); i++) {
            int intValue = arrayList.get(i).intValue();
            if (intValue != instanceList.getDataAlphabet().size()) {
                double[] dArr = featureLabelCounts[intValue];
                if (z2) {
                    MatrixOps.plusEquals(dArr, 1.0E-8d);
                    MatrixOps.timesEquals(dArr, 1.0d / MatrixOps.sum(dArr));
                }
                hashMap.put(Integer.valueOf(intValue), dArr);
            }
        }
        return hashMap;
    }

    public static HashMap<Integer, double[]> setTargetsUsingHeuristic(HashMap<Integer, ArrayList<Integer>> hashMap, int i, double d) {
        HashMap<Integer, double[]> hashMap2 = new HashMap<>();
        Iterator<Integer> it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            hashMap2.put(Integer.valueOf(intValue), getHeuristicPrior(hashMap.get(Integer.valueOf(intValue)), i, d));
        }
        return hashMap2;
    }

    public static HashMap<Integer, double[]> setTargetsUsingFeatureVoting(HashMap<Integer, ArrayList<Integer>> hashMap, InstanceList instanceList) {
        HashMap<Integer, double[]> hashMap2 = new HashMap<>();
        int size = instanceList.getTargetAlphabet().size();
        hashMap.keySet().iterator();
        double[][] dArr = new double[hashMap.size()][size];
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            FeatureVector featureVector = (FeatureVector) instance.getData();
            Labeling labeling = instanceList.get(i).getLabeling();
            double[] dArr2 = new double[size];
            if (labeling == null) {
                labelByVoting(hashMap, instance, dArr2);
            } else {
                dArr2[labeling.getBestIndex()] = 1.0d;
            }
            Iterator<Integer> it = hashMap.keySet().iterator();
            int i2 = 0;
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (featureVector.location(intValue) >= 0) {
                    for (int i3 = 0; i3 < size; i3++) {
                        double[] dArr3 = dArr[i2];
                        int i4 = i3;
                        dArr3[i4] = dArr3[i4] + (dArr2[i3] * featureVector.valueAtLocation(featureVector.location(intValue)));
                    }
                }
                i2++;
            }
        }
        Iterator<Integer> it2 = hashMap.keySet().iterator();
        int i5 = 0;
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            MatrixOps.plusEquals(dArr[i5], 1.0E-8d);
            MatrixOps.timesEquals(dArr[i5], 1.0d / MatrixOps.sum(dArr[i5]));
            hashMap2.put(Integer.valueOf(intValue2), dArr[i5]);
            i5++;
        }
        return hashMap2;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList instanceList, ArrayList<Integer> arrayList, boolean z) {
        HashMap<Integer, ArrayList<Integer>> hashMap = new HashMap<>();
        double[][] featureLabelCounts = getFeatureLabelCounts(instanceList, true);
        int size = instanceList.getTargetAlphabet().size();
        int i = 100 * size;
        InfoGain infoGain = new InfoGain(instanceList);
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += infoGain.getValueAtRank(i2);
        }
        double d2 = d / i;
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int intValue = arrayList.get(i3).intValue();
            if (!z || infoGain.value(intValue) >= d2) {
                double[] dArr = featureLabelCounts[intValue];
                MatrixOps.plusEquals(dArr, 1.0E-8d);
                MatrixOps.timesEquals(dArr, 1.0d / MatrixOps.sum(dArr));
                int[] maxIndices = getMaxIndices(dArr);
                ArrayList<Integer> arrayList2 = new ArrayList<>();
                if (size > 2) {
                    boolean z2 = false;
                    double d3 = dArr[maxIndices[0]] / 2.0d;
                    int i4 = 0;
                    while (true) {
                        if (i4 >= size) {
                            break;
                        }
                        if (dArr[i4] > d3) {
                            arrayList2.add(Integer.valueOf(i4));
                        }
                        if (z && arrayList2.size() > size / 2) {
                            logger.info("Oracle labeler rejected labeling: " + instanceList.getDataAlphabet().lookupObject(intValue));
                            z2 = true;
                            break;
                        }
                        i4++;
                    }
                    if (z2) {
                    }
                } else {
                    arrayList2.add(Integer.valueOf(maxIndices[0]));
                }
                hashMap.put(Integer.valueOf(intValue), arrayList2);
            } else {
                logger.info("Oracle labeler rejected labeling: " + instanceList.getDataAlphabet().lookupObject(intValue));
            }
        }
        return hashMap;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList instanceList, ArrayList<Integer> arrayList) {
        return labelFeatures(instanceList, arrayList, true);
    }

    public static double[][] getFeatureLabelCounts(InstanceList instanceList, boolean z) {
        int size = instanceList.getDataAlphabet().size();
        int size2 = instanceList.getTargetAlphabet().size();
        double[][] dArr = new double[size][size2];
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            FeatureVector featureVector = (FeatureVector) instance.getData();
            for (int i2 = 0; i2 < size2; i2++) {
                double value = instance.getLabeling().value(i2);
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    int indexAtLocation = featureVector.indexAtLocation(i3);
                    double valueAtLocation = z ? featureVector.valueAtLocation(i3) : 1.0d;
                    double[] dArr2 = dArr[indexAtLocation];
                    int i4 = i2;
                    dArr2[i4] = dArr2[i4] + (value * valueAtLocation);
                }
            }
        }
        return dArr;
    }

    private static double[] getHeuristicPrior(ArrayList<Integer> arrayList, int i, double d) {
        int size = arrayList.size();
        double[] dArr = new double[i];
        if (size == i) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = 1.0d / i;
            }
            return dArr;
        }
        double d2 = d / size;
        double d3 = (1.0d - d) / (i - size);
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            dArr[arrayList.get(i3).intValue()] = d2;
        }
        for (int i4 = 0; i4 < i; i4++) {
            if (dArr[i4] == 0.0d) {
                dArr[i4] = d3;
            }
        }
        if ($assertionsDisabled || Maths.almostEquals(MatrixOps.sum(dArr), 1.0d)) {
            return dArr;
        }
        throw new AssertionError();
    }

    private static void labelByVoting(HashMap<Integer, ArrayList<Integer>> hashMap, Instance instance, double[] dArr) {
        FeatureVector featureVector = (FeatureVector) instance.getData();
        int size = instance.getDataAlphabet().size() + 1;
        int[] iArr = new int[instance.getTargetAlphabet().size()];
        Iterator<Integer> it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            ArrayList<Integer> arrayList = hashMap.get(it.next());
            for (int i = 0; i < arrayList.size(); i++) {
                int intValue = arrayList.get(i).intValue();
                iArr[intValue] = iArr[intValue] + 1;
            }
        }
        Iterator<Integer> it2 = hashMap.keySet().iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            if (!$assertionsDisabled && intValue2 >= size) {
                throw new AssertionError();
            }
            if (featureVector.location(intValue2) >= 0) {
                ArrayList<Integer> arrayList2 = hashMap.get(Integer.valueOf(intValue2));
                for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                    int intValue3 = arrayList2.get(i2).intValue();
                    dArr[intValue3] = dArr[intValue3] + 1.0d;
                }
            }
        }
        double sum = MatrixOps.sum(dArr);
        if (sum == 0.0d) {
            MatrixOps.plusEquals(dArr, 1.0d);
            sum = MatrixOps.sum(dArr);
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / sum;
        }
    }

    private static int[] getMaxIndices(double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            arrayList.add(new Element(i, dArr[i]));
        }
        Collections.sort(arrayList);
        Collections.reverse(arrayList);
        int[] iArr = new int[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            iArr[i2] = ((Element) arrayList.get(i2)).index;
        }
        return iArr;
    }
}
