package cc.mallet.classify;

import cc.mallet.classify.NaiveBayesTrainer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Multinomial;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/NaiveBayesEMTrainer.class */
public class NaiveBayesEMTrainer extends ClassifierTrainer<NaiveBayes> {
    private static Logger logger = MalletLogger.getLogger(NaiveBayesEMTrainer.class.getName());
    Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
    Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
    double docLengthNormalization = -1.0d;
    double unlabeledDataWeight = 1.0d;
    int iteration = 0;
    NaiveBayesTrainer.Factory nbTrainer = new NaiveBayesTrainer.Factory();
    NaiveBayes classifier;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public NaiveBayesEMTrainer() {
        this.nbTrainer.setDocLengthNormalization(this.docLengthNormalization);
        this.nbTrainer.setFeatureMultinomialEstimator(this.featureEstimator);
        this.nbTrainer.setPriorMultinomialEstimator(this.priorEstimator);
    }

    public Multinomial.Estimator getFeatureMultinomialEstimator() {
        return this.featureEstimator;
    }

    public void setFeatureMultinomialEstimator(Multinomial.Estimator estimator) {
        this.featureEstimator = estimator;
        this.nbTrainer.setFeatureMultinomialEstimator(this.featureEstimator);
    }

    public Multinomial.Estimator getPriorMultinomialEstimator() {
        return this.priorEstimator;
    }

    public void setPriorMultinomialEstimator(Multinomial.Estimator estimator) {
        this.priorEstimator = estimator;
        this.nbTrainer.setPriorMultinomialEstimator(this.priorEstimator);
    }

    public void setDocLengthNormalization(double d) {
        this.docLengthNormalization = d;
        this.nbTrainer.setDocLengthNormalization(this.docLengthNormalization);
    }

    public double getDocLengthNormalization() {
        return this.docLengthNormalization;
    }

    public double getUnlabeledDataWeight() {
        return this.unlabeledDataWeight;
    }

    public void setUnlabeledDataWeight(double d) {
        this.unlabeledDataWeight = d;
    }

    public int getIteration() {
        return this.iteration;
    }

    @Override // cc.mallet.classify.ClassifierTrainer
    public boolean isFinishedTraining() {
        return false;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public NaiveBayes getClassifier() {
        return this.classifier;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public NaiveBayes train(InstanceList instanceList) {
        NaiveBayes train = this.nbTrainer.newClassifierTrainer().train(instanceList);
        double d = 0.0d;
        boolean z = false;
        int i = 0;
        while (!z) {
            InstanceList instanceList2 = new InstanceList(instanceList.getPipe());
            for (int i2 = 0; i2 < instanceList.size(); i2++) {
                Instance instance = instanceList.get(i2);
                if (instance.getLabeling() != null) {
                    instanceList2.add(instance, 1.0d);
                } else {
                    Instance shallowCopy = instance.shallowCopy();
                    shallowCopy.unLock();
                    shallowCopy.setLabeling(train.classify(instance).getLabeling());
                    shallowCopy.lock();
                    instanceList2.add(shallowCopy, this.unlabeledDataWeight);
                }
            }
            train = this.nbTrainer.newClassifierTrainer().train(instanceList2);
            double dataLogLikelihood = train.dataLogLikelihood(instanceList2);
            System.err.println("Loglikelihood = " + dataLogLikelihood);
            if (Math.abs((dataLogLikelihood - d) / dataLogLikelihood) < 1.0E-4d) {
                z = true;
            }
            d = dataLogLikelihood;
            i++;
        }
        return train;
    }

    public String toString() {
        String str;
        str = "NaiveBayesEMTrainer";
        str = this.docLengthNormalization != 1.0d ? str + ",docLengthNormalization=" + this.docLengthNormalization : "NaiveBayesEMTrainer";
        if (this.unlabeledDataWeight != 1.0d) {
            str = str + ",unlabeledDataWeight=" + this.unlabeledDataWeight;
        }
        return str;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(1);
        objectOutputStream.writeObject(this.featureEstimator);
        objectOutputStream.writeObject(this.priorEstimator);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        int readInt = objectInputStream.readInt();
        if (readInt != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted 1, got " + readInt);
        }
        this.featureEstimator = (Multinomial.Estimator) objectInputStream.readObject();
        this.priorEstimator = (Multinomial.Estimator) objectInputStream.readObject();
    }
}
