All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.models.embeddings.reader.impl;

import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.SetUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

/**
 * Basic implementation for ModelUtils interface, suited for standalone use.
 *
 * PLEASE NOTE: This reader applies normalization to underlying lookup table.
 *
 * @author Adam Gibson
 */
public class BasicModelUtils implements ModelUtils {
    public static final String EXISTS = "exists";
    public static final String CORRECT = "correct";
    public static final String WRONG = "wrong";
    protected volatile VocabCache vocabCache;
    protected volatile WeightLookupTable lookupTable;

    protected volatile boolean normalized  = false;

    private static final Logger log = LoggerFactory.getLogger(BasicModelUtils.class);

    public BasicModelUtils() {

    }

    @Override
    public void init(@NonNull WeightLookupTable lookupTable) {
        this.vocabCache = lookupTable.getVocabCache();
        this.lookupTable = lookupTable;

        // reset normalization trigger on init call
        this.normalized = false;
    }

    /**
     * Returns the similarity of 2 words. Result value will be in range [-1,1], where -1.0 is exact opposite similarity, i.e. NO similarity, and 1.0 is total match of two word vectors.
     * However, most of time you'll see values in range [0,1], but that's something depends of training corpus.
     *
     * Returns NaN if any of labels not exists in vocab, or any label is null
     *
     * @param label1 the first word
     * @param label2 the second word
     * @return a normalized similarity (cosine similarity)
     */
    @Override
    public double similarity( String label1, String label2) {
        if (label1 == null || label2 == null) {
            log.debug("LABELS: " + label1 + ": " + (label1 == null ? "null": EXISTS)+ ";" + label2 +" vec2:" + (label2 == null ? "null": EXISTS));
            return Double.NaN;
        }

        INDArray vec1 = lookupTable.vector(label1).dup();
        INDArray vec2 = lookupTable.vector(label2).dup();


        if (vec1 == null || vec2 == null) {
            log.debug(label1 + ": " + (vec1 == null ? "null": EXISTS)+ ";" + label2 +" vec2:" + (vec2 == null ? "null": EXISTS));
            return Double.NaN;
        }

        if (label1.equals(label2)) return 1.0;

        return Transforms.cosineSim(vec1, vec2);
    }


    @Override
    public Collection wordsNearest(String label, int n) {
        Collection collection = wordsNearest(Arrays.asList(label),new ArrayList(),n + 1);
        if (collection.contains(label)) collection.remove(label);
        return collection;
    }

    /**
     * Accuracy based on questions which are a space separated list of strings
     * where the first word is the query word, the next 2 words are negative,
     * and the last word is the predicted word to be nearest
     * @param questions the questions to ask
     * @return the accuracy based on these questions
     */
    @Override
    public Map accuracy(List questions) {
        Map accuracy = new HashMap<>();
        Counter right = new Counter<>();
        String analogyType = "";
        for(String s : questions) {
            if(s.startsWith(":")) {
                double correct = right.getCount(CORRECT);
                double wrong = right.getCount(WRONG);
                if(analogyType.isEmpty()){
                    analogyType=s;
                    continue;
                }
                double accuracyRet = 100.0 * correct / (correct + wrong);
                accuracy.put(analogyType,accuracyRet);
                analogyType = s;
                right.clear();
            }
            else {
                String[] split = s.split(" ");
                String word = split[0];
                List positive = Arrays.asList(word);
                List negative = Arrays.asList(split[1],split[2]);
                String predicted = split[3];
                String w = wordsNearest(positive,negative,1).iterator().next();
                if(predicted.equals(w))
                    right.incrementCount(CORRECT,1.0);
                else
                    right.incrementCount(WRONG,1.0);

            }
        }
        if(!analogyType.isEmpty()){
            double correct = right.getCount(CORRECT);
            double wrong = right.getCount(WRONG);
            double accuracyRet = 100.0 * correct / (correct + wrong);
            accuracy.put(analogyType,accuracyRet);
        }
        return accuracy;
    }

    /**
     * Find all words with a similar characters
     * in the vocab
     * @param word the word to compare
     * @param accuracy the accuracy: 0 to 1
     * @return the list of words that are similar in the vocab
     */
    @Override
    public List similarWordsInVocabTo(String word, double accuracy) {
        List ret = new ArrayList<>();
        for(String s : vocabCache.words()) {
            if(MathUtils.stringSimilarity(word, s) >= accuracy)
                ret.add(s);
        }
        return ret;
    }

    public Collection wordsNearest(Collection positive, Collection negative, int top) {
        // Check every word is in the model
        for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) {
            if (!vocabCache.containsWord(p)) {
                return new ArrayList<>();
            }
        }

        INDArray words = Nd4j.create(positive.size() + negative.size(), lookupTable.layerSize());
        int row = 0;
        //Set union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative));
        for (String s : positive) {
            words.putRow(row++, lookupTable.vector(s));
        }

        for (String s : negative) {
            words.putRow(row++, lookupTable.vector(s).mul(-1));
        }

        INDArray mean = words.isMatrix() ? words.mean(0) : words;

