package com.medallia.word2vec;

import com.google.common.base.MoreObjects;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.Multiset;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkType;
import com.medallia.word2vec.util.AutoLog;
import java.util.List;
import org.apache.commons.logging.Log;

/* loaded from: input_file:com/medallia/word2vec/Word2VecTrainerBuilder.class */
public class Word2VecTrainerBuilder {
    private static final Log LOG = AutoLog.getLog();
    private Integer layerSize;
    private Integer windowSize;
    private Integer numThreads;
    private NeuralNetworkType type;
    private int negativeSamples;
    private boolean useHierarchicalSoftmax;
    private Multiset<String> vocab;
    private Integer minFrequency;
    private Double initialLearningRate;
    private Double downSampleRate;
    private Integer iterations;
    private TrainingProgressListener listener;

    /* loaded from: input_file:com/medallia/word2vec/Word2VecTrainerBuilder$TrainingProgressListener.class */
    public interface TrainingProgressListener {

        /* loaded from: input_file:com/medallia/word2vec/Word2VecTrainerBuilder$TrainingProgressListener$Stage.class */
        public enum Stage {
            ACQUIRE_VOCAB,
            FILTER_SORT_VOCAB,
            CREATE_HUFFMAN_ENCODING,
            TRAIN_NEURAL_NETWORK
        }

        void update(Stage stage, double d);
    }

    public Word2VecTrainerBuilder setLayerSize(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.layerSize = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainerBuilder setWindowSize(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.windowSize = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainerBuilder useNumThreads(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.numThreads = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainerBuilder type(NeuralNetworkType neuralNetworkType) {
        this.type = (NeuralNetworkType) Preconditions.checkNotNull(neuralNetworkType);
        return this;
    }

    public Word2VecTrainerBuilder useHierarchicalSoftmax() {
        this.useHierarchicalSoftmax = true;
        return this;
    }

    public Word2VecTrainerBuilder useNegativeSamples(int i) {
        Preconditions.checkArgument(i >= 0, "Value must be non-negative");
        this.negativeSamples = i;
        return this;
    }

    public Word2VecTrainerBuilder useVocab(Multiset<String> multiset) {
        this.vocab = (Multiset) Preconditions.checkNotNull(multiset);
        return this;
    }

    public Word2VecTrainerBuilder setMinVocabFrequency(int i) {
        Preconditions.checkArgument(i >= 0, "Value must be non-negative");
        this.minFrequency = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainerBuilder setInitialLearningRate(double d) {
        Preconditions.checkArgument(d >= 0.0d, "Value must be non-negative");
        this.initialLearningRate = Double.valueOf(d);
        return this;
    }

    public Word2VecTrainerBuilder setDownSamplingRate(double d) {
        Preconditions.checkArgument(d >= 0.0d, "Value must be non-negative");
        this.downSampleRate = Double.valueOf(d);
        return this;
    }

    public Word2VecTrainerBuilder setNumIterations(int i) {
        Preconditions.checkArgument(i > 0, "Value must be positive");
        this.iterations = Integer.valueOf(i);
        return this;
    }

    public Word2VecTrainerBuilder setListener(TrainingProgressListener trainingProgressListener) {
        this.listener = trainingProgressListener;
        return this;
    }

    public Word2VecModel train(Iterable<List<String>> iterable) throws InterruptedException {
        this.type = (NeuralNetworkType) MoreObjects.firstNonNull(this.type, NeuralNetworkType.CBOW);
        this.initialLearningRate = (Double) MoreObjects.firstNonNull(this.initialLearningRate, Double.valueOf(this.type.getDefaultInitialLearningRate()));
        if (this.numThreads == null) {
            this.numThreads = Integer.valueOf(Runtime.getRuntime().availableProcessors());
        }
        this.iterations = (Integer) MoreObjects.firstNonNull(this.iterations, 5);
        this.layerSize = (Integer) MoreObjects.firstNonNull(this.layerSize, 100);
        this.windowSize = (Integer) MoreObjects.firstNonNull(this.windowSize, 5);
        this.downSampleRate = (Double) MoreObjects.firstNonNull(this.downSampleRate, Double.valueOf(0.001d));
        this.minFrequency = (Integer) MoreObjects.firstNonNull(this.minFrequency, 5);
        this.listener = (TrainingProgressListener) MoreObjects.firstNonNull(this.listener, new TrainingProgressListener() { // from class: com.medallia.word2vec.Word2VecTrainerBuilder.1
            @Override // com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener
            public void update(TrainingProgressListener.Stage stage, double d) {
                System.out.println(String.format("Stage %s, progress %s%%", stage, Double.valueOf(d)));
            }
        });
        return new Word2VecTrainer(this.minFrequency, this.vocab == null ? Optional.absent() : Optional.of(this.vocab), new NeuralNetworkConfig(this.type, this.numThreads.intValue(), this.iterations.intValue(), this.layerSize.intValue(), this.windowSize.intValue(), this.negativeSamples, this.downSampleRate.doubleValue(), this.initialLearningRate.doubleValue(), this.useHierarchicalSoftmax)).train(LOG, this.listener, iterable);
    }
}
