package com.medallia.word2vec.neuralnetwork;

import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.medallia.word2vec.Word2VecTrainerBuilder;
import com.medallia.word2vec.huffman.HuffmanCoding;
import com.medallia.word2vec.util.CallableVoid;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

/* loaded from: input_file:com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer.class */
public abstract class NeuralNetworkTrainer {
    private static final int MAX_SENTENCE_LENGTH = 1000;
    static final int MAX_EXP = 6;
    static final int EXP_TABLE_SIZE = 1000;
    static final double[] EXP_TABLE = new double[1000];
    private static final int TABLE_SIZE = 100000000;
    private final Word2VecTrainerBuilder.TrainingProgressListener listener;
    final NeuralNetworkConfig config;
    final Map<String, HuffmanCoding.HuffmanNode> huffmanNodes;
    private final int vocabSize;
    final int layer1_size;
    final int window;
    int numTrainedTokens;
    volatile double alpha;
    final double[][] syn0;
    final double[][] syn1;
    private final double[][] syn1neg;
    long startNano;
    protected final AtomicInteger actualWordCount = new AtomicInteger();
    private final int[] table = new int[TABLE_SIZE];

    /* loaded from: input_file:com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer$NeuralNetworkModel.class */
    public interface NeuralNetworkModel {
        int layerSize();

        double[][] vectors();
    }

    /* loaded from: input_file:com/medallia/word2vec/neuralnetwork/NeuralNetworkTrainer$Worker.class */
    abstract class Worker extends CallableVoid {
        private static final int LEARNING_RATE_UPDATE_FREQUENCY = 10000;
        long nextRandom;
        final int iter;
        final Iterable<List<String>> batch;
        int wordCount;
        int lastWordCount;
        final double[] neu1;
        final double[] neu1e;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Worker(int i, int i2, Iterable<List<String>> iterable) {
            this.neu1 = new double[NeuralNetworkTrainer.this.layer1_size];
            this.neu1e = new double[NeuralNetworkTrainer.this.layer1_size];
            this.nextRandom = i;
            this.iter = i2;
            this.batch = iterable;
        }

        @Override // com.medallia.word2vec.util.CallableVoid
        public void run() throws InterruptedException {
            for (List<String> list : this.batch) {
                ArrayList arrayList = new ArrayList(list.size());
                for (String str : list) {
                    if (NeuralNetworkTrainer.this.huffmanNodes.containsKey(str)) {
                        this.wordCount++;
                        if (NeuralNetworkTrainer.this.config.downSampleRate > 0.0d) {
                            HuffmanCoding.HuffmanNode huffmanNode = NeuralNetworkTrainer.this.huffmanNodes.get(str);
                            double sqrt = ((Math.sqrt(huffmanNode.count / (NeuralNetworkTrainer.this.config.downSampleRate * NeuralNetworkTrainer.this.numTrainedTokens)) + 1.0d) * (NeuralNetworkTrainer.this.config.downSampleRate * NeuralNetworkTrainer.this.numTrainedTokens)) / huffmanNode.count;
                            this.nextRandom = NeuralNetworkTrainer.incrementRandom(this.nextRandom);
                            if (sqrt < (this.nextRandom & 65535) / 65536.0d) {
                            }
                        }
                        arrayList.add(str);
                    }
                }
                this.wordCount++;
                for (List<String> list2 : Iterables.partition(arrayList, 1000)) {
                    if (Thread.currentThread().isInterrupted()) {
                        throw new InterruptedException("Interrupted while training word2vec model");
                    }
                    if (this.wordCount - this.lastWordCount > 10000) {
                        updateAlpha(this.iter);
                    }
                    trainSentence(list2);
                }
            }
            NeuralNetworkTrainer.this.actualWordCount.addAndGet(this.wordCount - this.lastWordCount);
        }

        private void updateAlpha(int i) {
            int addAndGet = NeuralNetworkTrainer.this.actualWordCount.addAndGet(this.wordCount - this.lastWordCount);
            this.lastWordCount = this.wordCount;
            NeuralNetworkTrainer.this.alpha = NeuralNetworkTrainer.this.config.initialLearningRate * Math.max(1.0d - (addAndGet / (NeuralNetworkTrainer.this.config.iterations * NeuralNetworkTrainer.this.numTrainedTokens)), 1.0E-4d);
            NeuralNetworkTrainer.this.listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.TRAIN_NEURAL_NETWORK, addAndGet / ((NeuralNetworkTrainer.this.config.iterations * NeuralNetworkTrainer.this.numTrainedTokens) + 1));
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void handleNegativeSampling(HuffmanCoding.HuffmanNode huffmanNode) {
            int i;
            int i2;
            for (int i3 = 0; i3 <= NeuralNetworkTrainer.this.config.negativeSamples; i3++) {
                if (i3 == 0) {
                    i = huffmanNode.idx;
                    i2 = 1;
                } else {
                    this.nextRandom = NeuralNetworkTrainer.incrementRandom(this.nextRandom);
                    i = NeuralNetworkTrainer.this.table[((int) (((this.nextRandom >> 16) % 100000000) + 100000000)) % NeuralNetworkTrainer.TABLE_SIZE];
                    if (i == 0) {
                        i = ((int) ((((this.nextRandom % (NeuralNetworkTrainer.this.vocabSize - 1)) + NeuralNetworkTrainer.this.vocabSize) - 1) % (NeuralNetworkTrainer.this.vocabSize - 1))) + 1;
                    }
                    if (i != huffmanNode.idx) {
                        i2 = 0;
                    }
                }
                int i4 = i;
                double d = 0.0d;
                for (int i5 = 0; i5 < NeuralNetworkTrainer.this.layer1_size; i5++) {
                    d += this.neu1[i5] * NeuralNetworkTrainer.this.syn1neg[i4][i5];
                }
                double d2 = d > 6.0d ? (i2 - 1) * NeuralNetworkTrainer.this.alpha : d < -6.0d ? (i2 - 0) * NeuralNetworkTrainer.this.alpha : (i2 - NeuralNetworkTrainer.EXP_TABLE[(int) ((d + 6.0d) * 83.0d)]) * NeuralNetworkTrainer.this.alpha;
                for (int i6 = 0; i6 < NeuralNetworkTrainer.this.layer1_size; i6++) {
                    double[] dArr = this.neu1e;
                    int i7 = i6;
                    dArr[i7] = dArr[i7] + (d2 * NeuralNetworkTrainer.this.syn1neg[i4][i6]);
                }
                for (int i8 = 0; i8 < NeuralNetworkTrainer.this.layer1_size; i8++) {
                    double[] dArr2 = NeuralNetworkTrainer.this.syn1neg[i4];
                    int i9 = i8;
                    dArr2[i9] = dArr2[i9] + (d2 * this.neu1[i8]);
                }
            }
        }

        abstract void trainSentence(List<String> list);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public NeuralNetworkTrainer(NeuralNetworkConfig neuralNetworkConfig, Multiset<String> multiset, Map<String, HuffmanCoding.HuffmanNode> map, Word2VecTrainerBuilder.TrainingProgressListener trainingProgressListener) {
        this.config = neuralNetworkConfig;
        this.huffmanNodes = map;
        this.listener = trainingProgressListener;
        this.vocabSize = map.size();
        this.numTrainedTokens = multiset.size();
        this.layer1_size = neuralNetworkConfig.layerSize;
        this.window = neuralNetworkConfig.windowSize;
        this.alpha = neuralNetworkConfig.initialLearningRate;
        this.syn0 = new double[this.vocabSize][this.layer1_size];
        this.syn1 = new double[this.vocabSize][this.layer1_size];
        this.syn1neg = new double[this.vocabSize][this.layer1_size];
        initializeSyn0();
        initializeUnigramTable();
    }

    private void initializeUnigramTable() {
        long j = 0;
        Iterator<HuffmanCoding.HuffmanNode> it = this.huffmanNodes.values().iterator();
        while (it.hasNext()) {
            j = (long) (j + Math.pow(it.next().count, 0.75d));
        }
        Iterator<HuffmanCoding.HuffmanNode> it2 = this.huffmanNodes.values().iterator();
        HuffmanCoding.HuffmanNode next = it2.next();
        double pow = Math.pow(next.count, 0.75d) / j;
        int i = 0;
        for (int i2 = 0; i2 < TABLE_SIZE; i2++) {
            this.table[i2] = i;
            if (i2 / 1.0E8d > pow) {
                i++;
                HuffmanCoding.HuffmanNode next2 = it2.hasNext() ? it2.next() : next;
                pow += Math.pow(next2.count, 0.75d) / j;
                next = next2;
            }
        }
    }

    private void initializeSyn0() {
        long j = 1;
        for (int i = 0; i < this.huffmanNodes.size(); i++) {
            j = incrementRandom(j);
            for (int i2 = 0; i2 < this.layer1_size; i2++) {
                j = incrementRandom(j);
                this.syn0[i][i2] = (((j & 65535) / 65536.0d) - 0.5d) / this.layer1_size;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long incrementRandom(long j) {
        return (j * 25214903917L) + 11;
    }

    public NeuralNetworkModel train(Iterable<List<String>> iterable) throws InterruptedException {
        ListeningExecutorService listeningDecorator = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(this.config.numThreads));
        int size = Iterables.size(iterable);
        this.numTrainedTokens += size;
        Iterable partition = Iterables.partition(iterable, (size / this.config.numThreads) + 1);
        try {
            this.listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.TRAIN_NEURAL_NETWORK, 0.0d);
            for (int i = this.config.iterations; i > 0; i--) {
                ArrayList arrayList = new ArrayList();
                int i2 = 0;
                Iterator it = partition.iterator();
                while (it.hasNext()) {
                    arrayList.add(createWorker(i2, i, (List) it.next()));
                    i2++;
                }
                ArrayList arrayList2 = new ArrayList(arrayList.size());
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    arrayList2.add(listeningDecorator.submit((Callable) it2.next()));
                }
                try {
                    Futures.allAsList(arrayList2).get();
                } catch (ExecutionException e) {
                    throw new IllegalStateException("Error training neural network", e.getCause());
                }
            }
            listeningDecorator.shutdown();
            listeningDecorator.shutdownNow();
            return new NeuralNetworkModel() { // from class: com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.1
                @Override // com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.NeuralNetworkModel
                public int layerSize() {
                    return NeuralNetworkTrainer.this.config.layerSize;
                }

                @Override // com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.NeuralNetworkModel
                public double[][] vectors() {
                    return NeuralNetworkTrainer.this.syn0;
                }
            };
        } catch (Throwable th) {
            listeningDecorator.shutdownNow();
            throw th;
        }
    }

    abstract Worker createWorker(int i, int i2, Iterable<List<String>> iterable);

    static {
        for (int i = 0; i < 1000; i++) {
            EXP_TABLE[i] = Math.exp((((i / 1000.0d) * 2.0d) - 1.0d) * 6.0d);
            double[] dArr = EXP_TABLE;
            int i2 = i;
            dArr[i2] = dArr[i2] / (EXP_TABLE[i] + 1.0d);
        }
    }
}
