
org.apache.flink.ml.feature.lsh.LSHModel Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.feature.lsh;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.typeinfo.PriorityQueueTypeInfo;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.commons.lang3.ArrayUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
/**
* Base class for LSH model.
*
* In addition to transforming input feature vectors to multiple hash values, it also supports
* approximate nearest neighbors search within a dataset regarding a key vector and approximate
* similarity join between two datasets.
*
* @param class type of the LSHModel implementation itself.
*/
abstract class LSHModel> implements Model, LSHModelParams {
private static final String MODEL_DATA_BC_KEY = "modelData";
private final Map, Object> paramMap = new HashMap<>();
/** Stores the corresponding model data class of T. */
private final Class extends LSHModelData> modelDataClass;
protected Table modelDataTable;
public LSHModel(Class extends LSHModelData> modelDataClass) {
this.modelDataClass = modelDataClass;
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
public T setModelData(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
modelDataTable = inputs[0];
return (T) this;
}
@Override
public Table[] getModelData() {
return new Table[] {modelDataTable};
}
@Override
public Map, Object> getParamMap() {
return paramMap;
}
@Override
public Table[] transform(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
DataStream extends LSHModelData> modelData =
tEnv.toDataStream(modelDataTable, modelDataClass);
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
TypeInformation> outputType = TypeInformation.of(DenseVector[].class);
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType),
ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
DataStream output =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(tEnv.toDataStream(inputs[0])),
Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
inputList -> {
//noinspection unchecked
DataStream data = (DataStream) inputList.get(0);
return data.map(new PredictFunction(getInputCol()), outputTypeInfo);
});
return new Table[] {tEnv.fromDataStream(output)};
}
/**
* Approximately finds at most k items from a dataset which have the closest distance to a given
* item. If the `outputCol` is missing in the given dataset, this method transforms the dataset
* with the model at first.
*
* @param dataset The dataset in which to to search for nearest neighbors.
* @param key The item to search for.
* @param k The maximum number of nearest neighbors.
* @param distCol The output column storing the distance between each neighbor and the key.
* @return A dataset containing at most k items closest to the key with a column named `distCol`
* appended.
*/
public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment();
Table transformedTable =
(dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
? dataset
: transform(dataset)[0];
DataStream extends LSHModelData> modelData =
tEnv.toDataStream(modelDataTable, modelDataClass);
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE),
ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol));
// Fetches items in the same bucket with key's, and calculates their distances to key.
DataStream filteredData =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(tEnv.toDataStream(transformedTable)),
Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
inputList -> {
//noinspection unchecked
DataStream data = (DataStream) inputList.get(0);
return data.flatMap(
new FilterByBucketFunction(getInputCol(), getOutputCol(), key),
outputTypeInfo);
});
TopKFunction topKFunction = new TopKFunction(distCol, k);
DataStream> topKList =
DataStreamUtils.aggregate(
filteredData,
topKFunction,
new PriorityQueueTypeInfo(topKFunction.getComparator(), outputTypeInfo),
Types.LIST(outputTypeInfo));
DataStream topKData =
topKList.flatMap(
(value, out) -> {
for (Row row : value) {
out.collect(row);
}
});
topKData.getTransformation().setOutputType(outputTypeInfo);
return tEnv.fromDataStream(topKData);
}
/**
* An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
* `distCol`.
*/
public Table approxNearestNeighbors(Table dataset, Vector key, int k) {
return approxNearestNeighbors(dataset, key, k, "distCol");
}
/**
* Joins two datasets to approximately find all pairs of rows whose distance are smaller than or
* equal to the threshold. If the `outputCol` is missing in either dataset, this method
* transforms the dataset at first.
*
* @param datasetA One dataset.
* @param datasetB The other dataset.
* @param threshold The distance threshold.
* @param idCol A column in the two datasets to identify each row.
* @param distCol The output column storing the distance between each pair of rows.
* @return A joined dataset containing pairs of rows. The original rows are in columns
* "datasetA" and "datasetB", and a column "distCol" is added to show the distance between
* each pair.
*/
public Table approxSimilarityJoin(
Table datasetA, Table datasetB, double threshold, String idCol, String distCol) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) datasetA).getTableEnvironment();
DataStream explodedA = preprocessData(datasetA, idCol);
DataStream explodedB = preprocessData(datasetB, idCol);
RowTypeInfo inputTypeInfo = getOutputType(datasetA, idCol);
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
inputTypeInfo.getTypeAt(0),
inputTypeInfo.getTypeAt(0),
inputTypeInfo.getTypeAt(1),
inputTypeInfo.getTypeAt(1));
DataStream extends LSHModelData> modelData =
tEnv.toDataStream(modelDataTable, modelDataClass);
DataStream sameBucketPairs =
explodedA
.join(explodedB)
.where(new IndexHashValueKeySelector())
.equalTo(new IndexHashValueKeySelector())
.window(EndOfStreamWindows.get())
.apply(
(r0, r1) ->
Row.of(
r0.getField(0),
r1.getField(0),
r0.getField(1),
r1.getField(1)),
outputTypeInfo);
DataStream distinctSameBucketPairs =
DataStreamUtils.reduce(
sameBucketPairs.keyBy(
new KeySelector>() {
@Override
public Tuple2 getKey(Row r) {
return Tuple2.of(r.getFieldAs(0), r.getFieldAs(1));
}
}),
(r0, r1) -> r0,
outputTypeInfo);
TypeInformation> idColType =
TableUtils.getRowTypeInfo(datasetA.getResolvedSchema()).getTypeAt(idCol);
DataStream pairsWithDists =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(distinctSameBucketPairs),
Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
inputList -> {
DataStream data = (DataStream) inputList.get(0);
return data.flatMap(
new FilterByDistanceFunction(threshold),
new RowTypeInfo(
new TypeInformation[] {
idColType, idColType, Types.DOUBLE
},
new String[] {"datasetA.id", "datasetB.id", distCol}));
});
return tEnv.fromDataStream(pairsWithDists);
}
/**
* An overloaded version of `approxNearestNeighbors` with "distCol" as default value of
* `distCol`.
*/
public Table approxSimilarityJoin(
Table datasetA, Table datasetB, double threshold, String idCol) {
return approxSimilarityJoin(datasetA, datasetB, threshold, idCol, "distCol");
}
private DataStream preprocessData(Table dataTable, String idCol) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment();
dataTable =
(dataTable.getResolvedSchema().getColumnNames().contains(getOutputCol()))
? dataTable
: transform(dataTable)[0];
RowTypeInfo outputTypeInfo = getOutputType(dataTable, idCol);
return tEnv.toDataStream(dataTable)
.flatMap(
new ExplodeHashValuesFunction(idCol, getInputCol(), getOutputCol()),
outputTypeInfo);
}
private RowTypeInfo getOutputType(Table dataTable, String idCol) {
final String indexCol = "index";
final String hashValueCol = "hashValue";
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema());
TypeInformation> idColType = inputTypeInfo.getTypeAt(idCol);
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
new TypeInformation[] {
idColType,
VectorTypeInfo.INSTANCE,
Types.INT,
DenseVectorTypeInfo.INSTANCE
},
new String[] {idCol, getInputCol(), indexCol, hashValueCol});
return outputTypeInfo;
}
private static class PredictFunction extends RichMapFunction {
private final String inputCol;
private LSHModelData modelData;
public PredictFunction(String inputCol) {
this.inputCol = inputCol;
}
@Override
public Row map(Row value) throws Exception {
if (null == modelData) {
modelData =
(LSHModelData)
getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0);
}
Vector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol));
return Row.join(value, Row.of((Object) hashValues));
}
}
private static class FilterByBucketFunction extends RichFlatMapFunction {
private final String inputCol;
private final String outputCol;
private final Vector key;
private LSHModelData modelData;
private DenseVector[] keyHashes;
public FilterByBucketFunction(String inputCol, String outputCol, Vector key) {
this.inputCol = inputCol;
this.outputCol = outputCol;
this.key = key;
}
@Override
public void flatMap(Row value, Collector out) throws Exception {
if (null == modelData) {
modelData =
(LSHModelData)
getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0);
keyHashes = modelData.hashFunction(key);
}
DenseVector[] hashes = value.getFieldAs(outputCol);
boolean sameBucket = false;
for (int i = 0; i < keyHashes.length; i += 1) {
if (keyHashes[i].equals(hashes[i])) {
sameBucket = true;
break;
}
}
if (!sameBucket) {
return;
}
Vector vec = value.getFieldAs(inputCol);
double dist = modelData.keyDistance(key, vec);
out.collect(Row.join(value, Row.of(dist)));
}
}
private static class TopKFunction
implements AggregateFunction, List> {
private final int numNearestNeighbors;
private final String distCol;
private static class DistColComparator implements Comparator, Serializable {
private final String distCol;
private DistColComparator(String distCol) {
this.distCol = distCol;
}
@Override
public int compare(Row o1, Row o2) {
return Double.compare(o1.getFieldAs(distCol), o2.getFieldAs(distCol));
}
}
public TopKFunction(String distCol, int numNearestNeighbors) {
this.distCol = distCol;
this.numNearestNeighbors = numNearestNeighbors;
}
@Override
public PriorityQueue createAccumulator() {
return new PriorityQueue<>(numNearestNeighbors, getComparator());
}
@Override
public PriorityQueue add(Row value, PriorityQueue accumulator) {
if (accumulator.size() == numNearestNeighbors) {
Row peek = accumulator.peek();
if (accumulator.comparator().compare(value, peek) < 0) {
accumulator.poll();
}
}
accumulator.add(value);
return accumulator;
}
@Override
public List getResult(PriorityQueue accumulator) {
return new ArrayList<>(accumulator);
}
@Override
public PriorityQueue merge(PriorityQueue a, PriorityQueue b) {
PriorityQueue merged = new PriorityQueue<>(a);
for (Row row : b) {
add(row, merged);
}
return merged;
}
private Comparator getComparator() {
return new DistColComparator(distCol);
}
}
private static class ExplodeHashValuesFunction implements FlatMapFunction {
private final String idCol;
private final String inputCol;
private final String outputCol;
public ExplodeHashValuesFunction(String idCol, String inputCol, String outputCol) {
this.idCol = idCol;
this.inputCol = inputCol;
this.outputCol = outputCol;
}
@Override
public void flatMap(Row value, Collector out) throws Exception {
Row kept = Row.of(value.getField(idCol), value.getField(inputCol));
DenseVector[] hashValues = value.getFieldAs(outputCol);
for (int i = 0; i < hashValues.length; i += 1) {
out.collect(Row.join(kept, Row.of(i, hashValues[i])));
}
}
}
private static class IndexHashValueKeySelector
implements KeySelector> {
@Override
public Tuple2 getKey(Row value) throws Exception {
return Tuple2.of(value.getFieldAs(2), value.getFieldAs(3));
}
}
private static class FilterByDistanceFunction extends RichFlatMapFunction {
private final double threshold;
private LSHModelData modelData;
public FilterByDistanceFunction(double threshold) {
this.threshold = threshold;
}
@Override
public void flatMap(Row value, Collector out) throws Exception {
if (null == modelData) {
modelData =
(LSHModelData)
getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0);
}
double dist = modelData.keyDistance(value.getFieldAs(2), value.getFieldAs(3));
if (dist <= threshold) {
out.collect(Row.of(value.getFieldAs(0), value.getFieldAs(1), dist));
}
}
}
}