Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob Maven / Gradle / Ivy
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.hadoop.similarity.cooccurrence;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.mapreduce.VectorSumCombiner;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.map.OpenIntIntHashMap;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
public class RowSimilarityJob extends AbstractJob {
public static final double NO_THRESHOLD = Double.MIN_VALUE;
public static final long NO_FIXED_RANDOM_SEED = Long.MIN_VALUE;
private static final String SIMILARITY_CLASSNAME = RowSimilarityJob.class + ".distributedSimilarityClassname";
private static final String NUMBER_OF_COLUMNS = RowSimilarityJob.class + ".numberOfColumns";
private static final String MAX_SIMILARITIES_PER_ROW = RowSimilarityJob.class + ".maxSimilaritiesPerRow";
private static final String EXCLUDE_SELF_SIMILARITY = RowSimilarityJob.class + ".excludeSelfSimilarity";
private static final String THRESHOLD = RowSimilarityJob.class + ".threshold";
private static final String NORMS_PATH = RowSimilarityJob.class + ".normsPath";
private static final String MAXVALUES_PATH = RowSimilarityJob.class + ".maxWeightsPath";
private static final String NUM_NON_ZERO_ENTRIES_PATH = RowSimilarityJob.class + ".nonZeroEntriesPath";
private static final int DEFAULT_MAX_SIMILARITIES_PER_ROW = 100;
private static final String OBSERVATIONS_PER_COLUMN_PATH = RowSimilarityJob.class + ".observationsPerColumnPath";
private static final String MAX_OBSERVATIONS_PER_ROW = RowSimilarityJob.class + ".maxObservationsPerRow";
private static final String MAX_OBSERVATIONS_PER_COLUMN = RowSimilarityJob.class + ".maxObservationsPerColumn";
private static final String RANDOM_SEED = RowSimilarityJob.class + ".randomSeed";
private static final int DEFAULT_MAX_OBSERVATIONS_PER_ROW = 500;
private static final int DEFAULT_MAX_OBSERVATIONS_PER_COLUMN = 500;
private static final int NORM_VECTOR_MARKER = Integer.MIN_VALUE;
private static final int MAXVALUE_VECTOR_MARKER = Integer.MIN_VALUE + 1;
private static final int NUM_NON_ZERO_ENTRIES_VECTOR_MARKER = Integer.MIN_VALUE + 2;
enum Counters { ROWS, USED_OBSERVATIONS, NEGLECTED_OBSERVATIONS, COOCCURRENCES, PRUNED_COOCCURRENCES }
public static void main(String[] args) throws Exception {
ToolRunner.run(new RowSimilarityJob(), args);
}
@Override
public int run(String[] args) throws Exception {
addInputOption();
addOutputOption();
addOption("numberOfColumns", "r", "Number of columns in the input matrix", false);
addOption("similarityClassname", "s", "Name of distributed similarity class to instantiate, alternatively use "
+ "one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
addOption("maxSimilaritiesPerRow", "m", "Number of maximum similarities per row (default: "
+ DEFAULT_MAX_SIMILARITIES_PER_ROW + ')', String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ROW));
addOption("excludeSelfSimilarity", "ess", "compute similarity of rows to themselves?", String.valueOf(false));
addOption("threshold", "tr", "discard row pairs with a similarity value below this", false);
addOption("maxObservationsPerRow", null, "sample rows down to this number of entries",
String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_ROW));
addOption("maxObservationsPerColumn", null, "sample columns down to this number of entries",
String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_COLUMN));
addOption("randomSeed", null, "use this seed for sampling", false);
addOption(DefaultOptionCreator.overwriteOption().create());
Map> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
int numberOfColumns;
if (hasOption("numberOfColumns")) {
// Number of columns explicitly specified via CLI
numberOfColumns = Integer.parseInt(getOption("numberOfColumns"));
} else {
// else get the number of columns by determining the cardinality of a vector in the input matrix
numberOfColumns = getDimensions(getInputPath());
}
String similarityClassnameArg = getOption("similarityClassname");
String similarityClassname;
try {
similarityClassname = VectorSimilarityMeasures.valueOf(similarityClassnameArg).getClassname();
} catch (IllegalArgumentException iae) {
similarityClassname = similarityClassnameArg;
}
// Clear the output and temp paths if the overwrite option has been set
if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
// Clear the temp path
HadoopUtil.delete(getConf(), getTempPath());
// Clear the output path
HadoopUtil.delete(getConf(), getOutputPath());
}
int maxSimilaritiesPerRow = Integer.parseInt(getOption("maxSimilaritiesPerRow"));
boolean excludeSelfSimilarity = Boolean.parseBoolean(getOption("excludeSelfSimilarity"));
double threshold = hasOption("threshold")
? Double.parseDouble(getOption("threshold")) : NO_THRESHOLD;
long randomSeed = hasOption("randomSeed")
? Long.parseLong(getOption("randomSeed")) : NO_FIXED_RANDOM_SEED;
int maxObservationsPerRow = Integer.parseInt(getOption("maxObservationsPerRow"));
int maxObservationsPerColumn = Integer.parseInt(getOption("maxObservationsPerColumn"));
Path weightsPath = getTempPath("weights");
Path normsPath = getTempPath("norms.bin");
Path numNonZeroEntriesPath = getTempPath("numNonZeroEntries.bin");
Path maxValuesPath = getTempPath("maxValues.bin");
Path pairwiseSimilarityPath = getTempPath("pairwiseSimilarity");
Path observationsPerColumnPath = getTempPath("observationsPerColumn.bin");
AtomicInteger currentPhase = new AtomicInteger();
Job countObservations = prepareJob(getInputPath(), getTempPath("notUsed"), CountObservationsMapper.class,
NullWritable.class, VectorWritable.class, SumObservationsReducer.class, NullWritable.class,
VectorWritable.class);
countObservations.setCombinerClass(VectorSumCombiner.class);
countObservations.getConfiguration().set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
countObservations.setNumReduceTasks(1);
countObservations.waitForCompletion(true);
if (shouldRunNextPhase(parsedArgs, currentPhase)) {
Job normsAndTranspose = prepareJob(getInputPath(), weightsPath, VectorNormMapper.class, IntWritable.class,
VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
normsAndTranspose.setCombinerClass(MergeVectorsCombiner.class);
Configuration normsAndTransposeConf = normsAndTranspose.getConfiguration();
normsAndTransposeConf.set(THRESHOLD, String.valueOf(threshold));
normsAndTransposeConf.set(NORMS_PATH, normsPath.toString());
normsAndTransposeConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
normsAndTransposeConf.set(MAXVALUES_PATH, maxValuesPath.toString());
normsAndTransposeConf.set(SIMILARITY_CLASSNAME, similarityClassname);
normsAndTransposeConf.set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_ROW, String.valueOf(maxObservationsPerRow));
normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_COLUMN, String.valueOf(maxObservationsPerColumn));
normsAndTransposeConf.set(RANDOM_SEED, String.valueOf(randomSeed));
boolean succeeded = normsAndTranspose.waitForCompletion(true);
if (!succeeded) {
return -1;
}
}
if (shouldRunNextPhase(parsedArgs, currentPhase)) {
Job pairwiseSimilarity = prepareJob(weightsPath, pairwiseSimilarityPath, CooccurrencesMapper.class,
IntWritable.class, VectorWritable.class, SimilarityReducer.class, IntWritable.class, VectorWritable.class);
pairwiseSimilarity.setCombinerClass(VectorSumReducer.class);
Configuration pairwiseConf = pairwiseSimilarity.getConfiguration();
pairwiseConf.set(THRESHOLD, String.valueOf(threshold));
pairwiseConf.set(NORMS_PATH, normsPath.toString());
pairwiseConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
pairwiseConf.set(MAXVALUES_PATH, maxValuesPath.toString());
pairwiseConf.set(SIMILARITY_CLASSNAME, similarityClassname);
pairwiseConf.setInt(NUMBER_OF_COLUMNS, numberOfColumns);
pairwiseConf.setBoolean(EXCLUDE_SELF_SIMILARITY, excludeSelfSimilarity);
boolean succeeded = pairwiseSimilarity.waitForCompletion(true);
if (!succeeded) {
return -1;
}
}
if (shouldRunNextPhase(parsedArgs, currentPhase)) {
Job asMatrix = prepareJob(pairwiseSimilarityPath, getOutputPath(), UnsymmetrifyMapper.class,
IntWritable.class, VectorWritable.class, MergeToTopKSimilaritiesReducer.class, IntWritable.class,
VectorWritable.class);
asMatrix.setCombinerClass(MergeToTopKSimilaritiesReducer.class);
asMatrix.getConfiguration().setInt(MAX_SIMILARITIES_PER_ROW, maxSimilaritiesPerRow);
boolean succeeded = asMatrix.waitForCompletion(true);
if (!succeeded) {
return -1;
}
}
return 0;
}
public static class CountObservationsMapper extends Mapper {
private Vector columnCounts = new RandomAccessSparseVector(Integer.MAX_VALUE);
@Override
protected void map(IntWritable rowIndex, VectorWritable rowVectorWritable, Context ctx)
throws IOException, InterruptedException {
Vector row = rowVectorWritable.get();
for (Vector.Element elem : row.nonZeroes()) {
columnCounts.setQuick(elem.index(), columnCounts.getQuick(elem.index()) + 1);
}
}
@Override
protected void cleanup(Context ctx) throws IOException, InterruptedException {
ctx.write(NullWritable.get(), new VectorWritable(columnCounts));
}
}
public static class SumObservationsReducer extends Reducer {
@Override
protected void reduce(NullWritable nullWritable, Iterable partialVectors, Context ctx)
throws IOException, InterruptedException {
Vector counts = Vectors.sum(partialVectors.iterator());
Vectors.write(counts, new Path(ctx.getConfiguration().get(OBSERVATIONS_PER_COLUMN_PATH)), ctx.getConfiguration());
}
}
public static class VectorNormMapper extends Mapper {
private VectorSimilarityMeasure similarity;
private Vector norms;
private Vector nonZeroEntries;
private Vector maxValues;
private double threshold;
private OpenIntIntHashMap observationsPerColumn;
private int maxObservationsPerRow;
private int maxObservationsPerColumn;
private Random random;
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
Configuration conf = ctx.getConfiguration();
similarity = ClassUtils.instantiateAs(conf.get(SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
norms = new RandomAccessSparseVector(Integer.MAX_VALUE);
nonZeroEntries = new RandomAccessSparseVector(Integer.MAX_VALUE);
maxValues = new RandomAccessSparseVector(Integer.MAX_VALUE);
threshold = Double.parseDouble(conf.get(THRESHOLD));
observationsPerColumn = Vectors.readAsIntMap(new Path(conf.get(OBSERVATIONS_PER_COLUMN_PATH)), conf);
maxObservationsPerRow = conf.getInt(MAX_OBSERVATIONS_PER_ROW, DEFAULT_MAX_OBSERVATIONS_PER_ROW);
maxObservationsPerColumn = conf.getInt(MAX_OBSERVATIONS_PER_COLUMN, DEFAULT_MAX_OBSERVATIONS_PER_COLUMN);
long seed = Long.parseLong(conf.get(RANDOM_SEED));
if (seed == NO_FIXED_RANDOM_SEED) {
random = RandomUtils.getRandom();
} else {
random = RandomUtils.getRandom(seed);
}
}
private Vector sampleDown(Vector rowVector, Context ctx) {
int observationsPerRow = rowVector.getNumNondefaultElements();
double rowSampleRate = (double) Math.min(maxObservationsPerRow, observationsPerRow) / (double) observationsPerRow;
Vector downsampledRow = rowVector.like();
long usedObservations = 0;
long neglectedObservations = 0;
for (Vector.Element elem : rowVector.nonZeroes()) {
int columnCount = observationsPerColumn.get(elem.index());
double columnSampleRate = (double) Math.min(maxObservationsPerColumn, columnCount) / (double) columnCount;
if (random.nextDouble() <= Math.min(rowSampleRate, columnSampleRate)) {
downsampledRow.setQuick(elem.index(), elem.get());
usedObservations++;
} else {
neglectedObservations++;
}
}
ctx.getCounter(Counters.USED_OBSERVATIONS).increment(usedObservations);
ctx.getCounter(Counters.NEGLECTED_OBSERVATIONS).increment(neglectedObservations);
return downsampledRow;
}
@Override
protected void map(IntWritable row, VectorWritable vectorWritable, Context ctx)
throws IOException, InterruptedException {
Vector sampledRowVector = sampleDown(vectorWritable.get(), ctx);
Vector rowVector = similarity.normalize(sampledRowVector);
int numNonZeroEntries = 0;
double maxValue = Double.MIN_VALUE;
for (Vector.Element element : rowVector.nonZeroes()) {
RandomAccessSparseVector partialColumnVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
partialColumnVector.setQuick(row.get(), element.get());
ctx.write(new IntWritable(element.index()), new VectorWritable(partialColumnVector));
numNonZeroEntries++;
if (maxValue < element.get()) {
maxValue = element.get();
}
}
if (threshold != NO_THRESHOLD) {
nonZeroEntries.setQuick(row.get(), numNonZeroEntries);
maxValues.setQuick(row.get(), maxValue);
}
norms.setQuick(row.get(), similarity.norm(rowVector));
ctx.getCounter(Counters.ROWS).increment(1);
}
@Override
protected void cleanup(Context ctx) throws IOException, InterruptedException {
ctx.write(new IntWritable(NORM_VECTOR_MARKER), new VectorWritable(norms));
ctx.write(new IntWritable(NUM_NON_ZERO_ENTRIES_VECTOR_MARKER), new VectorWritable(nonZeroEntries));
ctx.write(new IntWritable(MAXVALUE_VECTOR_MARKER), new VectorWritable(maxValues));
}
}
private static class MergeVectorsCombiner extends Reducer {
@Override
protected void reduce(IntWritable row, Iterable partialVectors, Context ctx)
throws IOException, InterruptedException {
ctx.write(row, new VectorWritable(Vectors.merge(partialVectors)));
}
}
public static class MergeVectorsReducer extends Reducer {
private Path normsPath;
private Path numNonZeroEntriesPath;
private Path maxValuesPath;
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
normsPath = new Path(ctx.getConfiguration().get(NORMS_PATH));
numNonZeroEntriesPath = new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH));
maxValuesPath = new Path(ctx.getConfiguration().get(MAXVALUES_PATH));
}
@Override
protected void reduce(IntWritable row, Iterable partialVectors, Context ctx)
throws IOException, InterruptedException {
Vector partialVector = Vectors.merge(partialVectors);
if (row.get() == NORM_VECTOR_MARKER) {
Vectors.write(partialVector, normsPath, ctx.getConfiguration());
} else if (row.get() == MAXVALUE_VECTOR_MARKER) {
Vectors.write(partialVector, maxValuesPath, ctx.getConfiguration());
} else if (row.get() == NUM_NON_ZERO_ENTRIES_VECTOR_MARKER) {
Vectors.write(partialVector, numNonZeroEntriesPath, ctx.getConfiguration(), true);
} else {
ctx.write(row, new VectorWritable(partialVector));
}
}
}
public static class CooccurrencesMapper extends Mapper {
private VectorSimilarityMeasure similarity;
private OpenIntIntHashMap numNonZeroEntries;
private Vector maxValues;
private double threshold;
private static final Comparator BY_INDEX = new Comparator() {
@Override
public int compare(Vector.Element one, Vector.Element two) {
return Ints.compare(one.index(), two.index());
}
};
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
VectorSimilarityMeasure.class);
numNonZeroEntries = Vectors.readAsIntMap(new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH)),
ctx.getConfiguration());
maxValues = Vectors.read(new Path(ctx.getConfiguration().get(MAXVALUES_PATH)), ctx.getConfiguration());
threshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
}
private boolean consider(Vector.Element occurrenceA, Vector.Element occurrenceB) {
int numNonZeroEntriesA = numNonZeroEntries.get(occurrenceA.index());
int numNonZeroEntriesB = numNonZeroEntries.get(occurrenceB.index());
double maxValueA = maxValues.get(occurrenceA.index());
double maxValueB = maxValues.get(occurrenceB.index());
return similarity.consider(numNonZeroEntriesA, numNonZeroEntriesB, maxValueA, maxValueB, threshold);
}
@Override
protected void map(IntWritable column, VectorWritable occurrenceVector, Context ctx)
throws IOException, InterruptedException {
Vector.Element[] occurrences = Vectors.toArray(occurrenceVector);
Arrays.sort(occurrences, BY_INDEX);
int cooccurrences = 0;
int prunedCooccurrences = 0;
for (int n = 0; n < occurrences.length; n++) {
Vector.Element occurrenceA = occurrences[n];
Vector dots = new RandomAccessSparseVector(Integer.MAX_VALUE);
for (int m = n; m < occurrences.length; m++) {
Vector.Element occurrenceB = occurrences[m];
if (threshold == NO_THRESHOLD || consider(occurrenceA, occurrenceB)) {
dots.setQuick(occurrenceB.index(), similarity.aggregate(occurrenceA.get(), occurrenceB.get()));
cooccurrences++;
} else {
prunedCooccurrences++;
}
}
ctx.write(new IntWritable(occurrenceA.index()), new VectorWritable(dots));
}
ctx.getCounter(Counters.COOCCURRENCES).increment(cooccurrences);
ctx.getCounter(Counters.PRUNED_COOCCURRENCES).increment(prunedCooccurrences);
}
}
public static class SimilarityReducer extends Reducer {
private VectorSimilarityMeasure similarity;
private int numberOfColumns;
private boolean excludeSelfSimilarity;
private Vector norms;
private double treshold;
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
VectorSimilarityMeasure.class);
numberOfColumns = ctx.getConfiguration().getInt(NUMBER_OF_COLUMNS, -1);
Preconditions.checkArgument(numberOfColumns > 0, "Number of columns must be greater then 0! But numberOfColumns = " + numberOfColumns);
excludeSelfSimilarity = ctx.getConfiguration().getBoolean(EXCLUDE_SELF_SIMILARITY, false);
norms = Vectors.read(new Path(ctx.getConfiguration().get(NORMS_PATH)), ctx.getConfiguration());
treshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
}
@Override
protected void reduce(IntWritable row, Iterable partialDots, Context ctx)
throws IOException, InterruptedException {
Iterator partialDotsIterator = partialDots.iterator();
Vector dots = partialDotsIterator.next().get();
while (partialDotsIterator.hasNext()) {
Vector toAdd = partialDotsIterator.next().get();
for (Element nonZeroElement : toAdd.nonZeroes()) {
dots.setQuick(nonZeroElement.index(), dots.getQuick(nonZeroElement.index()) + nonZeroElement.get());
}
}
Vector similarities = dots.like();
double normA = norms.getQuick(row.get());
for (Element b : dots.nonZeroes()) {
double similarityValue = similarity.similarity(b.get(), normA, norms.getQuick(b.index()), numberOfColumns);
if (similarityValue >= treshold) {
similarities.set(b.index(), similarityValue);
}
}
if (excludeSelfSimilarity) {
similarities.setQuick(row.get(), 0);
}
ctx.write(row, new VectorWritable(similarities));
}
}
public static class UnsymmetrifyMapper extends Mapper {
private int maxSimilaritiesPerRow;
@Override
protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
}
@Override
protected void map(IntWritable row, VectorWritable similaritiesWritable, Context ctx)
throws IOException, InterruptedException {
Vector similarities = similaritiesWritable.get();
// For performance, the creation of transposedPartial is moved out of the while loop and it is reused inside
Vector transposedPartial = new RandomAccessSparseVector(similarities.size(), 1);
TopElementsQueue topKQueue = new TopElementsQueue(maxSimilaritiesPerRow);
for (Element nonZeroElement : similarities.nonZeroes()) {
MutableElement top = topKQueue.top();
double candidateValue = nonZeroElement.get();
if (candidateValue > top.get()) {
top.setIndex(nonZeroElement.index());
top.set(candidateValue);
topKQueue.updateTop();
}
transposedPartial.setQuick(row.get(), candidateValue);
ctx.write(new IntWritable(nonZeroElement.index()), new VectorWritable(transposedPartial));
transposedPartial.setQuick(row.get(), 0.0);
}
Vector topKSimilarities = new RandomAccessSparseVector(similarities.size(), maxSimilaritiesPerRow);
for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
}
ctx.write(row, new VectorWritable(topKSimilarities));
}
}
public static class MergeToTopKSimilaritiesReducer
extends Reducer {
private int maxSimilaritiesPerRow;
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
}
@Override
protected void reduce(IntWritable row, Iterable partials, Context ctx)
throws IOException, InterruptedException {
Vector allSimilarities = Vectors.merge(partials);
Vector topKSimilarities = Vectors.topKElements(maxSimilaritiesPerRow, allSimilarities);
ctx.write(row, new VectorWritable(topKSimilarities));
}
}
}