All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.embeddings.wordvectors;

import com.google.common.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.commons.lang.ArrayUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.io.CollectionUtils;

import java.util.*;

/**
 * Common word vector operations
 * @author Adam Gibson
 */
public class WordVectorsImpl implements WordVectors {
    private static final long serialVersionUID = 78249242142L;

    //number of times the word must occur in the vocab to appear in the calculations, otherwise treat as unknown
    @Getter
    protected int minWordFrequency = 5;
    @Getter
    protected WeightLookupTable lookupTable;
    @Getter
    protected VocabCache vocab;
    protected int layerSize = 100;
    @Getter
    protected transient ModelUtils modelUtils = new BasicModelUtils<>();
    private boolean initDone = false;

    protected int numIterations = 1;
    protected int numEpochs = 1;
    protected double negative = 0;
    protected double sampling = 0;
    protected AtomicDouble learningRate = new AtomicDouble(0.025);
    protected double minLearningRate = 0.01;
    @Getter
    protected int window = 5;
    protected int batchSize;
    protected int learningRateDecayWords;
    protected boolean resetModel;
    protected boolean useAdeGrad;
    protected int workers = 1; //Runtime.getRuntime().availableProcessors();
    protected boolean trainSequenceVectors = false;
    protected boolean trainElementsVectors = true;
    protected long seed;
    protected boolean useUnknown = false;
    protected int[] variableWindows;


    /**
     * This method returns word vector size
     *
     * @return
     */
    public int getLayerSize() {
        if (lookupTable != null && lookupTable.getWeights() != null) {
            return lookupTable.getWeights().columns();
        } else
            return layerSize;
    }

    public final static String DEFAULT_UNK = "UNK";
    @Getter
    @Setter
    private String UNK = DEFAULT_UNK;

    @Getter
    protected Collection stopWords = new ArrayList<>(); //StopWords.getStopWords();

    /**
     * Returns true if the model has this word in the vocab
     * @param word the word to test for
     * @return true if the model has the word in the vocab
     */
    public boolean hasWord(String word) {
        return vocab().indexOf(word) >= 0;
    }

    /**
     * Words nearest based on positive and negative words
     * @param positive the positive words
     * @param negative the negative words
     * @param top the top n words
     * @return the words nearest the mean of the words
     */
    public Collection wordsNearestSum(Collection positive, Collection negative, int top) {
        return modelUtils.wordsNearestSum(positive, negative, top);
    }

    /**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearestSum(INDArray words, int top) {
        return modelUtils.wordsNearestSum(words, top);
    }

    /**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearest(INDArray words, int top) {
        return modelUtils.wordsNearest(words, top);
    }

    /**
     * Get the top n words most similar to the given word
     * @param word the word to compare
     * @param n the n to get
     * @return the top n words
     */
    public Collection wordsNearestSum(String word, int n) {
        return modelUtils.wordsNearestSum(word, n);
    }


    /**
     * Accuracy based on questions which are a space separated list of strings
    * where the first word is the query word, the next 2 words are negative,
    * and the last word is the predicted word to be nearest
    * @param questions the questions to ask
    * @return the accuracy based on these questions
    */
    public Map accuracy(List questions) {
        return modelUtils.accuracy(questions);
    }

    @Override
    public int indexOf(String word) {
        return vocab().indexOf(word);
    }


    /**
     * Find all words with a similar characters
     * in the vocab
     * @param word the word to compare
     * @param accuracy the accuracy: 0 to 1
     * @return the list of words that are similar in the vocab
     */
    public List similarWordsInVocabTo(String word, double accuracy) {
        return this.modelUtils.similarWordsInVocabTo(word, accuracy);
    }

    /**
     * Get the word vector for a given matrix
     * @param word the word to get the matrix for
     * @return the ndarray for this word
     */
    public double[] getWordVector(String word) {
        INDArray r = getWordVectorMatrix(word);
        if (r == null)
            return null;
        return r.dup().data().asDouble();
    }

    /**
     * Returns the word vector divided by the norm2 of the array
     * @param word the word to get the matrix for
     * @return the looked up matrix
     */
    public INDArray getWordVectorMatrixNormalized(String word) {
        INDArray r = getWordVectorMatrix(word);
        if (r == null)
            return null;

        return r.div(Nd4j.getBlasWrapper().nrm2(r));
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        return lookupTable().vector(word);
    }


    /**
     * Words nearest based on positive and negative words
     *
     * @param positive the positive words
     * @param negative the negative words
     * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearest(Collection positive, Collection negative, int top) {
        return modelUtils.wordsNearest(positive, negative, top);
    }

    /**
     * This method returns 2D array, where each row represents corresponding label
     *
     * @param labels
     * @return
     */
    @Override
    public INDArray getWordVectors(@NonNull Collection labels) {
        int indexes[] = new int[labels.size()];
        int cnt = 0;
        boolean useIndexUnknown = useUnknown && vocab.containsWord(getUNK());

        for (String label : labels) {
            if (vocab.containsWord(label)) {
                indexes[cnt] = vocab.indexOf(label);
            } else
                indexes[cnt] = useIndexUnknown ? vocab.indexOf(getUNK()) : -1;
            cnt++;
        }

        while (ArrayUtils.contains(indexes, -1)) {
            indexes = ArrayUtils.removeElement(indexes, -1);
        }
        if (indexes.length == 0) {
                return Nd4j.empty(((InMemoryLookupTable)lookupTable).getSyn0().dataType());
        }

        INDArray result = Nd4j.pullRows(lookupTable.getWeights(), 1, indexes);
        return result;
    }

    /**
     * This method returns mean vector, built from words/labels passed in
     *
     * @param labels
     * @return
     */
    @Override
    public INDArray getWordVectorsMean(Collection labels) {
        INDArray array = getWordVectors(labels);
        return array.mean(0);
    }

    /**
     * Get the top n words most similar to the given word
     * @param word the word to compare
     * @param n the n to get
     * @return the top n words
     */
    public Collection wordsNearest(String word, int n) {
        return modelUtils.wordsNearest(word, n);
    }


    /**
     * Returns similarity of two elements, provided by ModelUtils
     *
     * @param word the first word
     * @param word2 the second word
     * @return a normalized similarity (cosine similarity)
     */
    public double similarity(String word, String word2) {
        return modelUtils.similarity(word, word2);
    }

    @Override
    public VocabCache vocab() {
        return vocab;
    }

    @Override
    public WeightLookupTable lookupTable() {
        return lookupTable;
    }

    @Override
    @SuppressWarnings("unchecked")
    public void setModelUtils(@NonNull ModelUtils modelUtils) {
        if (lookupTable != null) {
            modelUtils.init(lookupTable);
            this.modelUtils = modelUtils;
            //0.25, -0.03, -0.47, 0.10, -0.25, 0.28, 0.37,
        }
    }

    public void setLookupTable(@NonNull WeightLookupTable lookupTable) {
        this.lookupTable = lookupTable;
        if (modelUtils == null)
            this.modelUtils = new BasicModelUtils<>();

        this.modelUtils.init(lookupTable);
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    protected void update() {
        update(EnvironmentUtils.buildEnvironment(), Event.STANDALONE);
    }

    protected void update(Environment env, Event event) {
        if (!initDone) {
            initDone = true;

            Heartbeat heartbeat = Heartbeat.getInstance();
            Task task = new Task();
            task.setNumFeatures(layerSize);
            if (vocab != null)
                task.setNumSamples(vocab.numWords());
            task.setNetworkType(Task.NetworkType.DenseNetwork);
            task.setArchitectureType(Task.ArchitectureType.WORDVECTORS);

            heartbeat.reportEvent(event, env, task);
        }
    }

    @Override
    public void loadWeightsInto(INDArray array) {
        array.assign(lookupTable.getWeights());
    }

    @Override
    public long vocabSize() {
        return lookupTable.getWeights().size(0);
    }

    @Override
    public int vectorSize() {
        return lookupTable.layerSize();
    }

    @Override
    public boolean jsonSerializable() {
        return false;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy