All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.flipkart.fdp.ml.transformer.CountVectorizerTransformer Maven / Gradle / Ivy

There is a newer version: 0.4.0
Show newest version
package com.flipkart.fdp.ml.transformer;

import com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/**
 * Transforms input/ predicts for a Count vectorizer model representation
 * captured by  {@link com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo}.
 */
public class CountVectorizerTransformer implements Transformer {
    private final CountVectorizerModelInfo modelInfo;
    private final Map vocabulary;

    public CountVectorizerTransformer(final CountVectorizerModelInfo modelInfo) {
        this.modelInfo = modelInfo;
        vocabulary = new HashMap();
        for (int i = 0; i < modelInfo.getVocabulary().length; i++) {
            vocabulary.put(modelInfo.getVocabulary()[i], i);
        }
    }

    double[] predict(final String[] input) {
        final Map termFrequencies = new HashMap();
        final int tokenCount = input.length;
        for (String term : input) {
            if (vocabulary.containsKey(term)) {
                if (termFrequencies.containsKey(term)) {
                    termFrequencies.put(term, termFrequencies.get(term) + 1);
                } else {
                    termFrequencies.put(term, 1);
                }
            } else {
                //ignore terms not in vocabulary
            }
        }
        final int effectiveMinTF = (int) ((modelInfo.getMinTF() >= 1.0) ? modelInfo.getMinTF() : modelInfo.getMinTF() * tokenCount);

        final double[] encoding = new double[modelInfo.getVocabulary().length];
        Arrays.fill(encoding, 0.0);

        for (final Map.Entry entry : termFrequencies.entrySet()) {
            //filter out terms with freq < effectiveMinTF
            if (entry.getValue() >= effectiveMinTF) {
                int position = vocabulary.get(entry.getKey());
                encoding[position] = entry.getValue();
            }
        }
        return encoding;
    }

    @Override
    public void transform(Map input) {
        String[] inp = (String[]) input.get(modelInfo.getInputKeys().iterator().next());
        input.put(modelInfo.getOutputKey(), predict(inp));
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy