All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.jnthnclt.os.lab.nn.LABNN Maven / Gradle / Ivy

package com.github.jnthnclt.os.lab.nn;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class LABNN {

    static class N {
        private final long id;
        private final double[] features;
        public boolean used = false;

        N(long id, double[] features) {
            this.id = id;
            this.features = features;
        }

        @Override
        public String toString() {
            return "N{" +
                "id=" + id +
                ", features=" + Arrays.toString(features) +
                '}';
        }
    }

    public static N randomN(Random random, long id, int numFeatures) {
        double[] features = new double[numFeatures];
        for (int i = 0; i < features.length; i++) {
            features[i] = random.nextDouble();
        }
        return new N(id, features);
    }

    static class Dimension {
        private final N[] neighbors;
        private final Comparator nComparator;
        private final double min;
        private final double max;

        Dimension(int featureIndex, N[] neighbors) {
            this.nComparator = Comparator.comparingDouble(o -> o.features[featureIndex]);

            N[] sort = Arrays.copyOf(neighbors, neighbors.length);
            Arrays.sort(sort, nComparator);
            this.neighbors = sort;
            this.min = sort[0].features[featureIndex];
            this.max = sort[sort.length - 1].features[featureIndex];
        }

        int closest(N n) {
            return Math.min(absInsertion(Arrays.binarySearch(neighbors, n, nComparator)), neighbors.length - 1);
        }

        N n(int i) {
            return i < 0 ? null : i >= neighbors.length ? null : neighbors[i];
        }

        N nn(int i) {
            return i < 0 ? neighbors[0] : i >= neighbors.length ? neighbors[neighbors.length - 1] : neighbors[i];
        }
    }


    public static N naive(N query, N[] neighbors) {

        double min = Double.MAX_VALUE;
        N nearest = null;
        for (N neighbor : neighbors) {
            double d = euclidianDistance(query.features, neighbor.features, -1);
            if (d < min) {
                min = d;
                nearest = neighbor;
            }
        }
        return nearest;
    }


    public static int absInsertion(int binarysearchIndex) {
        return binarysearchIndex < 0 ? (-(binarysearchIndex) - 1) : binarysearchIndex;
    }

    public static void main(String[] args) {

        long seed = System.currentTimeMillis();
        Random rand = new Random(seed);

        int numFeatures = 100;
        int numN = 100;
        int loops = 1;

        N[] neighbors = new N[numN];
        for (int i = 0; i < neighbors.length; i++) {
            neighbors[i] = randomN(rand, i, numFeatures);
        }


        N query = randomN(rand, -1, numFeatures);


        Dimension[] dimensions = new Dimension[numFeatures];
        for (int i = 0; i < numFeatures; i++) {
            dimensions[i] = new Dimension(i, neighbors);
        }
        N n = null;
        long timestamp = System.currentTimeMillis();
        long ids = 0;
        for (int i = 0; i < loops; i++) {
            n = fancy(query, dimensions);
            //n = naive2(query, neighbors);
            ids += n.id;
        }
        long elapse = System.currentTimeMillis() - timestamp;

        System.out.println(euclidianDistance(n.features, query.features, -1) + "|" + elapse + " fancy answer:" + n + " " + ids);


        timestamp = System.currentTimeMillis();
        ids = 0;
        for (int i = 0; i < loops; i++) {
            n = naive(query, neighbors);
            ids += n.id;
        }
        elapse = System.currentTimeMillis() - timestamp;
        System.out.println(euclidianDistance(n.features, query.features, -1) + "|" + elapse + " naive answer:" + n + " " + ids);
        System.out.println("query:" + query);
        System.out.println("seed:" + seed);

    }

    public static N fancy(N query, Dimension[] dimensions) {

        double min = Double.MAX_VALUE;
        N nearest = null;

        int gets = 0;

        int dl = dimensions.length;
        int[] fis = new int[dl];
        int[] highs = new int[dl];
        int[] lows = new int[dl];

        int max_neighbors = dimensions[0].neighbors.length;

        Map counts = Maps.newHashMap();
        for (int i = 0; i < dl; i++) {
            gets++;
            fis[i] = lows[i] = highs[i] = dimensions[i].closest(query);
            long id = dimensions[i].nn(fis[i]).id;
            counts.put(id, 1);
            System.out.println("S " + i + " i=" + fis[i] + " id=" + id + " count=1");
        }
        int topN = 1;
        List nearests = Lists.newArrayList();
        while (nearests.isEmpty() && topN >= 0) {

            for (int i = 0; i < dl; i++) {
                if (lows[i] > 0) {
                    lows[i]--;
                    gets++;
                    long id = dimensions[i].nn(lows[i]).id;
                    Integer c = counts.compute(id, (k, v) -> v == null ? 1 : v + 1);
                    if (c == dl) {
                        nearests.add(dimensions[i].nn(lows[i]));
                    }

                    System.out.println("L " + i + " i=" + lows[i] + " id=" + id + " count=" + c);
                }

                if (highs[i] < max_neighbors - 1) {
                    highs[i]++;
                    gets++;
                    long id = dimensions[i].nn(highs[i]).id;
                    Integer c = counts.compute(id, (k, v) -> v == null ? 1 : v + 1);
                    if (c == dl) {
                        nearests.add(dimensions[i].nn(highs[i]));
                    }
                    System.out.println("H " + i + " i=" + highs[i] + " id=" + id + " count=" + c);
                }
            }

            if (!nearests.isEmpty()) {
                int better = 0;
                for (N n : nearests) {
                    double d = euclidianDistance(n.features, query.features, -1);
                    if (d < min) {
                        min = d;
                        nearest = n;
                        better++;
                    }
                }
                if (better == 0) {
                    topN--;
                }
                nearests.clear();
            }

        }


        System.out.println("gets:" + gets);
        return nearest;

    }


    public static N fancy_hmm(N query, Dimension[] dimensions) {

        double min = Double.MAX_VALUE;
        N nearest = query;
        int compares = 0;

        int dl = dimensions.length;
        int[] fis = new int[dl];
        int[] highs = new int[dl];
        int[] cis = new int[dl];
        int[] lows = new int[dl];


        int max_neighbors = dimensions[0].neighbors.length;
        int di = 0;
        while (di < dl) {

            fis[di] = lows[di] = highs[di] = cis[di] = dimensions[di].closest(nearest);

            N n = dimensions[di].nn(cis[di]);

            double pd = euclidianDistance(query.features, n.features, di);
            double delta = query.features[di] = n.features[di];
            double d = pd + (delta * delta);
            compares++;
            if (d < min) {
                nearest = n;
                min = d;
            }

            double low = low(min, nearest.features[di], pd * dl);
            double high = high(min, nearest.features[di], pd * dl);
            System.out.println(di + " " + low + " " + nearest.features[di] + " " + high);
            boolean lowDone = false;
            boolean highDone = false;
            while ((!lowDone && lows[di] > 0) || (!highDone && highs[di] < max_neighbors - 1)) {
                if (!lowDone && lows[di] > 0) {
                    lows[di]--;
                    n = dimensions[di].nn(lows[di]);
                    if (n.features[di] >= low) {
                        pd = euclidianDistance(query.features, n.features, di);
                        compares++;
                        delta = query.features[di] = n.features[di];
                        d = pd + (delta * delta);
                        if (d < min) {
                            lowDone = false;
                            highDone = false;
                            nearest = n;
                            min = d;
                            low = low(min, nearest.features[di], pd * dl);
                            high = high(min, nearest.features[di], pd * dl);
                            System.out.println(di + " LOW CLOSER:" + d + " " + low + " " + nearest.features[di] + " " + high);
                        }
                    } else {
                        lowDone = true;
                        System.out.println(di + " LOW DONE:" + n.features[di] + " >= " + low);
                    }
                }
                if (!highDone && highs[di] < max_neighbors - 1) {
                    highs[di]++;
                    n = dimensions[di].nn(highs[di]);
                    if (n.features[di] <= high) {
                        pd = euclidianDistance(query.features, n.features, di);
                        compares++;
                        delta = query.features[di] = n.features[di];
                        d = pd + (delta * delta);
                        if (d < min) {
                            lowDone = false;
                            highDone = false;
                            nearest = n;
                            min = d;
                            low = low(min, nearest.features[di], pd * dl);
                            high = high(min, nearest.features[di], pd * dl);
                            System.out.println(di + " HIGH CLOSER:" + d + " " + low + " " + nearest.features[di] + " " + high);
                        }
                    } else {
                        highDone = true;
                        System.out.println(di + " HIGH DONE:" + n.features[di] + " <= " + high);
                    }
                }
            }
            di++;
        }

        System.out.println(compares + " " + max_neighbors);
        return nearest;
    }

    static private void hl(String context, double[] lows, double[] highs) {
        for (int i = 0; i < lows.length; i++) {
            System.out.print((i > 0 ? "\t" : "") + lows[i] + "\t" + highs[i]);
        }
        System.out.println();

        //System.out.println(context + " lows:" + Arrays.toString(lows) + " highs:" + Arrays.toString(highs));
    }

    // Math.sqrt( ((x1-x2)^2) + ((y1-y2)^2)
    // ((x1-x2)^2) + ((y1-y2)^2) // lose unnecessary sqrt
    // ((10-x)^2) + ((10-15)^2) < ((10-20)^2) + ((30-15)^2)
    // (10 - x)^2 + 25 < 325
    // (10 - x)^2 < 325 - 25
    // 10 - x < sqrt(325 - 25)
    // -x < sqrt(325 - 25) - 10
    // x < -(sqrt(325 - 25) - 10)
    public static double low(double best, double dim, double partialDistance) {
        double v = -(Math.sqrt(best - partialDistance) - dim);
        return Double.isNaN(v) ? 0.0 : v;
    }


    // (10 - x)^2 + 25 < 325
    // 25 < 325 - (10 - x)^2
    // 325 + 25 < -((10 - x)^2)
    // 325 + 25 < -((10 - x)^2)
    // sqrt(325 + 25) < -(10 - x)
    // -sqrt(325 + 25) < 10 - x
    // -sqrt(325 + 25) - 10 < -x
    // -(-sqrt(325 + 25) - 10) < x
    public static double high(double best, double dim, double partialDistance) {
        double v = -(-Math.sqrt(best + partialDistance) - dim);
        return Double.isNaN(v) ? 0.0 : v;
    }


    //
    static public double euclidianDistance(double[] a, double[] b, int exclude) {
        double v = 0;
        for (int i = 0; i < a.length; i++) {
            if (i != exclude) {
                double d = a[i] - b[i];
                v += (d * d);
            }
        }
        return v;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy