All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.jelmerk.knn.examples.FastText2 Maven / Gradle / Ivy

There is a newer version: 1.0.1
Show newest version
package com.github.jelmerk.knn.examples;

import com.github.jelmerk.knn.DistanceFunctions;
import com.github.jelmerk.knn.Index;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.OptionalDouble;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream;

import static com.github.jelmerk.knn.util.VectorUtils.normalize;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
 * Example application that will download the english fast-text word vectors insert them into a hnsw index and lets
 * you query them.
 */
public class FastText2 {

    private static final String WORDS_FILE_URL = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz";

    private static final Path TMP_PATH = Paths.get(System.getProperty("java.io.tmpdir"));

    public static void main(String[] args) throws Exception {

        List mValues = Arrays.asList(10, 16, 32, 48, 64);

        for (Integer m : mValues) {

            Path file = TMP_PATH.resolve("cc.en.300.vec.gz");

            if (!Files.exists(file)) {
                downloadFile(WORDS_FILE_URL, file);
            } else {
                System.out.printf("Input file already downloaded. Using %s%n", file);
            }

            List words = loadWordVectors(file);

            System.out.println("Constructing index.");


            HnswIndex hnswIndex = HnswIndex
                    .newBuilder(300, DistanceFunctions.FLOAT_INNER_PRODUCT, words.size())
                    .withM(m)
                    .build();

            long start = System.currentTimeMillis();

            hnswIndex.addAll(words, (workDone, max) -> System.out.printf("Added %d out of %d words to the index.%n", workDone, max));

            long end = System.currentTimeMillis();

            long duration = end - start;

            System.out.printf("Creating index with %d words took %d millis which is %d minutes.%n", hnswIndex.size(), duration, MILLISECONDS.toMinutes(duration));


            Index groundTruthIndex = hnswIndex.asExactIndex();


            List efValues = Arrays.asList(10, 50, 100, 150, 200);


            for (Integer ef : efValues) {
                Random random = new Random(1000L);

//                hnswIndex.setEf(ef);

                OptionalDouble average = IntStream.range(0, 2000).mapToDouble(i -> {

                    String randomWord = words.get(random.nextInt(words.size())).id();

                    List> expectedResults = groundTruthIndex.findNeighbors(randomWord, 10);
                    List> actualResults = hnswIndex.findNeighbors(randomWord, 10);

                    int correct = expectedResults.stream().mapToInt(r -> actualResults.contains(r) ? 1 : 0).sum();
                    return (double) correct / (double) expectedResults.size();

                }).average();

                System.out.println("Average precession for m = " + m + " and ef " + ef + " : " + average.getAsDouble());

            }

        }
//        Console console = System.console();
//
//        while (true) {
//            System.out.println("Enter an english word : ");
//
//            String input = console.readLine();
//
//            List> nearest = index.findNeighbors(input, 10);
//
//            System.out.println("Most similar words : ");
//
//            for (SearchResult result : nearest) {
//                System.out.printf("%s %.4f%n", result.item().id(), result.distance());
//            }
//        }
    }

    private static void downloadFile(String url, Path path) throws IOException {
        System.out.printf("Downloading %s to %s. This may take a while.%n", url, path);

        try (InputStream in = new URL(url).openStream()) {
            Files.copy(in, path);
        }
    }
    private static List loadWordVectors(Path path) throws IOException {
        System.out.printf("Loading words from %s%n", path);

        try (BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(Files.newInputStream(path)), StandardCharsets.UTF_8))) {
            return reader.lines()
                    .skip(1)
                    .map(line -> {
                        String[] tokens = line.split(" ");

                        String word = tokens[0];

                        float[] vector = new float[tokens.length - 1];
                        for (int i = 1; i < tokens.length - 1; i++) {
                            vector[i] = Float.parseFloat(tokens[i]);
                        }

                        return new Word(word, normalize(vector)); // normalize the vector so we can do inner product search
                    })
                    .collect(Collectors.toList());
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy