package org.allenai.word2vec;

import java.io.File;
import java.io.IOException;
import java.nio.DoubleBuffer;
import org.allenai.word2vec.thrift.Word2VecModelThrift;

/* loaded from: input_file:org/allenai/word2vec/NormalizedWord2VecModel.class */
public class NormalizedWord2VecModel extends Word2VecModel {
    private NormalizedWord2VecModel(Iterable<String> iterable, int i, DoubleBuffer[] doubleBufferArr) {
        super(iterable, i, doubleBufferArr);
        normalize();
    }

    private NormalizedWord2VecModel(Iterable<String> iterable, int i, double[] dArr) {
        super(iterable, i, dArr);
        normalize();
    }

    public static NormalizedWord2VecModel fromWord2VecModel(Word2VecModel word2VecModel) {
        DoubleBuffer[] doubleBufferArr = new DoubleBuffer[word2VecModel.vectors.length];
        for (int i = 0; i < doubleBufferArr.length; i++) {
            doubleBufferArr[i] = word2VecModel.vectors[i].duplicate();
        }
        return new NormalizedWord2VecModel(word2VecModel.vocab, word2VecModel.layerSize, doubleBufferArr);
    }

    public static NormalizedWord2VecModel fromThrift(Word2VecModelThrift word2VecModelThrift) {
        return fromWord2VecModel(Word2VecModel.fromThrift(word2VecModelThrift));
    }

    public static NormalizedWord2VecModel fromBinFile(File file) throws IOException {
        return fromWord2VecModel(Word2VecModel.fromBinFile(file));
    }

    private void normalize() {
        for (int i = 0; i < this.vectors.length; i++) {
            DoubleBuffer doubleBuffer = this.vectors[i];
            for (int i2 = 0; i2 < Math.min(this.vectorsPerBuffer, doubleBuffer.limit() / this.layerSize); i2++) {
                double d = 0.0d;
                for (int i3 = i2 * this.layerSize; i3 < (i2 + 1) * this.layerSize; i3++) {
                    d += doubleBuffer.get(i3) * doubleBuffer.get(i3);
                }
                double sqrt = Math.sqrt(d);
                for (int i4 = i2 * this.layerSize; i4 < (i2 + 1) * this.layerSize; i4++) {
                    doubleBuffer.put(i4, doubleBuffer.get(i4) / sqrt);
                }
            }
        }
    }
}
