package edu.stanford.nlp.parser.lexparser;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Numberer;
import edu.stanford.nlp.util.StringUtils;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/MLEDependencyGrammar.class */
public class MLEDependencyGrammar extends AbstractDependencyGrammar {
    private static final boolean useSmoothTagProjection = false;
    private static final boolean useUnigramWordSmoothing = false;
    protected int numWordTokens;
    protected Counter<IntDependency> argCounter;
    protected Counter<IntDependency> stopCounter;
    public double smooth_aT_hTWd;
    public double smooth_aTW_hTWd;
    public double smooth_stop;
    public double interp;
    public double smooth_aTW_aT;
    public double smooth_aTW_hTd;
    public double smooth_aT_hTd;
    public double smooth_aPTW_aPT;
    static transient EndHead tempEndHead = new EndHead();
    protected transient List<IntTaggedWord> tagITWList;
    private TagProjection smoothTP;
    private Numberer smoothTPNumberer;
    private static final String TP_PREFIX = ".*TP*.";
    private static final boolean verbose = false;
    protected static final double MIN_PROBABILITY = 1.0E-40d;
    private static final long serialVersionUID = 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/parser/lexparser/MLEDependencyGrammar$EndHead.class */
    public static class EndHead {
        public int end;
        public int head;

        EndHead() {
        }
    }

    public MLEDependencyGrammar(TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3) {
        this(LexicalizedParser.basicCategoryTagsInDependencyGrammar ? new BasicCategoryTagProjection(treebankLangParserParams.treebankLanguagePack()) : new TestTagProjection(), treebankLangParserParams, z, z2, z3);
    }

    public MLEDependencyGrammar(TagProjection tagProjection, TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3) {
        super(treebankLangParserParams.treebankLanguagePack(), tagProjection, z, z2, z3);
        this.smooth_aT_hTWd = 32.0d;
        this.smooth_aTW_hTWd = 16.0d;
        this.smooth_stop = 4.0d;
        this.interp = 0.6d;
        this.smooth_aTW_aT = 96.0d;
        this.smooth_aTW_hTd = 32.0d;
        this.smooth_aT_hTd = 32.0d;
        this.smooth_aPTW_aPT = 16.0d;
        this.tagITWList = null;
        this.argCounter = new Counter<>();
        this.stopCounter = new Counter<>();
        double[] MLEDependencyGrammarSmoothingParams = treebankLangParserParams.MLEDependencyGrammarSmoothingParams();
        this.smooth_aT_hTWd = MLEDependencyGrammarSmoothingParams[0];
        this.smooth_aTW_hTWd = MLEDependencyGrammarSmoothingParams[1];
        this.smooth_stop = MLEDependencyGrammarSmoothingParams[2];
        this.interp = MLEDependencyGrammarSmoothingParams[3];
        this.smoothTP = new BasicCategoryTagProjection(treebankLangParserParams.treebankLanguagePack());
    }

    public String toString() {
        NumberFormat.getNumberInstance().setMaximumFractionDigits(2);
        StringBuilder sb = new StringBuilder(2000);
        String name = getClass().getName();
        sb.append(name.substring(name.lastIndexOf(".") + 1)).append("[tagbins=");
        sb.append(this.numTagBins).append(",wordTokens=").append(this.numWordTokens).append("; head -> arg\n");
        sb.append("]");
        return sb.toString();
    }

    public boolean pruneTW(IntTaggedWord intTaggedWord) {
        for (String str : this.tlp.punctuationTags()) {
            if (intTaggedWord.tag == tagNumberer().number(str)) {
                return true;
            }
        }
        return false;
    }

    protected static EndHead treeToDependencyHelper(Tree tree, List<IntDependency> list, int i) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            tempEndHead.head = i;
            tempEndHead.end = i + 1;
            return tempEndHead;
        }
        Tree[] children = tree.children();
        if (children.length == 1) {
            return treeToDependencyHelper(children[0], list, i);
        }
        tempEndHead = treeToDependencyHelper(children[0], list, i);
        int i2 = tempEndHead.head;
        int i3 = tempEndHead.end;
        tempEndHead = treeToDependencyHelper(children[1], list, tempEndHead.end);
        int i4 = tempEndHead.end;
        int i5 = tempEndHead.head;
        String tag = ((HasTag) tree.label()).tag();
        String tag2 = ((HasTag) children[0].label()).tag();
        String tag3 = ((HasTag) children[1].label()).tag();
        String word = ((HasWord) tree.label()).word();
        String word2 = ((HasWord) children[0].label()).word();
        String word3 = ((HasWord) children[1].label()).word();
        boolean equals = word.equals(word2);
        String str = equals ? tag3 : tag2;
        String str2 = equals ? word3 : word2;
        int number = tagNumberer().number(tag);
        int number2 = tagNumberer().number(str);
        int number3 = wordNumberer().hasSeen(word) ? wordNumberer().number(word) : wordNumberer().number(Lexicon.UNKNOWN_WORD);
        int number4 = wordNumberer().hasSeen(str2) ? wordNumberer().number(str2) : wordNumberer().number(Lexicon.UNKNOWN_WORD);
        int i6 = equals ? i2 : i5;
        int i7 = equals ? i5 : i2;
        list.add(new IntDependency(number3, number, number4, number2, equals, equals ? (i3 - i6) - 1 : i6 - i3));
        list.add(new IntDependency(number4, number2, -2, -2, false, equals ? i7 - i3 : i7 - i));
        list.add(new IntDependency(number4, number2, -2, -2, true, equals ? (i4 - i7) - 1 : (i3 - i7) - 1));
        tempEndHead.head = i6;
        return tempEndHead;
    }

    public void dumpSizes() {
        System.out.println("arg counter " + this.argCounter.size());
        System.out.println("stop counter " + this.stopCounter.size());
    }

    public static List<IntDependency> treeToDependencyList(Tree tree) {
        ArrayList arrayList = new ArrayList();
        treeToDependencyHelper(tree, arrayList, 0);
        return arrayList;
    }

    public double scoreAll(Collection<IntDependency> collection) {
        double d = 0.0d;
        Iterator<IntDependency> it = collection.iterator();
        while (it.hasNext()) {
            double score = score(it.next());
            if (score > Double.NEGATIVE_INFINITY) {
                d += score;
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.parser.lexparser.AbstractDependencyGrammar, edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public void tune(Collection<Tree> collection) {
        ArrayList<IntDependency> arrayList = new ArrayList();
        Iterator<Tree> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.addAll(treeToDependencyList(it.next()));
        }
        double d = Double.NEGATIVE_INFINITY;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        System.err.println("Tuning smooth_stop...");
        this.smooth_stop = 0.01d;
        while (this.smooth_stop < 100.0d) {
            double d6 = 0.0d;
            for (IntDependency intDependency : arrayList) {
                if (!rootTW(intDependency.head)) {
                    double stopProb = getStopProb(intDependency);
                    if (!intDependency.arg.equals(stopTW)) {
                        stopProb = 1.0d - stopProb;
                    }
                    if (stopProb > 0.0d) {
                        d6 += Math.log(stopProb);
                    }
                }
            }
            if (d6 > d) {
                d = d6;
                d2 = this.smooth_stop;
            }
            this.smooth_stop *= 1.25d;
        }
        this.smooth_stop = d2;
        System.err.println("Tuning selected smooth_stop: " + this.smooth_stop);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            if (((IntDependency) it2.next()).arg.equals(stopTW)) {
                it2.remove();
            }
        }
        System.err.println("Tuning other parameters...");
        double d7 = Double.NEGATIVE_INFINITY;
        this.smooth_aTW_hTWd = 0.5d;
        while (this.smooth_aTW_hTWd < 100.0d) {
            System.err.print(".");
            this.smooth_aT_hTWd = 0.5d;
            while (this.smooth_aT_hTWd < 100.0d) {
                this.interp = 0.02d;
                while (this.interp < 1.0d) {
                    double d8 = 0.0d;
                    Iterator it3 = arrayList.iterator();
                    while (it3.hasNext()) {
                        double score = score((IntDependency) it3.next());
                        if (score > Double.NEGATIVE_INFINITY) {
                            d8 += score;
                        }
                    }
                    if (d8 > d7) {
                        d7 = d8;
                        d5 = this.interp;
                        d3 = this.smooth_aTW_hTWd;
                        d4 = this.smooth_aT_hTWd;
                        System.err.println("Current best interp: " + this.interp + " with score " + d8);
                    }
                    this.interp += 0.02d;
                }
                this.smooth_aT_hTWd *= 1.25d;
            }
            this.smooth_aTW_hTWd *= 1.25d;
        }
        this.smooth_aTW_hTWd = d3;
        this.smooth_aT_hTWd = d4;
        this.interp = d5;
        System.err.println("\nTuning selected smooth_aTW_hTWd: " + this.smooth_aTW_hTWd + " smooth_aT_hTWd: " + this.smooth_aT_hTWd + " interp: " + this.interp + " smooth_aTW_aT: " + this.smooth_aTW_aT + " smooth_aTW_hTd: " + this.smooth_aTW_hTd + " smooth_aT_hTd: " + this.smooth_aT_hTd);
    }

    public void addRule(IntDependency intDependency, double d) {
        if (!this.directional) {
            intDependency.leftHeaded = false;
        }
        expandDependency(intDependency, d);
    }

    private IntTaggedWord getCachedITW(short s) {
        if (this.tagITWList == null) {
            this.tagITWList = new ArrayList(this.numTagBins + 2);
            for (int i = 0; i < this.numTagBins + 2; i++) {
                this.tagITWList.add(i, null);
            }
        }
        IntTaggedWord intTaggedWord = this.tagITWList.get(tagBin(s) + 2);
        if (intTaggedWord == null) {
            intTaggedWord = new IntTaggedWord(-1, tagBin(s));
            this.tagITWList.set(tagBin(s) + 2, intTaggedWord);
        }
        return intTaggedWord;
    }

    protected void expandDependency(IntDependency intDependency, double d) {
        if (intDependency.head == null || intDependency.arg == null) {
            return;
        }
        if (intDependency.arg.word != -2) {
            expandArg(intDependency, valenceBin(intDependency.distance), d);
        }
        expandStop(intDependency, distanceBin(intDependency.distance), d, true);
    }

    private short tagProject(short s) {
        if (this.smoothTPNumberer == null) {
            this.smoothTPNumberer = new Numberer(tagNumberer());
        }
        if (s < 0) {
            return s;
        }
        return (short) this.smoothTPNumberer.number(TP_PREFIX + this.smoothTP.project((String) this.smoothTPNumberer.object(s)));
    }

    private void expandArg(IntDependency intDependency, short s, double d) {
        IntTaggedWord cachedITW = getCachedITW(intDependency.head.tag);
        IntTaggedWord cachedITW2 = getCachedITW(intDependency.arg.tag);
        IntTaggedWord intTaggedWord = new IntTaggedWord(intDependency.head.word, tagBin(intDependency.head.tag));
        IntTaggedWord intTaggedWord2 = new IntTaggedWord(intDependency.arg.word, tagBin(intDependency.arg.tag));
        boolean z = intDependency.leftHeaded;
        this.argCounter.incrementCount(intern(intTaggedWord, intTaggedWord2, z, s), d);
        this.argCounter.incrementCount(intern(cachedITW, intTaggedWord2, z, s), d);
        this.argCounter.incrementCount(intern(intTaggedWord, cachedITW2, z, s), d);
        this.argCounter.incrementCount(intern(cachedITW, cachedITW2, z, s), d);
        this.argCounter.incrementCount(intern(intTaggedWord, wildTW, z, s), d);
        this.argCounter.incrementCount(intern(cachedITW, wildTW, z, s), d);
        this.argCounter.incrementCount(intern(wildTW, intTaggedWord2, false, (short) -1), d);
        this.argCounter.incrementCount(intern(wildTW, cachedITW2, false, (short) -1), d);
        this.numWordTokens++;
    }

    private void expandStop(IntDependency intDependency, short s, double d, boolean z) {
        IntTaggedWord cachedITW = getCachedITW(intDependency.head.tag);
        IntTaggedWord intTaggedWord = new IntTaggedWord(intDependency.head.word, tagBin(intDependency.head.tag));
        IntTaggedWord intTaggedWord2 = new IntTaggedWord(intDependency.arg.word, tagBin(intDependency.arg.tag));
        boolean z2 = intDependency.leftHeaded;
        if (intTaggedWord2.word == -2) {
            this.stopCounter.incrementCount(intern(intTaggedWord, intTaggedWord2, z2, s), d);
            this.stopCounter.incrementCount(intern(cachedITW, intTaggedWord2, z2, s), d);
        }
        if (z || intTaggedWord2.word != -2) {
            this.stopCounter.incrementCount(intern(intTaggedWord, wildTW, z2, s), d);
            this.stopCounter.incrementCount(intern(cachedITW, wildTW, z2, s), d);
        }
    }

    public double countHistory(IntDependency intDependency) {
        short s = intDependency.head.tag;
        IntTaggedWord intTaggedWord = intDependency.arg;
        short s2 = intDependency.distance;
        intDependency.head.tag = (short) tagBin(intDependency.head.tag);
        intDependency.distance = valenceBin(s2);
        intDependency.arg = wildTW;
        double count = this.argCounter.getCount(intDependency);
        intDependency.head.tag = s;
        intDependency.arg = intTaggedWord;
        intDependency.distance = s2;
        return count;
    }

    @Override // edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public double scoreTB(IntDependency intDependency) {
        return Test.depWeight * Math.log(probTB(intDependency));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double probTB(IntDependency intDependency) {
        if (!this.directional) {
            intDependency.leftHeaded = false;
        }
        boolean z = intDependency.leftHeaded;
        short s = intDependency.distance;
        int i = intDependency.head.word;
        int i2 = intDependency.arg.word;
        short s2 = intDependency.head.tag;
        short s3 = intDependency.arg.tag;
        IntTaggedWord intTaggedWord = intDependency.arg;
        IntTaggedWord intTaggedWord2 = intDependency.head;
        double stopProb = rootTW(intDependency.head) ? 0.0d : getStopProb(intDependency);
        if (intDependency.arg.word == -2) {
            return stopProb;
        }
        double d = 1.0d - stopProb;
        intDependency.distance = valenceBin(s);
        short s4 = intDependency.distance;
        IntDependency intDependency2 = new IntDependency(intDependency.head, intDependency.arg, intDependency.leftHeaded, intDependency.distance);
        double count = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count2 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.arg = wildTW;
        double count3 = this.argCounter.getCount(intDependency);
        intDependency.arg = intTaggedWord;
        if (!intDependency.equals(intDependency2)) {
            throw new RuntimeException("Dependencies not equal: " + intDependency + " and " + intDependency2);
        }
        intDependency.head.word = -1;
        double count4 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count5 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.arg = wildTW;
        double count6 = this.argCounter.getCount(intDependency);
        intDependency.arg = intTaggedWord;
        intDependency.head.word = i;
        if (!intDependency.equals(intDependency2)) {
            throw new RuntimeException("Dependencies not equal: " + intDependency + " and " + intDependency2);
        }
        intDependency.head = wildTW;
        intDependency.leftHeaded = false;
        intDependency.distance = (short) -1;
        double count7 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count8 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.arg.tag = (short) -1;
        this.argCounter.getCount(intDependency);
        intDependency.arg.tag = s3;
        intDependency.head = intTaggedWord2;
        intDependency.leftHeaded = z;
        intDependency.distance = s4;
        if (!intDependency.equals(intDependency2)) {
            throw new RuntimeException("Dependencies not equal: " + intDependency + " and " + intDependency2);
        }
        intDependency.distance = s;
        double d2 = count7 > 0.0d ? count7 / count8 : 1.0d;
        double d3 = ((this.interp * ((count + (this.smooth_aTW_hTWd * (count6 > 0.0d ? count4 / count6 : 0.0d))) / (count3 + this.smooth_aTW_hTWd))) + ((1.0d - this.interp) * d2 * ((count2 + (this.smooth_aT_hTWd * (count6 > 0.0d ? count5 / count6 : 0.0d))) / (count3 + this.smooth_aT_hTWd)))) * d;
        if (Test.prunePunc && pruneTW(intTaggedWord)) {
            return 1.0d;
        }
        if (Double.isNaN(d3)) {
            d3 = 0.0d;
        }
        if (d3 < MIN_PROBABILITY) {
            d3 = 0.0d;
        }
        return d3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getStopProb(IntDependency intDependency) {
        int i = intDependency.head.word;
        IntTaggedWord intTaggedWord = intDependency.arg;
        short s = intDependency.distance;
        intDependency.distance = distanceBin(s);
        intDependency.arg = stopTW;
        double count = this.stopCounter.getCount(intDependency);
        intDependency.head.word = -1;
        double count2 = this.stopCounter.getCount(intDependency);
        intDependency.head.word = i;
        intDependency.arg = wildTW;
        double count3 = this.stopCounter.getCount(intDependency);
        intDependency.head.word = -1;
        double count4 = this.stopCounter.getCount(intDependency);
        intDependency.head.word = i;
        intDependency.arg = intTaggedWord;
        intDependency.distance = s;
        return (count + (this.smooth_stop * (count4 > 0.0d ? count2 / count4 : 1.0d))) / (count3 + this.smooth_stop);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        Counter<IntDependency> counter = this.argCounter;
        this.argCounter = new Counter<>();
        Counter<IntDependency> counter2 = this.stopCounter;
        this.stopCounter = new Counter<>();
        for (IntDependency intDependency : counter.keySet()) {
            expandArg(intDependency, intDependency.distance, counter.getCount(intDependency));
        }
        for (IntDependency intDependency2 : counter2.keySet()) {
            expandStop(intDependency2, intDependency2.distance, counter2.getCount(intDependency2), false);
        }
        this.expandDependencyMap = null;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        Counter<IntDependency> counter = this.argCounter;
        this.argCounter = new Counter<>();
        for (IntDependency intDependency : counter.keySet()) {
            if (intDependency.head != wildTW && intDependency.arg != wildTW && intDependency.head.word != -1 && intDependency.arg.word != -1) {
                this.argCounter.incrementCount(intDependency, counter.getCount(intDependency));
            }
        }
        Counter<IntDependency> counter2 = this.stopCounter;
        this.stopCounter = new Counter<>();
        for (IntDependency intDependency2 : counter2.keySet()) {
            if (intDependency2.head.word != -1) {
                this.stopCounter.incrementCount(intDependency2, counter2.getCount(intDependency2));
            }
        }
        objectOutputStream.defaultWriteObject();
        this.argCounter = counter;
        this.stopCounter = counter2;
    }

    @Override // edu.stanford.nlp.parser.lexparser.AbstractDependencyGrammar, edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public void readData(BufferedReader bufferedReader) throws IOException {
        int i = 1;
        boolean z = false;
        IntDependency intDependency = new IntDependency(-2, -2, -2, -2, false, 0);
        String readLine = bufferedReader.readLine();
        while (true) {
            String str = readLine;
            if (str == null || str.length() <= 0) {
                return;
            }
            try {
                if (str.equals("BEGIN_STOP")) {
                    z = true;
                } else {
                    String[] splitOnCharWithQuoting = StringUtils.splitOnCharWithQuoting(str, ' ', '\"', '\\');
                    intDependency.leftHeaded = splitOnCharWithQuoting[3].equals("left");
                    short parseInt = (short) Integer.parseInt(splitOnCharWithQuoting[4]);
                    intDependency.head = new IntTaggedWord(splitOnCharWithQuoting[0], '/');
                    intDependency.arg = new IntTaggedWord(splitOnCharWithQuoting[2], '/');
                    double parseDouble = Double.parseDouble(splitOnCharWithQuoting[5]);
                    if (z) {
                        expandStop(intDependency, parseInt, parseDouble, false);
                    } else {
                        expandArg(intDependency, parseInt, parseDouble);
                    }
                    i++;
                }
                readLine = bufferedReader.readLine();
            } catch (Exception e) {
                e.printStackTrace();
                throw new IOException("Error on line " + i + ": " + str);
            }
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.AbstractDependencyGrammar, edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public void writeData(PrintWriter printWriter) throws IOException {
        for (IntDependency intDependency : this.argCounter.keySet()) {
            if (intDependency.head != wildTW && intDependency.arg != wildTW && intDependency.head.word != -1 && intDependency.arg.word != -1) {
                printWriter.println(intDependency + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + this.argCounter.getCount(intDependency));
            }
        }
        printWriter.println("BEGIN_STOP");
        for (IntDependency intDependency2 : this.stopCounter.keySet()) {
            if (intDependency2.head.word != -1) {
                printWriter.println(intDependency2 + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + this.stopCounter.getCount(intDependency2));
            }
        }
        printWriter.flush();
    }
}
