All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding.NGramFeatureExtractor Maven / Gradle / Ivy

/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 *
 * This Java port of CLD3 was derived from Google's CLD3 project at https://github.com/google/cld3
 */
package org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembedding;

import org.apache.lucene.util.Counter;

import java.util.Map;
import java.util.TreeMap;

/**
 * This provides an array of {@link FeatureValue} for the given nGram size and dimensionId
 *
 * Each feature value contains the average occurrence of an nGram and its "id". This id is determined via a custom hash ({@link Hash32})
 * and the provided dimensionId
 */
public class NGramFeatureExtractor implements FeatureExtractor {

    private static final Hash32 hashing = new Hash32();

    private final int nGrams;
    private final int dimensionId;

    public NGramFeatureExtractor(int nGrams, int dimensionId) {
        this.nGrams = nGrams;
        this.dimensionId = dimensionId;
    }

    @Override
    public FeatureValue[] extractFeatures(String text) {
        // First add terminators:
        // Split the text based on spaces to get tokens, adds "^"
        // to the beginning of each token, and adds "$" to the end of each token.
        // e.g.
        // " this text is written in english" goes to
        // "^$ ^this$ ^text$ ^is$ ^written$ ^in$ ^english$ ^$"
        StringBuilder newText = new StringBuilder("^");
        for (int i = 0; i < text.length(); i++) {
            char c = text.charAt(i);
            if (c == ' ') {
                newText.append("$ ^");
            } else {
                newText.append(c);
            }
        }
        newText.append("$");

        // Find the char ngrams
        // ^$ ^this$ ^text$ ^is$ ^written$ ^in$ ^english$ ^$"
        // nGramSize = 2
        // [{h$},{sh},{li},{gl},{in},{en},{^$},...]
        Map charNGrams = new TreeMap<>();

        int countSum = 0;
        String textWithTerminators = newText.toString();
        int end = textWithTerminators.length() - nGrams;
        for (int start = 0; start <= end; ++start) {
            StringBuilder charNGram = new StringBuilder();

            int index;
            for (index = 0; index < nGrams; ++index) {
                char currentChar = textWithTerminators.charAt(start + index);
                if (currentChar == ' ') {
                    break;
                }
                charNGram.append(currentChar);
            }

            if (index == nGrams) {
                charNGrams.computeIfAbsent(charNGram.toString(), ngram -> Counter.newCounter()).addAndGet(1);
                ++countSum;
            }
        }

        FeatureValue[] results = new FeatureValue[charNGrams.size()];
        int index = 0;
        for (Map.Entry entry : charNGrams.entrySet()) {
            String key = entry.getKey();
            long value = entry.getValue().get();

            double weight = (double) value / (double) countSum;
            // We need to use the special hashing so that we choose the appropriate weight+ quantile
            // when building the feature vector.
            int id = (int)(hashing.hash(key) % dimensionId);

            results[index++] = new ContinuousFeatureValue(id, weight);
        }
        return results;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy