package org.nodes.util;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.jena.atlas.json.io.JSWriter;
import org.apache.thrift.protocol.TMultiplexedProtocol;
import org.nodes.Global;

/* loaded from: input_file:org/nodes/util/FrequencyModel.class */
public class FrequencyModel<T> {
    protected Map<T, Double> frequencies;
    protected double total;
    protected List<T> sorted;
    private long modsAtLastSort;
    private long mods;

    /* loaded from: input_file:org/nodes/util/FrequencyModel$Comparator.class */
    public static class Comparator<T> implements java.util.Comparator<T> {
        private FrequencyModel<T> model;

        public Comparator(FrequencyModel<T> frequencyModel) {
            this.model = frequencyModel;
        }

        @Override // java.util.Comparator
        public int compare(T t, T t2) {
            return Double.compare(this.model.frequency(t), this.model.frequency(t2));
        }
    }

    public FrequencyModel() {
        this.frequencies = new LinkedHashMap();
        this.total = 0.0d;
        this.sorted = null;
        this.modsAtLastSort = -1L;
        this.mods = 0L;
    }

    public FrequencyModel(FrequencyModel<T> frequencyModel) {
        this(frequencyModel, frequencyModel.tokens());
    }

    public FrequencyModel(FrequencyModel<T> frequencyModel, Collection<T> collection) {
        this();
        for (T t : collection) {
            add(t, frequencyModel.frequency(t));
        }
    }

    public FrequencyModel(Collection<T> collection) {
        this.frequencies = new LinkedHashMap();
        this.total = 0.0d;
        this.sorted = null;
        this.modsAtLastSort = -1L;
        this.mods = 0L;
        add((Collection) collection);
    }

    public void add(Collection<T> collection) {
        Iterator<T> it = collection.iterator();
        while (it.hasNext()) {
            add((FrequencyModel<T>) it.next());
        }
    }

    public void add(T t) {
        add(t, 1.0d);
    }

    public void add(T t, double d) {
        this.mods++;
        if (this.frequencies.containsKey(t)) {
            this.frequencies.put(t, Double.valueOf(this.frequencies.get(t).doubleValue() + d));
        } else {
            this.frequencies.put(t, Double.valueOf(d));
        }
        this.total += d;
    }

    public double distinct() {
        return this.frequencies.keySet().size();
    }

    public double frequency(T t) {
        if (this.frequencies.containsKey(t)) {
            return this.frequencies.get(t).doubleValue();
        }
        return 0.0d;
    }

    public double total() {
        return this.total;
    }

    public Set<T> tokens() {
        return Collections.unmodifiableSet(this.frequencies.keySet());
    }

    public List<T> sorted() {
        if (this.modsAtLastSort != this.mods || this.sorted == null) {
            this.sorted = new ArrayList(tokens());
            Collections.sort(this.sorted, Collections.reverseOrder(new Comparator(this)));
            this.sorted = Collections.unmodifiableList(this.sorted);
        }
        return this.sorted;
    }

    public T maxToken() {
        double d = Double.MIN_VALUE;
        T t = null;
        for (T t2 : this.frequencies.keySet()) {
            double probability = probability(t2);
            if (probability > d) {
                t = t2;
                d = probability;
            }
        }
        return t;
    }

    public T choose() {
        double nextDouble = Global.random().nextDouble();
        double d = 0.0d;
        T t = null;
        for (T t2 : this.frequencies.keySet()) {
            t = t2;
            d += probability(t2);
            if (d > nextDouble) {
                break;
            }
        }
        return t;
    }

    public Set<T> chooseWithoutReplacement(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Input (" + i + ") cannot be negative.");
        }
        if (i > distinct()) {
            throw new IllegalArgumentException("Input (" + i + ") must be smaller than number of distinct elements in model (" + distinct() + ").");
        }
        HashSet hashSet = new HashSet();
        while (hashSet.size() < i) {
            T choose = choose();
            if (!hashSet.contains(choose)) {
                hashSet.add(choose);
            }
        }
        return hashSet;
    }

    public double entropy() {
        double d = 0.0d;
        Iterator<T> it = tokens().iterator();
        while (it.hasNext()) {
            double probability = probability(it.next());
            if (probability != 0.0d) {
                d += probability * Functions.log2(probability);
            }
        }
        return -d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void print(PrintStream printStream) {
        printStream.printf("total:    %.0f \n", Double.valueOf(total()));
        printStream.printf("distinct: %.0f \n", Double.valueOf(distinct()));
        printStream.printf("entropy:  %.3f \n", Double.valueOf(entropy()));
        printStream.println("tokens: ");
        ArrayList arrayList = new ArrayList(tokens());
        Collections.sort(arrayList, Collections.reverseOrder(new Comparator(this)));
        for (Object obj : arrayList) {
            printStream.println("  " + obj + JSWriter.ArraySep + frequency(obj));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public String toStringLong() {
        StringBuilder sb = new StringBuilder();
        sb.append('[');
        ArrayList arrayList = new ArrayList(this.frequencies.keySet());
        if (!arrayList.isEmpty() && (arrayList.get(0) instanceof Comparable)) {
            Collections.sort(arrayList);
            arrayList = arrayList;
        }
        for (Object obj : arrayList) {
            if (sb.length() != 1) {
                sb.append(JSWriter.ArraySep);
            }
            sb.append(obj + TMultiplexedProtocol.SEPARATOR);
            sb.append(String.format("%.2f", Double.valueOf(probability(obj))));
        }
        sb.append(']');
        return sb.toString();
    }

    public long state() {
        return this.mods;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append('[');
        for (T t : this.frequencies.keySet()) {
            stringBuffer.append(stringBuffer.length() == 1 ? "" : JSWriter.ArraySep).append(t + TMultiplexedProtocol.SEPARATOR + this.frequencies.get(t));
        }
        stringBuffer.append(']');
        return stringBuffer.toString();
    }

    public double probability(T t) {
        return frequency(t) / total();
    }

    public double logProbability(T t) {
        return Math.log(probability(t));
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof FrequencyModel)) {
            return false;
        }
        FrequencyModel frequencyModel = (FrequencyModel) obj;
        if (total() != frequencyModel.total() || distinct() != frequencyModel.distinct()) {
            return false;
        }
        HashSet hashSet = new HashSet(frequencyModel.tokens());
        for (T t : tokens()) {
            if (frequency(t) != frequencyModel.frequency(t)) {
                return false;
            }
            hashSet.remove(t);
        }
        return hashSet.isEmpty();
    }
}
