All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.latinolib.bow.BowSpace Maven / Gradle / Ivy

package org.latinolib.bow;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.latinolib.VectorEntry;
import org.latinolib.model.ModelUtils;
import org.latinolib.SparseVector;
import org.latinolib.stemmer.Stemmer;
import org.latinolib.stopwords.StopWords;
import org.latinolib.tokenizer.SimpleTokenizer;
import org.latinolib.tokenizer.SimpleTokenizerType;
import org.latinolib.tokenizer.Token;
import org.latinolib.tokenizer.Tokenizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.*;

/**
 * Author saxo
 */
public class BowSpace implements Serializable
{
    private static final long serialVersionUID = 7695534861034247430L;

    private transient Logger logger = LoggerFactory.getLogger(BowSpace.class);

    private Tokenizer tokenizer = new SimpleTokenizer(SimpleTokenizerType.ALL_CHARS);
    private StopWords stopWords = null;
    private Stemmer stemmer = null;
    private Map wordInfo = new HashMap();
    private List idxInfo = Lists.newArrayList();
    private int maxNGramLen = 2;
    private int minWordFreq = 5;
    private WordWeightType wordWeightType = WordWeightType.TERM_FREQ;
    private double cutLowWeightsPerc = 0.2;
    private boolean normalizeVectors = true;
    private boolean keepWordForms = false;

    public Tokenizer getTokenizer() {
        return tokenizer;
    }

    public void setTokenizer(Tokenizer tokenizer) {
        this.tokenizer = Preconditions.checkNotNull(tokenizer);
    }

    public StopWords getStopWords() {
        return stopWords;
    }

    public void setStopWords(StopWords stopWords) {
        this.stopWords = stopWords;
    }

    public Stemmer getStemmer() {
        return stemmer;
    }

    public void setStemmer(Stemmer stemmer) {
        this.stemmer = stemmer;
    }

    public Map getWordInfo() {
        return wordInfo;
    }

    public void setWordInfo(Map wordInfo) {
        this.wordInfo = wordInfo;
    }

    public List getIdxInfo() {
        return Collections.unmodifiableList(idxInfo);
    }

    public int getMaxNGramLen() {
        return maxNGramLen;
    }

    public void setMaxNGramLen(int maxNGramLen) {
        Preconditions.checkArgument(maxNGramLen >= 1);
        this.maxNGramLen = maxNGramLen;
    }

    public int getMinWordFreq() {
        return minWordFreq;
    }

    public void setMinWordFreq(int minWordFreq) {
        Preconditions.checkArgument(minWordFreq >= 1);
        this.minWordFreq = minWordFreq;
    }

    public WordWeightType getWordWeightType() {
        return wordWeightType;
    }

    public void setWordWeightType(WordWeightType wordWeightType) {
        this.wordWeightType = wordWeightType;
    }

    public double getCutLowWeightsPerc() {
        return cutLowWeightsPerc;
    }

    public void setCutLowWeightsPerc(double cutLowWeightsPerc) {
        Preconditions.checkArgument(cutLowWeightsPerc >= 0 && cutLowWeightsPerc < 1);
        this.cutLowWeightsPerc = cutLowWeightsPerc;
    }

    public boolean isNormalizeVectors() {
        return normalizeVectors;
    }

    public void setNormalizeVectors(boolean normalizeVectors) {
        this.normalizeVectors = normalizeVectors;
    }

    public boolean isKeepWordForms() {
        return keepWordForms;
    }

    public void setKeepWordForms(boolean keepWordForms) {
        this.keepWordForms = keepWordForms;
    }

