All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.preprocess.transform.TFIDFTransform Maven / Gradle / Ivy

The newest version!
package com.davidbracewell.apollo.ml.preprocess.transform;

import com.davidbracewell.apollo.ml.Feature;
import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.preprocess.RestrictedInstancePreprocessor;
import com.davidbracewell.collection.counter.Counter;
import com.davidbracewell.collection.counter.Counters;
import com.davidbracewell.conversion.Val;
import com.davidbracewell.json.JsonReader;
import com.davidbracewell.json.JsonTokenType;
import com.davidbracewell.json.JsonWriter;
import com.davidbracewell.stream.MStream;
import com.davidbracewell.stream.accumulator.MDoubleAccumulator;
import com.davidbracewell.string.StringUtils;
import com.davidbracewell.tuple.Tuple2;
import lombok.NonNull;

import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

/**
 * 

Transform values using Tf-idf

* * @author David B. Bracewell */ public class TFIDFTransform extends RestrictedInstancePreprocessor implements TransformProcessor, Serializable { private static final long serialVersionUID = 1L; private volatile Counter documentFrequencies = Counters.newCounter(); private volatile double totalDocs = 0; /** * Instantiates a new Tfidf transform. */ public TFIDFTransform() { super(StringUtils.EMPTY); } /** * Instantiates a new Tfidf transform. * * @param featureNamePrefix the feature name prefix */ public TFIDFTransform(String featureNamePrefix) { super(featureNamePrefix); } @Override public String describe() { if (applyToAll()) { return "TFIDFTransform{totalDocuments=" + totalDocs + ", vocabSize=" + documentFrequencies.size() + "}"; } return "TFIDFTransform[" + getRestriction() + "]{totalDocuments=" + totalDocs + ", vocabSize=" + documentFrequencies .size() + "}"; } @Override public void fromJson(@NonNull JsonReader reader) throws IOException { reset(); while (reader.peek() != JsonTokenType.END_OBJECT) { switch (reader.peekName()) { case "restriction": setRestriction(reader.nextKeyValue().v2.asString()); break; case "totalDocuments": this.totalDocs = reader.nextKeyValue().v2.asDoubleValue(); break; case "documentCounts": reader.beginObject(); while (reader.peek() != JsonTokenType.END_OBJECT) { Tuple2 kv = reader.nextKeyValue(); documentFrequencies.set(kv.getKey(), kv.getValue().asDoubleValue()); } reader.endObject(); break; } } } @Override public void reset() { totalDocs = 0; documentFrequencies.clear(); } @Override protected void restrictedFitImpl(MStream> stream) { MDoubleAccumulator docCount = stream.getContext().doubleAccumulator(0d); this.documentFrequencies.merge(stream.flatMap(instance -> { docCount.add(1d); return instance.stream().map(Feature::getFeatureName).distinct(); } ).countByValue() ); this.totalDocs = docCount.value(); } @Override protected Stream restrictedProcessImpl(Stream featureStream, Instance originalExample) { double dSum = originalExample.getFeatures().stream().mapToDouble(Feature::getValue).sum(); return featureStream.map(f -> { double value = f.getValue() / dSum * Math.log(totalDocs / documentFrequencies.get(f.getFeatureName())); if (value != 0) { return Feature.real(f.getFeatureName(), value); } return null; } ) .filter(Objects::nonNull); } @Override public void toJson(@NonNull JsonWriter writer) throws IOException { if (!applyToAll()) { writer.property("restriction", getRestriction()); } writer.property("totalDocuments", totalDocs); writer.beginObject("documentCounts"); for (Map.Entry entry : documentFrequencies.entries()) { writer.property(entry.getKey(), entry.getValue()); } writer.endObject(); } }// END OF TFIDFTransform




© 2015 - 2025 Weber Informatics LLC | Privacy Policy