All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.jnthnclt.os.lab.nn.LABNG Maven / Gradle / Ivy

package com.github.jnthnclt.os.lab.nn;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.collect.Sets;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

public class LABNG {

    static class N {
        private final long id;
        private final double[] features;

        N(long id, double[] features) {
            this.id = id;
            this.features = features;
        }

        @Override
        public String toString() {
            return "N{" +
                "id=" + id +
                ", features=" + Arrays.toString(features) +
                '}';
        }
    }


    static String[] color = new String[] { "green", "orange", "red", "purple", "blue", "cyan", "gray", "black" };

    static class NG {
        private final N n;
        NG[] nearest;
        int depth = 0;

        NG(N n) {
            this.n = n;
        }

        public void asString() {
            for (int i = 0; i < nearest.length; i++) {
                System.out.println(n.id + " -> " + nearest[i].n.id + "[ color=\"" + color[Math.min(i, color.length - 1)] + "\"];");
            }
        }

        public NG find(N find, AtomicLong count, Set starts) {
            count.incrementAndGet();
            double d = euclidianDistance(n.features, find.features, -1);
            int displace = -1;
            for (int i = 0; i < nearest.length; i++) {

//                if (starts.contains((int)nearest[i].n.id)) {
//                    continue;
//                }

                count.incrementAndGet();
                double ed = euclidianDistance(nearest[i].n.features, find.features, -1);
                if (ed < d) {
                    d = ed;
                    displace = i;
                }
            }
            return displace == -1 ? this : nearest[displace];
        }


    }

    public static void main(String[] args) {
        long seed = System.currentTimeMillis();
        Random rand = new Random(seed);

        int topN = 1;
        int numFeatures = 2;
        int numN = 1024;
        int maxZ = 0;

        N[] neighbors = new N[numN];
        for (int i = 0; i < neighbors.length; i++) {
            neighbors[i] = randomN(rand, i, numFeatures);
        }

        NG[] ngs = new NG[neighbors.length];
        Map lut = Maps.newHashMap();
        for (int i = 0; i < neighbors.length; i++) {
            int si = i;
            NG startNG = lut.computeIfAbsent(si, integer -> new NG(neighbors[si]));
            ngs[i] = startNG;

            int[] is = topN(i, neighbors, topN);
            NG[] topNGs = new NG[topN];
            for (int j = 0; j < topNGs.length; j++) {
                int ni = is[j];
                topNGs[j] = lut.computeIfAbsent(ni, integer -> new NG(neighbors[ni]));
            }
            startNG.nearest = topNGs;
        }

//        Foo:
//        {
//            for (int z = 0; z < maxZ; z++) {
//
//
//
//
//                Map tos = Maps.newHashMap();
//                for (NG n : ngs) {
//                    tos.compute(n.n.id, (k, v) -> v == null ? 1 : v + 1);
//                    for (int i = 0; i < n.nearest.length; i++) {
//                        tos.compute(n.nearest[i].n.id, (k, v) -> v == null ? 1 : v + 1);
//                    }
//                }
//
//                Set starts = Sets.newHashSet();
//                for (Entry entry : tos.entrySet()) {
//                    if (entry.getValue() == 1) {
//                        starts.add((int) (long) entry.getKey());
//                    }
//                }
//
//                if (starts.size() <= 2) {
//                    break;
//                }
//
//                N[] mutualNs = new N[starts.size()];
//                int p = 0;
//                for (Integer start : starts) {
//                    mutualNs[p] = neighbors[start];
//                    p++;
//                }
//
//                for (int i = 0; i < mutualNs.length; i++) {
//                    NG startNG = lut.get((int) mutualNs[i].id);
//                    startNG.depth++;
//
//                    int[] is = topN(i, mutualNs, 1);
//                    NG[] topNGs = new NG[1];
//                    for (int j = 0; j < topNGs.length; j++) {
//                        topNGs[j] = lut.get(is[j]);
//                    }
//
//                    NG next = startNG;
//                    NG[] grow = Arrays.copyOf(next.nearest, next.nearest.length + 1);
//                    grow[grow.length - 1] = topNGs[0];
//                    next.nearest = grow;
//
//                }
//            }
//        }


        // Find mutual
        Set ends = null;
        Set mutual = Sets.newHashSet();
        int depth = 0;
        do {

            mutual.clear();
            depth++;


            for (int i = 0; i < ngs.length; i++) {
                NG at = ngs[i];

                for (int j = 0; j < at.nearest.length; j++) {
                    NG next = at.nearest[j];
                    for (int k = 0; k < next.nearest.length; k++) {
                        if (next.nearest[k] == at) {
                            System.out.println(at.n.id + "<->" + next.n.id);

                            mutual.add((int) at.n.id);
                            mutual.add((int) next.n.id);
                        }
                    }
                }
            }

            ends = mutual;
            break;
//            if (ends.size() <= depth + 1) {
//                break;
//            }
//
//            N[] mutualNs = new N[mutual.size()];
//            int m = 0;
//            for (Integer ni : mutual) {
//                mutualNs[m] = neighbors[ni];
//                m++;
//            }
//
//
//            for (int i = 0; i < mutualNs.length; i++) {
//                NG startNG = lut.get((int) mutualNs[i].id);
//                startNG.depth++;
//
//                int[] is = topN(i, mutualNs, depth + 1);
//                NG[] topNGs = new NG[1];
//                for (int j = 0; j < topNGs.length; j++) {
//                    topNGs[j] = lut.get(is[depth + j]);
//                }
//                startNG.nearest = topNGs;
//            }


        }
        while (mutual.size() > 0);


        System.out.println(ends);
        if (2 + 2 == 4) {
            return;
        }


        System.out.println();


        for (NG n : ngs) {
            n.asString();

        }
        System.out.println();


        NG[] reversed = new NG[neighbors.length];
        for (int i = 0; i < ngs.length; i++) {
            NG ng = ngs[i];
            reversed[i] = new NG(ng.n);
            reversed[i].depth = ng.depth;
            reversed[i].nearest = new NG[0];

        }

        for (int i = 0; i < ngs.length; i++) {
            NG at = ngs[i];
            for (int j = 0; j < at.nearest.length; j++) {
                NG next = reversed[(int) at.nearest[j].n.id];
                NG[] grow = Arrays.copyOf(next.nearest, next.nearest.length + 1);
                grow[grow.length - 1] = reversed[i];
                next.nearest = grow;
            }
        }

        System.out.println();


        for (NG n : reversed) {
            n.asString();

        }
        System.out.println();


        System.out.println();
//
        Map tos = Maps.newHashMap();
        for (NG n : ngs) {
            tos.compute(n.n.id, (k, v) -> v == null ? 1 : v + 1);
            for (int i = 0; i < n.nearest.length; i++) {
                tos.compute(n.nearest[i].n.id, (k, v) -> v == null ? 1 : v + 1);
            }
        }

        Set starts = Sets.newHashSet();
        for (Entry entry : tos.entrySet()) {
            if (entry.getValue() == 1) {
                starts.add((int) (long) entry.getKey());
            }
        }


        System.out.println();
//
        System.out.println(starts);
        System.out.println(ends);
//
//        System.out.println();
        System.out.println(starts.size() + " " + ngs.length);


        long naiveElapse = 0;
        long fancyElapse = 0;
        N naiveFound = null;
        NG fancyFound = null;
        N query = null;

        AtomicLong naiveCount = new AtomicLong();
        AtomicLong fancyCount = new AtomicLong();


        for (int j = 0; j < 1; j++) {


            query = randomN(rand, -1, numFeatures);
            long timestamp = System.currentTimeMillis();
            naiveFound = naive(query, neighbors);
            naiveCount.addAndGet(neighbors.length);
            naiveElapse = System.currentTimeMillis() - timestamp;

            timestamp = System.currentTimeMillis();

            NG start = null;
            NG winner = null;
            double best = Double.MAX_VALUE;
            for (Integer integer : ends) {
                start = reversed[integer];

                double ed = euclidianDistance(start.n.features, query.features, -1);
                fancyCount.incrementAndGet();
                if (ed < best) {
                    best = ed;
                    winner = start;


                    fancyFound = null;
                    do {
                        if (fancyFound != null) {
                            start = fancyFound;
                        }
                        fancyFound = start.find(query, fancyCount, ends);
                        if (fancyFound != null) {
                            System.out.println(start.n.id + "->" + fancyFound.n.id + ";");
                        }
                    }
                    while (fancyFound != start);

                    ed = euclidianDistance(fancyFound.n.features, query.features, -1);
                    if (ed < best) {
                        best = ed;
                        winner = fancyFound;
                    }
                }
            }
            fancyFound = winner;

            fancyElapse += System.currentTimeMillis() - timestamp;
        }

        System.out.println(
            "fancy:" + fancyCount.get() + " " + fancyElapse + " | " + euclidianDistance(fancyFound.n.features, query.features, -1) + " answer:" + fancyFound.n);
        System.out.println(
            "naive:" + naiveCount.get() + " " + naiveElapse + " | " + euclidianDistance(naiveFound.features, query.features, -1) + " answer:" + naiveFound);
        System.out.println("query:" + query);
        System.out.println("seed:" + seed);

    }

    static int[] topN(int i, N[] neighbors, int topN) {
        N n = neighbors[i];
        Comparator comparator = (o1, o2) -> Double.compare(o2.v, o1.v);
        MinMaxPriorityQueue heap = MinMaxPriorityQueue.orderedBy(comparator).maximumSize(topN).create();
        for (int j = 0; j < neighbors.length; j++) {
            if (j != i) {
                heap.add(new TN((int) neighbors[j].id, euclidianDistance(n.features, neighbors[j].features, -1)));
            }
        }
        int[] is = new int[topN];
        int j = 0;
        List all = Lists.newArrayList(heap);
        Collections.sort(all, comparator);
        for (TN tn : all) {
            is[j] = tn.i;
            j++;
        }
        return is;
    }

    static class TN {
        private final int i;
        private final double v;

        TN(int i, double v) {
            this.i = i;
            this.v = v;
        }
    }


    public static N randomN(Random random, long id, int numFeatures) {
        double[] features = new double[numFeatures];
        for (int i = 0; i < features.length; i++) {
            features[i] = random.nextDouble();
        }
        return new N(id, features);
    }


    public static N naive(N query, N[] neighbors) {

        double min = Double.MAX_VALUE;
        N nearest = null;
        for (N neighbor : neighbors) {
            double d = euclidianDistance(query.features, neighbor.features, -1);
            if (d < min) {
                min = d;
                nearest = neighbor;
            }
        }
        return nearest;
    }

    static public double euclidianDistance(double[] a, double[] b, int exclude) {
        double v = 0;
        for (int i = 0; i < a.length; i++) {
            if (i != exclude) {
                double d = a[i] - b[i];
                v += (d * d);
            }
        }
        return v;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy