
com.flipkart.fdp.ml.transformer.CountVectorizerTransformer Maven / Gradle / Ivy
package com.flipkart.fdp.ml.transformer;
import com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
/**
* Transforms input/ predicts for a Count vectorizer model representation
* captured by {@link com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo}.
*/
public class CountVectorizerTransformer implements Transformer {
private final CountVectorizerModelInfo modelInfo;
private final Map vocabulary;
public CountVectorizerTransformer(final CountVectorizerModelInfo modelInfo) {
this.modelInfo = modelInfo;
vocabulary = new HashMap();
for (int i = 0; i < modelInfo.getVocabulary().length; i++) {
vocabulary.put(modelInfo.getVocabulary()[i], i);
}
}
double[] predict(final String[] input) {
final Map termFrequencies = new HashMap();
final int tokenCount = input.length;
for (String term : input) {
if (vocabulary.containsKey(term)) {
if (termFrequencies.containsKey(term)) {
termFrequencies.put(term, termFrequencies.get(term) + 1);
} else {
termFrequencies.put(term, 1);
}
} else {
//ignore terms not in vocabulary
}
}
final int effectiveMinTF = (int) ((modelInfo.getMinTF() >= 1.0) ? modelInfo.getMinTF() : modelInfo.getMinTF() * tokenCount);
final double[] encoding = new double[modelInfo.getVocabulary().length];
Arrays.fill(encoding, 0.0);
for (final Map.Entry entry : termFrequencies.entrySet()) {
//filter out terms with freq < effectiveMinTF
if (entry.getValue() >= effectiveMinTF) {
int position = vocabulary.get(entry.getKey());
encoding[position] = entry.getValue();
}
}
return encoding;
}
@Override
public void transform(Map input) {
String[] inp = (String[]) input.get(modelInfo.getInputKeys().iterator().next());
input.put(modelInfo.getOutputKey(), predict(inp));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy