com.davidbracewell.apollo.ml.preprocess.transform.TFIDFTransform Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of apollo Show documentation
Show all versions of apollo Show documentation
A machine learning library for Java.
The newest version!
package com.davidbracewell.apollo.ml.preprocess.transform;
import com.davidbracewell.apollo.ml.Feature;
import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.preprocess.RestrictedInstancePreprocessor;
import com.davidbracewell.collection.counter.Counter;
import com.davidbracewell.collection.counter.Counters;
import com.davidbracewell.conversion.Val;
import com.davidbracewell.json.JsonReader;
import com.davidbracewell.json.JsonTokenType;
import com.davidbracewell.json.JsonWriter;
import com.davidbracewell.stream.MStream;
import com.davidbracewell.stream.accumulator.MDoubleAccumulator;
import com.davidbracewell.string.StringUtils;
import com.davidbracewell.tuple.Tuple2;
import lombok.NonNull;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
/**
* Transform values using Tf-idf
*
* @author David B. Bracewell
*/
public class TFIDFTransform extends RestrictedInstancePreprocessor implements TransformProcessor, Serializable {
private static final long serialVersionUID = 1L;
private volatile Counter documentFrequencies = Counters.newCounter();
private volatile double totalDocs = 0;
/**
* Instantiates a new Tfidf transform.
*/
public TFIDFTransform() {
super(StringUtils.EMPTY);
}
/**
* Instantiates a new Tfidf transform.
*
* @param featureNamePrefix the feature name prefix
*/
public TFIDFTransform(String featureNamePrefix) {
super(featureNamePrefix);
}
@Override
public String describe() {
if (applyToAll()) {
return "TFIDFTransform{totalDocuments=" + totalDocs + ", vocabSize=" + documentFrequencies.size() + "}";
}
return "TFIDFTransform[" + getRestriction() + "]{totalDocuments=" + totalDocs + ", vocabSize=" + documentFrequencies
.size() + "}";
}
@Override
public void fromJson(@NonNull JsonReader reader) throws IOException {
reset();
while (reader.peek() != JsonTokenType.END_OBJECT) {
switch (reader.peekName()) {
case "restriction":
setRestriction(reader.nextKeyValue().v2.asString());
break;
case "totalDocuments":
this.totalDocs = reader.nextKeyValue().v2.asDoubleValue();
break;
case "documentCounts":
reader.beginObject();
while (reader.peek() != JsonTokenType.END_OBJECT) {
Tuple2 kv = reader.nextKeyValue();
documentFrequencies.set(kv.getKey(), kv.getValue().asDoubleValue());
}
reader.endObject();
break;
}
}
}
@Override
public void reset() {
totalDocs = 0;
documentFrequencies.clear();
}
@Override
protected void restrictedFitImpl(MStream> stream) {
MDoubleAccumulator docCount = stream.getContext().doubleAccumulator(0d);
this.documentFrequencies.merge(stream.flatMap(instance -> {
docCount.add(1d);
return instance.stream().map(Feature::getFeatureName).distinct();
}
).countByValue()
);
this.totalDocs = docCount.value();
}
@Override
protected Stream restrictedProcessImpl(Stream featureStream, Instance originalExample) {
double dSum = originalExample.getFeatures().stream().mapToDouble(Feature::getValue).sum();
return featureStream.map(f -> {
double value = f.getValue() / dSum * Math.log(totalDocs / documentFrequencies.get(f.getFeatureName()));
if (value != 0) {
return Feature.real(f.getFeatureName(), value);
}
return null;
}
)
.filter(Objects::nonNull);
}
@Override
public void toJson(@NonNull JsonWriter writer) throws IOException {
if (!applyToAll()) {
writer.property("restriction", getRestriction());
}
writer.property("totalDocuments", totalDocs);
writer.beginObject("documentCounts");
for (Map.Entry entry : documentFrequencies.entries()) {
writer.property(entry.getKey(), entry.getValue());
}
writer.endObject();
}
}// END OF TFIDFTransform
© 2015 - 2025 Weber Informatics LLC | Privacy Policy