        return wordsNearest(mean, top);
    }

    /**
     * Get the top n words most similar to the given word
     * @param word the word to compare
     * @param n the n to get
     * @return the top n words
     */
    @Override
    public Collection wordsNearestSum(String word, int n) {
        //INDArray vec = Transforms.unitVec(this.lookupTable.vector(word));
        INDArray vec = this.lookupTable.vector(word);
        return wordsNearestSum(vec, n);
    }

    /**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearest(INDArray words, int top) {
        if(lookupTable instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;

            INDArray syn0 = l.getSyn0();

            if (!normalized) {
                synchronized (this) {
                    if (!normalized) {
                        syn0.diviColumnVector(syn0.norm1(1));
                        normalized = true;
                    }
                }
            }

            INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());

            List highToLowSimList = getTopN(similarity, top + 20);

            List result = new ArrayList<>();

            for (int i = 0; i < highToLowSimList.size(); i++) {
                String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
                if (word != null && !word.equals("UNK") && !word.equals("STOP") ) {
                    INDArray otherVec = lookupTable.vector(word);
                    double sim = Transforms.cosineSim(words, otherVec);

                    result.add(new WordSimilarity(word, sim));
                }
            }

            Collections.sort(result, new SimilarityComparator());

            return getLabels(result, top);
        }

        Counter distances = new Counter<>();

        for(String s : vocabCache.words()) {
            INDArray otherVec = lookupTable.vector(s);
            double sim = Transforms.cosineSim(words, otherVec);
            distances.incrementCount(s, sim);
        }


        distances.keepTopNKeys(top);
        return distances.keySet();


    }

    /**
     * Get top N elements
     *
     * @param vec the vec to extract the top elements from
     * @param N the number of elements to extract
     * @return the indices and the sorted top N elements
     */
    private List getTopN(INDArray vec, int N) {
        ArrayComparator comparator = new ArrayComparator();
        PriorityQueue queue = new PriorityQueue<>(vec.rows(),comparator);

        for (int j = 0; j < vec.length(); j++) {
            final Double[] pair = new Double[]{vec.getDouble(j), (double) j};
            if (queue.size() < N) {
                queue.add(pair);
            } else {
                Double[] head = queue.peek();
                if (comparator.compare(pair, head) > 0) {
                    queue.poll();
                    queue.add(pair);
                }
            }
        }

        List lowToHighSimLst = new ArrayList<>();

        while (!queue.isEmpty()) {
            double ind = queue.poll()[1];
            lowToHighSimLst.add(ind);
        }
        return Lists.reverse(lowToHighSimLst);
    }

    /**
     * Words nearest based on positive and negative words
     * * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearestSum(INDArray words,int top) {

        if(lookupTable instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).sum(1);
            INDArray[] sorted = Nd4j.sortWithIndices(distances,0,false);
            INDArray sort = sorted[0];
            List ret = new ArrayList<>();
            if(top > sort.length())
                top = sort.length();
            //there will be a redundant word
            int end = top;
            for(int i = 0; i < end; i++) {
                String add = vocabCache.wordAtIndex(sort.getInt(i));
                if(add == null || add.equals("UNK") || add.equals("STOP")) {
                    end++;
                    if(end >= sort.length())
                        break;
                    continue;
                }
                ret.add(vocabCache.wordAtIndex(sort.getInt(i)));
            }
            return ret;
        }

        Counter distances = new Counter<>();

        for(String s : vocabCache.words()) {
            INDArray otherVec = lookupTable.vector(s);
            double sim = Transforms.cosineSim(words, otherVec);
            distances.incrementCount(s, sim);
        }

        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    /**
     * Words nearest based on positive and negative words
     * @param positive the positive words
     * @param negative the negative words
     * @param top the top n words
     * @return the words nearest the mean of the words
     */
    @Override
    public Collection wordsNearestSum(Collection positive, Collection negative, int top) {
        INDArray words = Nd4j.create(lookupTable.layerSize());
    //    Set union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative));
        for(String s : positive)
            words.addi(lookupTable.vector(s));


        for(String s : negative)
            words.addi(lookupTable.vector(s).mul(-1));

        return wordsNearestSum(words,top);
    }


    private static class SimilarityComparator implements Comparator {
        @Override
        public int compare(WordSimilarity o1, WordSimilarity o2) {
            if (Double.isNaN(o1.getSimilarity()) && Double.isNaN(o2.getSimilarity())) {
                return 0;
            } else if (Double.isNaN(o1.getSimilarity()) && !Double.isNaN(o2.getSimilarity())) {
                return -1;
            } else if (!Double.isNaN(o1.getSimilarity()) && Double.isNaN(o2.getSimilarity())) {
                return 1;
            }
            return Double.compare(o2.getSimilarity(), o1.getSimilarity());
        }
    }

    private static class ArrayComparator implements Comparator {
        @Override
        public int compare(Double[] o1, Double[] o2) {
            if (Double.isNaN(o1[0]) && Double.isNaN(o2[0])) {
                return 0;
            } else if (Double.isNaN(o1[0]) && !Double.isNaN(o2[0])) {
                return -1;
            } else if (!Double.isNaN(o1[0]) && Double.isNaN(o2[0])) {
                return 1;
            }
            return Double.compare(o1[0], o2[0]);
        }
    }

    @Data
    @AllArgsConstructor
    private static class WordSimilarity {
        private String word;
        private double similarity;
    }

    private List getLabels(List results, int limit) {
        List result = new ArrayList<>();
        for (int x = 0; x < results.size(); x++) {
            result.add(results.get(x).getWord());
            if (result.size() >= limit)
                break;
        }

        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy