All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.matrix.knn.BruteForceKNNFinder Maven / Gradle / Ivy

package org.wikibrain.matrix.knn;

import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;

import java.io.*;
import java.util.Arrays;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Brute force implementations of knn for a dense matrix.
 *
 * @author Shilad Sen
 */
public class BruteForceKNNFinder implements KNNFinder {
    private static final Logger LOG = LoggerFactory.getLogger(BruteForceKNNFinder.class);

    private final DenseMatrix matrix;

    public BruteForceKNNFinder(DenseMatrix matrix) throws IOException {
        this.matrix = matrix;
    }

    @Override
    public Neighborhood query(final float[] vector, int k, int maxTraversal, TIntSet validIds) {
        final NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k);
        if (validIds == null) {
            for (DenseMatrixRow row : matrix) {
                double sim = row.dot(vector);
                accum.visit(row.getRowIndex(), sim);
            }
        } else {
            validIds.forEach(new TIntProcedure() {
                @Override
                public boolean execute(int id) {
                    DenseMatrixRow row = null;
                    try {
                        row = matrix.getRow(id);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                    double sim = KmeansKNNFinder.cosine(vector, row);
                    accum.visit(id, sim);
                    return true;
                }
            });
        }
        return accum.get();
    }


    @Override
    public void build() throws IOException {}
    @Override
    public void save(File path) throws IOException {}
    @Override
    public boolean load(File path) throws IOException { return true; }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy