package cc.mallet.classify;

import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;

/* loaded from: input_file:cc/mallet/classify/ConfidencePredictingClassifier.class */
public class ConfidencePredictingClassifier extends Classifier {
    Classifier underlyingClassifier;
    Classifier confidencePredictingClassifier;
    double totalCorrect;
    double totalIncorrect;
    double totalIncorrectIncorrect;
    double totalIncorrectCorrect;
    int numCorrectInstances;
    int numIncorrectInstances;
    int numConfidenceCorrect;
    int numFalsePositive;
    int numFalseNegative;

    public ConfidencePredictingClassifier(Classifier classifier, Classifier classifier2) {
        super(classifier.getInstancePipe());
        this.underlyingClassifier = classifier;
        this.confidencePredictingClassifier = classifier2;
        this.totalCorrect = 0.0d;
        this.totalIncorrect = 0.0d;
        this.totalIncorrectIncorrect = 0.0d;
        this.totalIncorrectCorrect = 0.0d;
        this.numCorrectInstances = 0;
        this.numIncorrectInstances = 0;
        this.numConfidenceCorrect = 0;
        this.numFalsePositive = 0;
        this.numFalseNegative = 0;
    }

    @Override // cc.mallet.classify.Classifier
    public Classification classify(Instance instance) {
        Classification classify = this.underlyingClassifier.classify(instance);
        Classification classify2 = this.confidencePredictingClassifier.classify(classify);
        LabelVector labelVector = classify.getLabelVector();
        int bestIndex = labelVector.getBestIndex();
        double[] dArr = new double[labelVector.numLocations()];
        for (int i = 0; i < labelVector.numLocations(); i++) {
            if (i != bestIndex) {
                dArr[i] = 0.0d;
            } else {
                dArr[i] = classify2.getLabelVector().value("correct");
            }
        }
        if (classify.bestLabelIsCorrect()) {
            this.numCorrectInstances++;
            this.totalCorrect += classify2.getLabelVector().value("correct");
            this.totalIncorrectCorrect += classify2.getLabelVector().value("incorrect");
            if (new String("correct").equals(classify2.getLabelVector().getBestLabel().toString())) {
                this.numConfidenceCorrect++;
            } else {
                this.numFalseNegative++;
            }
        } else {
            this.numIncorrectInstances++;
            this.totalIncorrect += classify2.getLabelVector().value("correct");
            this.totalIncorrectIncorrect += classify2.getLabelVector().value("incorrect");
            if (new String("incorrect").equals(classify2.getLabelVector().getBestLabel().toString())) {
                this.numConfidenceCorrect++;
            } else {
                this.numFalsePositive++;
            }
        }
        return new Classification(instance, this, new LabelVector(labelVector.getLabelAlphabet(), dArr));
    }

    public void printAverageScores() {
        System.out.println("Mean score of correct for correct instances = " + meanCorrect());
        System.out.println("Mean score of correct for incorrect instances = " + meanIncorrect());
        System.out.println("Mean score of incorrect for correct instances = " + (this.totalIncorrectCorrect / this.numCorrectInstances));
        System.out.println("Mean score of incorrect for incorrect instances = " + (this.totalIncorrectIncorrect / this.numIncorrectInstances));
    }

    public void printConfidenceAccuracy() {
        System.out.println("Confidence predicting accuracy = " + (this.numConfidenceCorrect / (this.numIncorrectInstances + this.numCorrectInstances)) + " false negatives: " + this.numFalseNegative + "/" + this.numCorrectInstances + " false positives: " + this.numFalsePositive + " / " + this.numIncorrectInstances);
    }

    public double meanCorrect() {
        if (this.numCorrectInstances == 0) {
            return 0.0d;
        }
        return this.totalCorrect / this.numCorrectInstances;
    }

    public double meanIncorrect() {
        if (this.numIncorrectInstances == 0) {
            return 0.0d;
        }
        return this.totalIncorrect / this.numIncorrectInstances;
    }
}
