All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.batch.nlp.DocHashCountVectorizerTrainBatchOp Maven / Gradle / Ivy

package com.alibaba.alink.operator.batch.nlp;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.nlp.DocHashCountVectorizerModelData;
import com.alibaba.alink.operator.common.nlp.DocHashCountVectorizerModelDataConverter;
import com.alibaba.alink.operator.common.nlp.NLPConstant;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.nlp.DocHashCountVectorizerTrainParams;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

import java.util.HashMap;

import static org.apache.flink.shaded.guava18.com.google.common.hash.Hashing.murmur3_32;

/**
 * Hash every word as a number, and save the inverse document frequency(IDF) of every word in the document.
 * 

* It's used together with DocHashCountVectorizerPredictBatchOp. */ public class DocHashCountVectorizerTrainBatchOp extends BatchOperator implements DocHashCountVectorizerTrainParams { public DocHashCountVectorizerTrainBatchOp() { super(new Params()); } public DocHashCountVectorizerTrainBatchOp(Params params) { super(params); } @Override public DocHashCountVectorizerTrainBatchOp linkFrom(BatchOperator... inputs) { BatchOperator in = checkAndGetFirst(inputs); int index = TableUtil.findColIndex(in.getColNames(), this.getSelectedCol()); if (index < 0) { throw new RuntimeException("Can not find column: " + this.getSelectedCol()); } DataSet out = in .getDataSet() .mapPartition(new HashingTF(index, this.getNumFeatures())) .reduce(new ReduceFunction>>() { @Override public Tuple2> reduce(Tuple2> map1, Tuple2> map2) { map2.f1.forEach((k2, v1) -> map1.f1.merge(k2, v1, Double::sum)); map1.f0 += map2.f0; return map1; } }).flatMap(new BuildModel(getParams())); this.setOutput(out, new DocHashCountVectorizerModelDataConverter().getModelSchema()); return this; } /** * The dense vector contains the document frequency(DF) of words. Calculate the IDF, and replace the original DF * value. The dense vector is saved into DocHashIDFVectorizerModel. */ static class BuildModel implements FlatMapFunction>, Row> { private double minDocFrequency; private int numFeatures; private String featureType; private double minTF; public BuildModel(Params params) { this.minDocFrequency = params.get(DocHashCountVectorizerTrainParams.MIN_DF); this.numFeatures = params.get(DocHashCountVectorizerTrainParams.NUM_FEATURES); this.featureType = params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE); this.minTF = params.get(DocHashCountVectorizerTrainParams.MIN_TF); } @Override public void flatMap(Tuple2> vec, Collector collector) throws Exception { long cnt = vec.f0; minDocFrequency = minDocFrequency >= 1.0 ? minDocFrequency : minDocFrequency * cnt; for (int key : vec.f1.keySet()) { vec.f1.compute(key, (k, v) -> { if (v >= minDocFrequency) { return Math.log((cnt + 1.0) / (v + 1.0)); } else { return null; } }); } DocHashCountVectorizerModelData model = new DocHashCountVectorizerModelData(); model.numFeatures = numFeatures; model.minTF = minTF; model.featureType = featureType; model.idfMap = vec.f1; new DocHashCountVectorizerModelDataConverter().save(model, collector); } } /** * Transform a word to a number by using MurMurHash3. The number is used as the index of a dense vector. The value * of the vector is document frequency of the corresponding word. */ static class HashingTF implements MapPartitionFunction>> { private int index, numFeatures; private static final HashFunction HASH = murmur3_32(0); public HashingTF(int index, int numFeatures) { this.index = index; this.numFeatures = numFeatures; } @Override public void mapPartition(Iterable iterable, Collector>> collector) throws Exception { HashMap map = new HashMap<>(numFeatures); long count = 0; for (Row row : iterable) { count++; String content = (String)row.getField(index); String[] words = content.split(NLPConstant.WORD_DELIMITER); for (String word : words) { int hashValue = Math.abs(HASH.hashUnencodedChars(word).asInt()); int index = Math.floorMod(hashValue, numFeatures); map.merge(index, 1.0, Double::sum); } } collector.collect(Tuple2.of(count, map)); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy