All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mitre.caasd.commons.collect.SetSearch Maven / Gradle / Ivy

/*
 *    Copyright 2022 The MITRE Corporation
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.mitre.caasd.commons.collect;

import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.PriorityQueue;

import org.mitre.caasd.commons.Pair;

/**
 * A Search iterates through a MetricSet and collects Keys that are close to the "search key".
 * 

* Search objects can perform a "k-nearest neighbors" or "all neighbors within range" search. Both * of these search types require providing a "search Key". *

* This class is package private because it is an implementation detail of the MetricSet class. * * @param This key class is used to measure distance between two objects */ class SetSearch { private enum SearchType { K_NEAREST_NEIGHBORS, RANGE } private final DistanceMetric metric; private final SearchType type; private final K searchKey; private final int maxNumResults; // only used for kNN searches private final double fixedRadius; // only used for range searches private final PriorityQueue> queue; /** * Create a kNN search query. * * @param searchKey Search for this * @param maxNumResults The "k" in k-Nearest-Neighbors * @param metric The distance metric used to determine how far objects are */ SetSearch(K searchKey, int maxNumResults, DistanceMetric metric) { this.metric = metric; this.type = SearchType.K_NEAREST_NEIGHBORS; this.searchKey = searchKey; this.maxNumResults = maxNumResults; this.fixedRadius = Double.POSITIVE_INFINITY; this.queue = new PriorityQueue<>(); } /** * Create a range query that returns all entries within range * * @param searchKey Search for this * @param metric The distance metric used to determine how far objects are * @param range Include results within this distance */ SetSearch(K searchKey, DistanceMetric metric, double range) { this.metric = metric; this.type = SearchType.RANGE; this.searchKey = searchKey; this.maxNumResults = Integer.MAX_VALUE; this.fixedRadius = range; this.queue = new PriorityQueue<>(); } /* * Note: This search process cannot be written as a recursive search. Searching recursivly can * produce a StackoverflowError when the underlying tree is deeper than the JVM's internal stack */ void startQuery(MetricSet.Sphere root) { Deque.Sphere> stack = new ArrayDeque<>(); stack.push(root); while (!stack.isEmpty()) { MetricSet.Sphere current = stack.pop(); // ignore this node (and all its sub-trees) because it cannot improve the current result if (!this.overlapsWith(current)) { continue; } if (current.isSphereOfPoints()) { ingestSphereOfPoints(current); } else { Pair.Sphere, MetricSet.Sphere> childSpheres = current.children(); double firstDist = metric.distanceBtw(searchKey, childSpheres.first().centerPoint); double secondDist = metric.distanceBtw(searchKey, childSpheres.second().centerPoint); /* * Submit the closest sphere second to reduce work (because this increases the * chance we can skip items in the sphere that are further away). */ if (firstDist < secondDist) { stack.push(childSpheres.second()); stack.push(childSpheres.first()); // will be popped first } else { stack.push(childSpheres.first()); stack.push(childSpheres.second()); // will be popped second } } } } private void ingestSphereOfPoints(MetricSet.Sphere inputSphere) { for (K key : inputSphere.points()) { SetSearchResult r = new SetSearchResult<>(key, metric.distanceBtw(searchKey, key)); if (r.distance <= this.radius()) { this.queue.offer(r); // enforce the "k" in kNN search if (queue.size() > this.maxNumResults) { // if too big, remove the worst result queue.poll(); } } } } /** * @return True when the "query sphere" and this sphere overlap. */ private boolean overlapsWith(MetricSet.Sphere s) { double distance = metric.distanceBtw(s.centerPoint, this.searchKey); double overlap = s.radius() + this.radius() - distance; return (overlap >= 0); } /** * @return The "inclusion radius" based on the type of query being executed and the quality of * the current results (so we can avoid processing spheres that cannot contain better * results) */ private double radius() { if (type == SearchType.K_NEAREST_NEIGHBORS) { if (queue.size() < maxNumResults) { // radius is still large because we haven't found "k" results yet return Double.POSITIVE_INFINITY; } else { return queue.peek().distance; // must beat this to improve } } else if (type == SearchType.RANGE) { return this.fixedRadius; // includes everything within this radius } else { throw new AssertionError("Should never get here"); } } Collection> results() { return queue; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy