All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.google.common.collect.TopKSelector Maven / Gradle / Ivy

Go to download

Guava is a suite of core and expanded libraries that include utility classes, Google's collections, I/O classes, and much more. This project includes GWT-friendly sources.

The newest version!
/*
 * Copyright (C) 2014 The Guava Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.google.common.collect;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.NullnessCasts.uncheckedCastNullableTToT;
import static java.lang.Math.max;
import static java.util.Arrays.asList;
import static java.util.Arrays.sort;
import static java.util.Collections.unmodifiableList;

import com.google.common.annotations.GwtCompatible;
import com.google.common.math.IntMath;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import javax.annotation.CheckForNull;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
 * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided
 * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to
 * create the {@code TopKSelector} instance.
 *
 * 

If your input data is available as a {@link Stream}, prefer passing {@link * Comparators#least(int)} to {@link Stream#collect(java.util.stream.Collector)}. If it is available * as an {@link Iterable} or {@link Iterator}, prefer {@link Ordering#leastOf(Iterable, int)}. * *

This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)}, * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log * k). In benchmarks, this implementation performs at least as well as either implementation, and * degrades more gracefully for worst-case input. * *

The implementation does not necessarily use a stable sorting algorithm; when multiple * equivalent elements are added to it, it is undefined which will come first in the output. * * @author Louis Wasserman */ @GwtCompatible @ElementTypesAreNonnullByDefault final class TopKSelector< T extends Object> { /** * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, * relative to the natural ordering of the elements, and returns them via {@link #topK} in * ascending order. * * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} */ public static > TopKSelector least(int k) { return least(k, Ordering.natural()); } /** * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it, * relative to the specified comparator, and returns them via {@link #topK} in ascending order. * * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} */ public static TopKSelector least( int k, Comparator comparator) { return new TopKSelector<>(comparator, k); } /** * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, * relative to the natural ordering of the elements, and returns them via {@link #topK} in * descending order. * * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} */ public static > TopKSelector greatest(int k) { return greatest(k, Ordering.natural()); } /** * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it, * relative to the specified comparator, and returns them via {@link #topK} in descending order. * * @throws IllegalArgumentException if {@code k < 0} or {@code k > Integer.MAX_VALUE / 2} */ public static TopKSelector greatest( int k, Comparator comparator) { return new TopKSelector<>(Ordering.from(comparator).reverse(), k); } private final int k; private final Comparator comparator; /* * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the * range [0, k) and ignore the remaining elements. */ private final T[] buffer; private int bufferSize; /** * The largest of the lowest k elements we've seen so far relative to this comparator. If * bufferSize ≥ k, then we can ignore any elements greater than this value. */ @CheckForNull private T threshold; @SuppressWarnings("unchecked") // TODO(cpovirk): Consider storing Object[] instead of T[]. private TopKSelector(Comparator comparator, int k) { this.comparator = checkNotNull(comparator, "comparator"); this.k = k; checkArgument(k >= 0, "k (%s) must be >= 0", k); checkArgument(k <= Integer.MAX_VALUE / 2, "k (%s) must be <= Integer.MAX_VALUE / 2", k); this.buffer = (T[]) new Object[IntMath.checkedMultiply(k, 2)]; this.bufferSize = 0; this.threshold = null; } /** * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes amortized * O(1) time. */ public void offer(@ParametricNullness T elem) { if (k == 0) { return; } else if (bufferSize == 0) { buffer[0] = elem; threshold = elem; bufferSize = 1; } else if (bufferSize < k) { buffer[bufferSize++] = elem; // uncheckedCastNullableTToT is safe because bufferSize > 0. if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) > 0) { threshold = elem; } // uncheckedCastNullableTToT is safe because bufferSize > 0. } else if (comparator.compare(elem, uncheckedCastNullableTToT(threshold)) < 0) { // Otherwise, we can ignore elem; we've seen k better elements. buffer[bufferSize++] = elem; if (bufferSize == 2 * k) { trim(); } } } /** * Quickselects the top k elements from the 2k elements in the buffer. O(k) expected time, O(k log * k) worst case. */ private void trim() { int left = 0; int right = 2 * k - 1; int minThresholdPosition = 0; // The leftmost position at which the greatest of the k lower elements // -- the new value of threshold -- might be found. int iterations = 0; int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3; while (left < right) { int pivotIndex = (left + right + 1) >>> 1; int pivotNewIndex = partition(left, right, pivotIndex); if (pivotNewIndex > k) { right = pivotNewIndex - 1; } else if (pivotNewIndex < k) { left = max(pivotNewIndex, left + 1); minThresholdPosition = pivotNewIndex; } else { break; } iterations++; if (iterations >= maxIterations) { @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts T[] castBuffer = (T[]) buffer; // We've already taken O(k log k), let's make sure we don't take longer than O(k log k). sort(castBuffer, left, right + 1, comparator); break; } } bufferSize = k; threshold = uncheckedCastNullableTToT(buffer[minThresholdPosition]); for (int i = minThresholdPosition + 1; i < k; i++) { if (comparator.compare( uncheckedCastNullableTToT(buffer[i]), uncheckedCastNullableTToT(threshold)) > 0) { threshold = buffer[i]; } } } /** * Partitions the contents of buffer in the range [left, right] around the pivot element * previously stored in buffer[pivotValue]. Returns the new index of the pivot element, * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in * (pivotNewIndex, right] is greater than pivotValue. */ private int partition(int left, int right, int pivotIndex) { T pivotValue = uncheckedCastNullableTToT(buffer[pivotIndex]); buffer[pivotIndex] = buffer[right]; int pivotNewIndex = left; for (int i = left; i < right; i++) { if (comparator.compare(uncheckedCastNullableTToT(buffer[i]), pivotValue) < 0) { swap(pivotNewIndex, i); pivotNewIndex++; } } buffer[right] = buffer[pivotNewIndex]; buffer[pivotNewIndex] = pivotValue; return pivotNewIndex; } private void swap(int i, int j) { T tmp = buffer[i]; buffer[i] = buffer[j]; buffer[j] = tmp; } TopKSelector combine(TopKSelector other) { for (int i = 0; i < other.bufferSize; i++) { this.offer(uncheckedCastNullableTToT(other.buffer[i])); } return this; } /** * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This * operation takes amortized linear time in the length of {@code elements}. * *

If all input data to this {@code TopKSelector} is in a single {@code Iterable}, prefer * {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use case. */ public void offerAll(Iterable elements) { offerAll(elements.iterator()); } /** * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This * operation takes amortized linear time in the length of {@code elements}. The iterator is * consumed after this operation completes. * *

If all input data to this {@code TopKSelector} is in a single {@code Iterator}, prefer * {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use case. */ public void offerAll(Iterator elements) { while (elements.hasNext()) { offer(elements.next()); } } /** * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if * fewer than {@code k} have been offered, in the order specified by the factory used to create * this {@code TopKSelector}. * *

The returned list is an unmodifiable copy and will not be affected by further changes to * this {@code TopKSelector}. This method returns in O(k log k) time. */ public List topK() { @SuppressWarnings("nullness") // safe because we pass sort() a range that contains real Ts T[] castBuffer = (T[]) buffer; sort(castBuffer, 0, bufferSize, comparator); if (bufferSize > k) { Arrays.fill(buffer, k, buffer.length, null); bufferSize = k; threshold = buffer[k - 1]; } // Up to bufferSize, all elements of buffer are real Ts (not null unless T includes null) T[] topK = Arrays.copyOf(castBuffer, bufferSize); // we have to support null elements, so no ImmutableList for us return unmodifiableList(asList(topK)); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy