com.github.jelmerk.knn.examples.FastText2 Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of hnswlib-examples-java Show documentation
Show all versions of hnswlib-examples-java Show documentation
Examples for the core java library
package com.github.jelmerk.knn.examples;
import com.github.jelmerk.knn.DistanceFunctions;
import com.github.jelmerk.knn.Index;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.OptionalDouble;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream;
import static com.github.jelmerk.knn.util.VectorUtils.normalize;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
/**
* Example application that will download the english fast-text word vectors insert them into a hnsw index and lets
* you query them.
*/
public class FastText2 {
private static final String WORDS_FILE_URL = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz";
private static final Path TMP_PATH = Paths.get(System.getProperty("java.io.tmpdir"));
public static void main(String[] args) throws Exception {
List mValues = Arrays.asList(10, 16, 32, 48, 64);
for (Integer m : mValues) {
Path file = TMP_PATH.resolve("cc.en.300.vec.gz");
if (!Files.exists(file)) {
downloadFile(WORDS_FILE_URL, file);
} else {
System.out.printf("Input file already downloaded. Using %s%n", file);
}
List words = loadWordVectors(file);
System.out.println("Constructing index.");
HnswIndex hnswIndex = HnswIndex
.newBuilder(300, DistanceFunctions.FLOAT_INNER_PRODUCT, words.size())
.withM(m)
.build();
long start = System.currentTimeMillis();
hnswIndex.addAll(words, (workDone, max) -> System.out.printf("Added %d out of %d words to the index.%n", workDone, max));
long end = System.currentTimeMillis();
long duration = end - start;
System.out.printf("Creating index with %d words took %d millis which is %d minutes.%n", hnswIndex.size(), duration, MILLISECONDS.toMinutes(duration));
Index groundTruthIndex = hnswIndex.asExactIndex();
List efValues = Arrays.asList(10, 50, 100, 150, 200);
for (Integer ef : efValues) {
Random random = new Random(1000L);
// hnswIndex.setEf(ef);
OptionalDouble average = IntStream.range(0, 2000).mapToDouble(i -> {
String randomWord = words.get(random.nextInt(words.size())).id();
List> expectedResults = groundTruthIndex.findNeighbors(randomWord, 10);
List> actualResults = hnswIndex.findNeighbors(randomWord, 10);
int correct = expectedResults.stream().mapToInt(r -> actualResults.contains(r) ? 1 : 0).sum();
return (double) correct / (double) expectedResults.size();
}).average();
System.out.println("Average precession for m = " + m + " and ef " + ef + " : " + average.getAsDouble());
}
}
// Console console = System.console();
//
// while (true) {
// System.out.println("Enter an english word : ");
//
// String input = console.readLine();
//
// List> nearest = index.findNeighbors(input, 10);
//
// System.out.println("Most similar words : ");
//
// for (SearchResult result : nearest) {
// System.out.printf("%s %.4f%n", result.item().id(), result.distance());
// }
// }
}
private static void downloadFile(String url, Path path) throws IOException {
System.out.printf("Downloading %s to %s. This may take a while.%n", url, path);
try (InputStream in = new URL(url).openStream()) {
Files.copy(in, path);
}
}
private static List loadWordVectors(Path path) throws IOException {
System.out.printf("Loading words from %s%n", path);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(Files.newInputStream(path)), StandardCharsets.UTF_8))) {
return reader.lines()
.skip(1)
.map(line -> {
String[] tokens = line.split(" ");
String word = tokens[0];
float[] vector = new float[tokens.length - 1];
for (int i = 1; i < tokens.length - 1; i++) {
vector[i] = Float.parseFloat(tokens[i]);
}
return new Word(word, normalize(vector)); // normalize the vector so we can do inner product search
})
.collect(Collectors.toList());
}
}
}