package com.rapidminer.word2vec;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.utils.ExampleSetBuilder;
import com.rapidminer.example.utils.ExampleSets;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.allenai.word2vec.Searcher;
import org.allenai.word2vec.Word2VecModel;
import org.allenai.word2vec.Word2VecTrainerBuilder;
import org.allenai.word2vec.neuralnetwork.NeuralNetworkType;
import org.allenai.word2vec.thrift.Word2VecModelThrift;
import org.allenai.word2vec.util.AutoLog;
import org.allenai.word2vec.util.Common;
import org.allenai.word2vec.util.Format;
import org.allenai.word2vec.util.ProfilingTimer;
import org.allenai.word2vec.util.ThriftUtils;
import org.apache.commons.logging.Log;
import org.apache.http.HttpStatus;
import org.apache.thrift.TException;

/* loaded from: input_file:com/rapidminer/word2vec/Word2Vec.class */
public class Word2Vec {
    private static final Log LOG = AutoLog.getLog();
    private static Word2VecModel model;
    private int MinVocabFrequency = 5;
    private int LayerSize = 5;
    private int WindowSize = 4;
    private int UseNegativeSamples = 25;
    private double DownSamplingRate = 1.0E-4d;
    private int numInterations = 1;
    private int numThreads = 20;

    public void run(ArrayList<String> arrayList, final Logger logger) throws IOException, TException, InterruptedException, Searcher.UnknownWordException {
        model = Word2VecModel.trainer().setMinVocabFrequency(this.MinVocabFrequency).useNumThreads(20).setWindowSize(this.WindowSize).type(NeuralNetworkType.CBOW).setLayerSize(this.LayerSize).useNegativeSamples(this.UseNegativeSamples).setDownSamplingRate(this.DownSamplingRate).setNumIterations(this.numInterations).setListener(new Word2VecTrainerBuilder.TrainingProgressListener() { // from class: com.rapidminer.word2vec.Word2Vec.2
            @Override // org.allenai.word2vec.Word2VecTrainerBuilder.TrainingProgressListener
            public void update(Word2VecTrainerBuilder.TrainingProgressListener.Stage stage, double d) {
                logger.log(Level.INFO, String.format("%s is %.2f%% complete", Format.formatEnum(stage), Double.valueOf(d * 100.0d)));
            }
        }).train(Lists.transform(arrayList, new Function<String, List<String>>() { // from class: com.rapidminer.word2vec.Word2Vec.1
            @Override // com.google.common.base.Function
            public List<String> apply(String str) {
                return Arrays.asList(str.split(" "));
            }
        }));
    }

    public ExampleSet createExampleSet(Word2VecModel word2VecModel) throws Searcher.UnknownWordException {
        ArrayList arrayList = new ArrayList();
        arrayList.add(AttributeFactory.createAttribute("word", 1));
        for (int i = 0; i < this.LayerSize; i++) {
            arrayList.add(AttributeFactory.createAttribute("dimension_" + Integer.toString(i), 4));
        }
        Searcher forSearch = word2VecModel.forSearch();
        ExampleSetBuilder from = ExampleSets.from(arrayList);
        for (String str : word2VecModel.getVocab()) {
            double[] dArr = new double[this.LayerSize + 1];
            ImmutableList<Double> rawVector = forSearch.getRawVector(str);
            dArr[0] = ((Attribute) arrayList.get(0)).getMapping().mapString(str);
            for (int i2 = 0; i2 < rawVector.size(); i2++) {
                dArr[i2 + 1] = rawVector.get(i2).doubleValue();
            }
            from.addRow(dArr);
        }
        return from.build();
    }

    public static void loadModel() throws IOException, TException, Searcher.UnknownWordException {
        ProfilingTimer create = ProfilingTimer.create(LOG, "Loading model", new Object[0]);
        Throwable th = null;
        try {
            Word2VecModel.fromThrift((Word2VecModelThrift) ThriftUtils.deserializeJson(new Word2VecModelThrift(), Common.readFileToString(new File("text8.model"))));
            if (create != null) {
                if (0 == 0) {
                    create.close();
                    return;
                }
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (create != null) {
                if (0 != 0) {
                    try {
                        create.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    create.close();
                }
            }
            throw th3;
        }
    }

    public static void skipGram() throws IOException, TException, InterruptedException, Searcher.UnknownWordException {
        Word2VecModel.trainer().setMinVocabFrequency(100).useNumThreads(20).setWindowSize(7).type(NeuralNetworkType.SKIP_GRAM).useHierarchicalSoftmax().setLayerSize(HttpStatus.SC_MULTIPLE_CHOICES).useNegativeSamples(0).setDownSamplingRate(0.001d).setNumIterations(5).setListener(new Word2VecTrainerBuilder.TrainingProgressListener() { // from class: com.rapidminer.word2vec.Word2Vec.4
            @Override // org.allenai.word2vec.Word2VecTrainerBuilder.TrainingProgressListener
            public void update(Word2VecTrainerBuilder.TrainingProgressListener.Stage stage, double d) {
                System.out.println(String.format("%s is %.2f%% complete", Format.formatEnum(stage), Double.valueOf(d * 100.0d)));
            }
        }).train(Lists.transform(Common.readToList(new File("sents.cleaned.word2vec.txt")), new Function<String, List<String>>() { // from class: com.rapidminer.word2vec.Word2Vec.3
            @Override // com.google.common.base.Function
            public List<String> apply(String str) {
                return Arrays.asList(str.split(" "));
            }
        }));
    }

    public void setMinVocabFrequency(int i) {
        this.MinVocabFrequency = i;
    }

    public void setLayerSize(int i) {
        this.LayerSize = i;
    }

    public void setWindowSize(int i) {
        this.WindowSize = i;
    }

    public void setUseNegativeSamples(int i) {
        this.UseNegativeSamples = i;
    }

    public void setDownSamplingRate(double d) {
        this.DownSamplingRate = d;
    }

    public void setNumInterations(int i) {
        this.numInterations = i;
    }

    public void setNumThreads(int i) {
        this.numThreads = i;
    }
}