    private void processNGramsPass1(List nGrams, int startIdx, Set docWords) {
        String nGramStem = "";
        String nGram = "";
        for (int i = startIdx; i < nGrams.size(); i++) {
            nGram += nGrams.get(i).word;
            nGramStem += nGrams.get(i).stem;
            if (!wordInfo.containsKey(nGramStem)) {
                Word nGramInfo = new Word(nGram, nGramStem);
                wordInfo.put(nGramStem, nGramInfo);
                docWords.add(nGramStem);
            } else {
                Word nGramInfo = wordInfo.get(nGramStem);
                if (!docWords.contains(nGramStem)) {
                    docWords.add(nGramStem);
                    nGramInfo.docFreq++;
                }
                nGramInfo.freq++;
                Integer freq = nGramInfo.forms.get(nGram);
                nGramInfo.forms.put(nGram, freq == null ? 1 : ++freq);
            }
            nGram += " ";
            nGramStem += " ";
        }
    }

    private void processNGramsPass2(List nGrams, int startIdx, Map tfVec) {
        String nGramStem = "";
        for (int i = startIdx; i < nGrams.size(); i++) {
            nGramStem += nGrams.get(i).stem;
            if (wordInfo.containsKey(nGramStem)) {
                Word wordInfo = this.wordInfo.get(nGramStem);
                if (wordInfo.idx == -1) {
                    wordInfo.idx = idxInfo.size();
                    tfVec.put(wordInfo.idx, 1);
                    idxInfo.add(wordInfo);
                } else {
                    Integer tfVal = tfVec.get(wordInfo.idx);
                    tfVec.put(wordInfo.idx, tfVal == null ? 1 : ++tfVal);
                }
            } else {
                break;
            }
            nGramStem += " ";
        }
    }

    public List initialize(Iterable documents) {
        return initialize(documents, true, false);
    }

    private String normalizeText(String text, boolean doNormalize) {
        return doNormalize ? text.trim().toLowerCase() : text;
    }

    public List initialize(Iterable documents, boolean normalizeTokens, boolean largeScale) {
        Preconditions.checkNotNull(documents);

        wordInfo.clear();
        idxInfo.clear();
        List bows = Lists.newArrayList();

        // build vocabulary
        logger.debug("Building vocabulary ...");
        int docCount = 0;
        if (!largeScale) {
            for (String document : documents) {
                docCount++;
                logger.debug("Document {} ...", docCount);
                Set docWords = new HashSet();
                List nGrams = Lists.newArrayListWithCapacity(maxNGramLen);
                for (Token token : tokenizer.getTokens(document)) {
                    String word = normalizeText(token.getText(), normalizeTokens);
                    if (stopWords == null || !stopWords.isStopWord(word)) {
                        String stem = stemmer == null ? word : normalizeText(stemmer.getStem(word), normalizeTokens);
                        if (nGrams.size() < maxNGramLen) {
                            WordStem wordStem = new WordStem();
                            wordStem.word = word;
                            wordStem.stem = stem;
                            nGrams.add(wordStem);
                            if (nGrams.size() < maxNGramLen) {
                                continue;
                            }
                        } else {
                            WordStem wordStem = nGrams.get(0);
                            wordStem.word = word;
                            wordStem.stem = stem;
                            for (int i = 0; i < maxNGramLen - 1; i++) {
                                nGrams.set(i, nGrams.get(i + 1));
                            }
                            nGrams.set(maxNGramLen - 1, wordStem);
                        }
                        processNGramsPass1(nGrams, 0, docWords);
                    }
                }
                int startIdx = nGrams.size() == maxNGramLen ? 1 : 0;
                for (int i = startIdx; i < nGrams.size(); i++) {
                    processNGramsPass1(nGrams, i, docWords);
                }
            }
            logger.debug("Document {} ...", docCount);
        } else // large-scale mode (needs less memory, slower)
        {
            for (int n = 1; n <= maxNGramLen; n++) {
                docCount = 0;
                logger.debug("Pass {} of {} ...", n, maxNGramLen);
                for (String document : documents) {
                    docCount++;
                    logger.debug("initialize", "Document {} ...", docCount);
                    List nGrams = Lists.newArrayListWithCapacity(n);
                    Set docWords = new HashSet();
                    for (Token token : tokenizer.getTokens(document)) {
                        String word = normalizeText(token.getText(), normalizeTokens);
                        if (stopWords == null || !stopWords.isStopWord(word)) {
                            String stem = stemmer == null ? word : normalizeText(stemmer.getStem(word), normalizeTokens);
                            if (nGrams.size() < n) {
                                WordStem wordStem = new WordStem();
                                wordStem.word = word;
                                wordStem.stem = stem;
                                nGrams.add(wordStem);
                                if (nGrams.size() < n) {
                                    continue;
                                }
                            } else {
                                WordStem wordStem = nGrams.get(0);
                                wordStem.word = word;
                                wordStem.stem = stem;
                                for (int i = 0; i < n - 1; i++) {
                                    nGrams.set(i, nGrams.get(i + 1));
                                }
                                nGrams.set(n - 1, wordStem);
                            }
                            String nGram = nGrams.get(0).word;
                            String nGramStem = nGrams.get(0).stem;
                            if (n > 1) {
                                for (int i = 1; i < n - 1; i++) {
                                    nGram += " " + nGrams.get(i).word;
                                    nGramStem += " " + nGrams.get(i).stem;
                                }
                                if (!wordInfo.containsKey(nGramStem)) {
                                    continue;
                                }
                                if (wordInfo.get(nGramStem).freq < minWordFreq) {
                                    continue;
                                }
                                String nGramStem2 = "";
                                for (int i = 1; i < n - 1; i++) {
                                    nGramStem2 += nGrams.get(i).stem + " ";
                                }
                                nGramStem2 += nGrams.get(n - 1).stem;
                                if (!wordInfo.containsKey(nGramStem2)) {
                                    continue;
                                }
                                if (wordInfo.get(nGramStem2).freq < minWordFreq) {
                                    continue;
                                }
                                nGram += " " + nGrams.get(n - 1).word;
                                nGramStem += " " + nGrams.get(n - 1).stem;
                            }
                            if (!wordInfo.containsKey(nGramStem)) {
                                Word nGramInfo = new Word(nGram, nGramStem);
                                wordInfo.put(nGramStem, nGramInfo);
                                docWords.add(nGramStem);
                            } else {
                                Word nGramInfo = wordInfo.get(nGramStem);
                                if (!docWords.contains(nGramStem)) {
                                    nGramInfo.docFreq++;
                                    docWords.add(nGramStem);
                                }
                                nGramInfo.freq++;
                                Integer count = nGramInfo.forms.get(nGram);
                                nGramInfo.forms.put(nGram, count == null ? 1 : ++count);
                            }
                        }
                    }
                }
                logger.debug("Document {} ...", docCount);
            }
        }
        // remove unfrequent words and n-grams, precompute IDF
        List removeList = Lists.newArrayList();
        for (Map.Entry info : wordInfo.entrySet()) {
            if (info.getValue().freq < minWordFreq) {
                removeList.add(info.getKey());
            } else {
                info.getValue().idf = Math.log((double) docCount / (double) info.getValue().docFreq);
            }
        }
        for (String key : removeList) {
            wordInfo.remove(key);
        }

        // determine most frequent word and n-gram forms
        for (Word info : wordInfo.values()) {
            int max = 0;
            for (Map.Entry wordForm : info.forms.entrySet()) {
                if (wordForm.getValue() > max) {
                    max = wordForm.getValue();
                    info.mostFrequentForm = wordForm.getKey();
                }
            }
            if (!keepWordForms) {
                info.forms.clear();
            }
        }

        // compute bag-of-words vectors
        logger.debug("Computing bag-of-words vectors ...");
        int docNum = 1;
        for (String document : documents) {
            logger.debug("initialize", "Document {} / {} ...", docNum++, docCount);
            Map tfVec = new HashMap();
            List nGrams = Lists.newArrayListWithCapacity(maxNGramLen);
            for (Token token : tokenizer.getTokens(document)) {
                String word = normalizeText(token.getText(), normalizeTokens);
                if (stopWords == null || !stopWords.isStopWord(word)) {
                    String stem = stemmer == null ? word : normalizeText(stemmer.getStem(word), normalizeTokens);
                    if (nGrams.size() < maxNGramLen) {
                        WordStem wordStem = new WordStem();
                        wordStem.word = word;
                        wordStem.stem = stem;
                        nGrams.add(wordStem);
                        if (nGrams.size() < maxNGramLen) {
                            continue;
                        }
                    } else {
                        WordStem wordStem = nGrams.get(0);
                        wordStem.word = word;
                        wordStem.stem = stem;
                        for (int i = 0; i < maxNGramLen - 1; i++) {
                            nGrams.set(i, nGrams.get(i + 1));
                        }
                        nGrams.set(maxNGramLen - 1, wordStem);
                    }
                    processNGramsPass2(nGrams, 0, tfVec);
                }
            }
            int startIdx = nGrams.size() == maxNGramLen ? 1 : 0;
            for (int i = startIdx; i < nGrams.size(); i++) {
                processNGramsPass2(nGrams, i, tfVec);
            }
            SparseVector docVec = new SparseVector(0);
            switch (wordWeightType) {
                case TERM_FREQ:
                    for (Map.Entry tfItem : tfVec.entrySet()) {
                        docVec.add(tfItem.getKey(), tfItem.getValue());
                    }
                    break;
                case TF_IDF:
                    for (Map.Entry tfItem : tfVec.entrySet()) {
                        double tfIdf = (double) tfItem.getValue() * idxInfo.get(tfItem.getKey()).idf;
                        if (tfIdf > 0) {
                            docVec.add(tfItem.getKey(), tfIdf);
                        }
                    }
                    break;
                case LOG_DF_TF_IDF:
                    for (Map.Entry tfItem : tfVec.entrySet()) {
                        double tfIdf = (double) tfItem.getValue() * idxInfo.get(tfItem.getKey()).idf;
                        if (tfIdf > 0) {
                            docVec.add(tfItem.getKey(), Math.log(1 + idxInfo.get(tfItem.getKey()).docFreq) * tfIdf);
                        }
                    }
                    break;
            }
            docVec.sort();
            docVec = ModelUtils.cutLowWeights(docVec, cutLowWeightsPerc);
            if (normalizeVectors) {
                ModelUtils.tryNrmVecL2(docVec);
            }
            bows.add(docVec);
        }
        return bows;
    }

    private void processDocumentNGrams(List nGrams, int startIdx, Map tfVec) {
        String nGramStem = "";
        String nGram = "";
        for (int i = startIdx; i < nGrams.size(); i++) {
            nGram += nGrams.get(i).word;
            nGramStem += nGrams.get(i).stem;
            if (wordInfo.containsKey(nGramStem)) {
                int stemIdx = wordInfo.get(nGramStem).idx;
                Integer idx = tfVec.get(stemIdx);
                tfVec.put(stemIdx, idx == null ? 1 : ++idx);
            }
            nGram += " ";
            nGramStem += " ";
        }
    }

    public SparseVector processDocument(String document) {
        return processDocument(document, true);
    }

    public SparseVector processDocument(String document, boolean normalizeTokens) {
        return processDocument(document, stemmer, normalizeTokens);
    }

    public SparseVector processDocument(String document, Stemmer stemmer, boolean normalizeTokens) {
        Preconditions.checkNotNull(document);
        Map tfVec = new HashMap();
        List nGrams = Lists.newArrayListWithCapacity(maxNGramLen);
        for (Token token : tokenizer.getTokens(document)) {
            String word = normalizeText(token.getText(), normalizeTokens);
            if (stopWords == null || !stopWords.isStopWord(word)) {
                String stem = stemmer == null ? word : normalizeText(stemmer.getStem(word), normalizeTokens);
                if (nGrams.size() < maxNGramLen) {
                    WordStem wordStem = new WordStem();
                    wordStem.word = word;
                    wordStem.stem = stem;
                    nGrams.add(wordStem);
                    if (nGrams.size() < maxNGramLen) {
                        continue;
                    }
                } else {
                    WordStem wordStem = nGrams.get(0);
                    wordStem.word = word;
                    wordStem.stem = stem;
                    for (int i = 0; i < maxNGramLen - 1; i++) {
                        nGrams.set(i, nGrams.get(i + 1));
                    }
                    nGrams.set(maxNGramLen - 1, wordStem);
                }
                processDocumentNGrams(nGrams, 0, tfVec);
            }
        }
        int startIdx = nGrams.size() == maxNGramLen ? 1 : 0;
        for (int i = startIdx; i < nGrams.size(); i++) {
            processDocumentNGrams(nGrams, i, tfVec);
        }
        SparseVector docVec = new SparseVector();
        switch (wordWeightType) {
            case TERM_FREQ:
                for (Map.Entry tfItem : tfVec.entrySet()) {
                    docVec.add(tfItem.getKey(), tfItem.getValue());
                }
                break;
            case TF_IDF:
                for (Map.Entry tfItem : tfVec.entrySet()) {
                    double tfIdf = (double) tfItem.getValue() * idxInfo.get(tfItem.getKey()).idf;
                    if (tfIdf > 0) {
                        docVec.add(tfItem.getKey(), tfIdf);
                    }
                }
                break;
            case LOG_DF_TF_IDF:
                for (Map.Entry tfItem : tfVec.entrySet()) {
                    double tfIdf = (double) tfItem.getValue() * idxInfo.get(tfItem.getKey()).idf;
                    if (tfIdf > 0) {
                        docVec.add(tfItem.getKey(), Math.log(1 + idxInfo.get(tfItem.getKey()).docFreq) * tfIdf);
                    }
                }
                break;
        }
        docVec.sort();
        docVec = ModelUtils.cutLowWeights(docVec, cutLowWeightsPerc);
        if (normalizeVectors) {
            ModelUtils.tryNrmVecL2(docVec);
        }
        return docVec;
    }

    public List getKeywords(SparseVector bowVec) {
        Preconditions.checkNotNull(bowVec);
        List keywords = Lists.newArrayListWithCapacity(bowVec.size());
        for (VectorEntry item : bowVec) {
            keywords.add(new Keyword(item.getData(), idxInfo.get(item.getIndex())));
        }
        Collections.sort(keywords, Collections.reverseOrder());
        return keywords;
    }

    public List getKeywords(SparseVector bowVec, int n) {
        Preconditions.checkNotNull(bowVec);
        Preconditions.checkArgument(n > 0);
        List keywords = getKeywords(bowVec);
        int keywordCount = Math.min(n, keywords.size());
        List result = Lists.newArrayListWithCapacity(keywordCount);
        for (int i = 0; i < keywordCount; i++) {
            result.add(keywords.get(i).getWord());
        }
        return result;
    }

    public String getKeywordsStr(SparseVector bowVec, int n) {
        List keywords = getKeywords(bowVec, n);
        if (keywords.size() == 0) {
            return "";
        }
        String keywordsStr = keywords.get(0).mostFrequentForm;
        for (int i = 1; i < keywords.size(); i++) {
            keywordsStr += ", " + keywords.get(i).mostFrequentForm;
        }
        return keywordsStr;
    }

    private static class WordStem
    {
        public String word;
        public String stem;
    }

    public static class Keyword implements Comparable
    {
        private final double factor;
        private final Word word;

        public Keyword(double factor, Word word) {
            this.factor = factor;
            this.word = word;
        }

        public double getFactor() {
            return factor;
        }

        public Word getWord() {
            return word;
        }

        @Override
        public int compareTo(Keyword o) {
            return Double.compare(factor, o.factor);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy