All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.ml.feature.lsh.MinHashLSHModelData Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.ml.feature.lsh;

import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Preconditions;

import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Random;

/**
 * Model data of {@link MinHashLSHModel}.
 *
 * 

This class also provides classes to save/load model data. */ public class MinHashLSHModelData extends LSHModelData { // A large prime smaller than sqrt(2^63 − 1). private static final int HASH_PRIME = 2038074743; public int numHashTables; public int numHashFunctionsPerTable; public int[] randCoefficientA; public int[] randCoefficientB; public MinHashLSHModelData() {} public MinHashLSHModelData( int numHashTables, int numHashFunctionsPerTable, int[] randCoefficientA, int[] randCoefficientB) { this.numHashTables = numHashTables; this.numHashFunctionsPerTable = numHashFunctionsPerTable; this.randCoefficientA = randCoefficientA; this.randCoefficientB = randCoefficientB; } public static MinHashLSHModelData generateModelData( int numHashTables, int numHashFunctionsPerTable, int dim, long seed) { Preconditions.checkArgument( dim <= HASH_PRIME, "The input vector dimension %d exceeds the threshold %s.", dim, HASH_PRIME); Random random = new Random(seed); int numHashFunctions = numHashTables * numHashFunctionsPerTable; int[] randCoeffA = new int[numHashFunctions]; int[] randCoeffB = new int[numHashFunctions]; for (int i = 0; i < numHashFunctions; i += 1) { randCoeffA[i] = 1 + random.nextInt(HASH_PRIME - 1); randCoeffB[i] = random.nextInt(HASH_PRIME - 1); } return new MinHashLSHModelData( numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB); } static class ModelDataDecoder extends SimpleStreamFormat { @Override public Reader createReader( Configuration configuration, FSDataInputStream fsDataInputStream) throws IOException { return new Reader() { @Override public MinHashLSHModelData read() throws IOException { try { DataInputViewStreamWrapper source = new DataInputViewStreamWrapper(fsDataInputStream); int numHashTables = IntSerializer.INSTANCE.deserialize(source); int numHashFunctionsPerTable = IntSerializer.INSTANCE.deserialize(source); int[] randCoeffA = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); int[] randCoeffB = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); return new MinHashLSHModelData( numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB); } catch (EOFException e) { return null; } } @Override public void close() throws IOException { fsDataInputStream.close(); } }; } @Override public TypeInformation getProducedType() { return TypeInformation.of(MinHashLSHModelData.class); } } @Override public DenseVector[] hashFunction(Vector vec) { int[] indices = vec.toSparse().indices; Preconditions.checkArgument(indices.length > 0, "Must have at least 1 non zero entry."); double[][] hashValues = new double[numHashTables][numHashFunctionsPerTable]; for (int i = 0; i < numHashTables; i += 1) { for (int j = 0; j < numHashFunctionsPerTable; j += 1) { // For each hash function, the hash value is computed by // min(((1 + index) * randCoefficientA + randCoefficientB) % HASH_PRIME). int coeffA = randCoefficientA[i * numHashFunctionsPerTable + j]; int coeffB = randCoefficientB[i * numHashFunctionsPerTable + j]; long minv = HASH_PRIME; for (int index : indices) { minv = Math.min(minv, ((1L + index) * coeffA + coeffB) % HASH_PRIME); } hashValues[i][j] = minv; } } return Arrays.stream(hashValues).map(DenseVector::new).toArray(DenseVector[]::new); } @Override public double keyDistance(Vector x, Vector y) { int[] xIndices = x.toSparse().indices; int[] yIndices = y.toSparse().indices; Preconditions.checkArgument( xIndices.length + yIndices.length > 0, "The union of two input sets must have at least 1 elements"); int px = 0, py = 0; int intersectionSize = 0; while (px < xIndices.length && py < yIndices.length) { if (xIndices[px] == yIndices[py]) { intersectionSize += 1; px += 1; py += 1; } else if (xIndices[px] < yIndices[py]) { px += 1; } else { py += 1; } } int unionSize = xIndices.length + yIndices.length - intersectionSize; return 1. - 1. * intersectionSize / unionSize; } /** Encoder for {@link MinHashLSHModelData}. */ public static class ModelDataEncoder implements Encoder { @Override public void encode(MinHashLSHModelData modelData, OutputStream outputStream) throws IOException { DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); IntSerializer.INSTANCE.serialize(modelData.numHashTables, dataOutputView); IntSerializer.INSTANCE.serialize(modelData.numHashFunctionsPerTable, dataOutputView); IntPrimitiveArraySerializer.INSTANCE.serialize( modelData.randCoefficientA, dataOutputView); IntPrimitiveArraySerializer.INSTANCE.serialize( modelData.randCoefficientB, dataOutputView); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy