org.apache.calcite.linq4j.EnumerableDefaults Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.linq4j;
import org.apache.calcite.linq4j.function.BigDecimalFunction1;
import org.apache.calcite.linq4j.function.DoubleFunction1;
import org.apache.calcite.linq4j.function.EqualityComparer;
import org.apache.calcite.linq4j.function.FloatFunction1;
import org.apache.calcite.linq4j.function.Function0;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.linq4j.function.Functions;
import org.apache.calcite.linq4j.function.IntegerFunction1;
import org.apache.calcite.linq4j.function.LongFunction1;
import org.apache.calcite.linq4j.function.NullableBigDecimalFunction1;
import org.apache.calcite.linq4j.function.NullableDoubleFunction1;
import org.apache.calcite.linq4j.function.NullableFloatFunction1;
import org.apache.calcite.linq4j.function.NullableIntegerFunction1;
import org.apache.calcite.linq4j.function.NullableLongFunction1;
import org.apache.calcite.linq4j.function.Predicate1;
import org.apache.calcite.linq4j.function.Predicate2;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.apiguardian.api.API;
import org.checkerframework.checker.nullness.qual.KeyFor;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.checker.nullness.qual.PolyNull;
import org.checkerframework.dataflow.qual.Pure;
import org.checkerframework.framework.qual.HasQualifierParameter;
import java.math.BigDecimal;
import java.util.AbstractList;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.RandomAccess;
import java.util.Set;
import java.util.TreeMap;
import static org.apache.calcite.linq4j.Linq4j.CollectionEnumerable;
import static org.apache.calcite.linq4j.Linq4j.ListEnumerable;
import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static org.apache.calcite.linq4j.function.Functions.adapt;
import static java.util.Objects.requireNonNull;
/**
* Default implementations of methods in the {@link Enumerable} interface.
*/
public abstract class EnumerableDefaults {
/**
* Applies an accumulator function over a sequence.
*/
public static @Nullable TSource aggregate(Enumerable source,
Function2<@Nullable TSource, TSource, TSource> func) {
try (Enumerator os = source.enumerator()) {
if (!os.moveNext()) {
return null;
}
TSource result = os.current();
while (os.moveNext()) {
TSource o = os.current();
result = func.apply(result, o);
}
return result;
}
}
/**
* Applies an accumulator function over a
* sequence. The specified seed value is used as the initial
* accumulator value.
*/
public static TAccumulate aggregate(
Enumerable source, TAccumulate seed,
Function2 func) {
TAccumulate result = seed;
try (Enumerator os = source.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
result = func.apply(result, o);
}
return result;
}
}
/**
* Applies an accumulator function over a
* sequence. The specified seed value is used as the initial
* accumulator value, and the specified function is used to select
* the result value.
*/
public static TResult aggregate(
Enumerable source, TAccumulate seed,
Function2 func,
Function1 selector) {
TAccumulate accumulate = seed;
try (Enumerator os = source.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
accumulate = func.apply(accumulate, o);
}
return selector.apply(accumulate);
}
}
/**
* Determines whether all elements of a sequence
* satisfy a condition.
*/
public static boolean all(Enumerable enumerable,
Predicate1 predicate) {
try (Enumerator os = enumerable.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
if (!predicate.apply(o)) {
return false;
}
}
return true;
}
}
/**
* Determines whether a sequence contains any
* elements.
*/
public static boolean any(Enumerable enumerable) {
return enumerable.enumerator().moveNext();
}
/**
* Determines whether any element of a sequence
* satisfies a condition.
*/
public static boolean any(Enumerable enumerable,
Predicate1 predicate) {
try (Enumerator os = enumerable.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
if (predicate.apply(o)) {
return true;
}
}
return false;
}
}
/**
* Returns the input typed as {@code Enumerable}.
*
* This method has no effect other than to change the compile-time type of
* source from a type that implements {@code Enumerable} to
* {@code Enumerable} itself.
*
* {@code AsEnumerable(Enumerable)} can be used to choose
* between query implementations when a sequence implements
* {@code Enumerable} but also has a different set of public query
* methods available. For example, given a generic class {@code Table} that
* implements {@code Enumerable} and has its own methods such as
* {@code where}, {@code select}, and {@code selectMany}, a call to
* {@code where} would invoke the public {@code where} method of
* {@code Table}. A {@code Table} type that represents a database table could
* have a {@code where} method that takes the predicate argument as an
* expression tree and converts the tree to SQL for remote execution. If
* remote execution is not desired, for example because the predicate invokes
* a local method, the {@code asEnumerable} method can be used to
* hide the custom methods and instead make the standard query operators
* available.
*/
public static Enumerable asEnumerable(
Enumerable enumerable) {
return enumerable;
}
/**
* Converts an Enumerable to an IQueryable.
*
* Analogous to the LINQ's Enumerable.AsQueryable extension method.
*
* @param enumerable Enumerable
* @param Element type
*
* @return A queryable
*/
public static Queryable asQueryable(
Enumerable enumerable) {
throw Extensions.todo();
}
/**
* Computes the average of a sequence of Decimal
* values that are obtained by invoking a transform function on
* each element of the input sequence.
*/
public static BigDecimal average(Enumerable source,
BigDecimalFunction1 selector) {
return sum(source, selector).divide(BigDecimal.valueOf(longCount(source)));
}
/**
* Computes the average of a sequence of nullable
* Decimal values that are obtained by invoking a transform
* function on each element of the input sequence.
*/
public static BigDecimal average(Enumerable source,
NullableBigDecimalFunction1 selector) {
return sum(source, selector).divide(BigDecimal.valueOf(longCount(source)));
}
/**
* Computes the average of a sequence of Double
* values that are obtained by invoking a transform function on
* each element of the input sequence.
*/
public static double average(Enumerable source,
DoubleFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Computes the average of a sequence of nullable
* Double values that are obtained by invoking a transform
* function on each element of the input sequence.
*/
public static Double average(Enumerable source,
NullableDoubleFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Computes the average of a sequence of int values
* that are obtained by invoking a transform function on each
* element of the input sequence.
*/
public static int average(Enumerable source,
IntegerFunction1 selector) {
return sum(source, selector) / count(source);
}
/**
* Computes the average of a sequence of nullable
* int values that are obtained by invoking a transform function
* on each element of the input sequence.
*/
public static Integer average(Enumerable source,
NullableIntegerFunction1 selector) {
return sum(source, selector) / count(source);
}
/**
* Computes the average of a sequence of long values
* that are obtained by invoking a transform function on each
* element of the input sequence.
*/
public static long average(Enumerable source,
LongFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Computes the average of a sequence of nullable
* long values that are obtained by invoking a transform function
* on each element of the input sequence.
*/
public static Long average(Enumerable source,
NullableLongFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Computes the average of a sequence of Float
* values that are obtained by invoking a transform function on
* each element of the input sequence.
*/
public static float average(Enumerable source,
FloatFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Computes the average of a sequence of nullable
* Float values that are obtained by invoking a transform
* function on each element of the input sequence.
*/
public static Float average(Enumerable source,
NullableFloatFunction1 selector) {
return sum(source, selector) / longCount(source);
}
/**
* Analogous to LINQ's Enumerable.Cast extension method.
*
* @param clazz Target type
* @param Target type
*
* @return Collection of T2
*/
public static Enumerable cast(
final Enumerable source, final Class clazz) {
return new AbstractEnumerable() {
@Override public Enumerator enumerator() {
return new CastingEnumerator<>(source.enumerator(), clazz);
}
};
}
/**
* Concatenates two sequences.
*/
public static Enumerable concat(
Enumerable enumerable0, Enumerable enumerable1) {
//noinspection unchecked
return Linq4j.concat(
Arrays.asList(enumerable0, enumerable1));
}
/**
* Determines whether a sequence contains a specified
* element by using the default equality comparer.
*/
public static boolean contains(Enumerable enumerable,
TSource element) {
// Implementations of Enumerable backed by a Collection call
// Collection.contains, which may be more efficient, not this method.
try (Enumerator os = enumerable.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
if (Objects.equals(o, element)) {
return true;
}
}
return false;
}
}
/**
* Determines whether a sequence contains a specified
* element by using a specified {@code EqualityComparer}.
*/
public static boolean contains(Enumerable enumerable,
TSource element, EqualityComparer comparer) {
for (TSource o : enumerable) {
if (comparer.equal(o, element)) {
return true;
}
}
return false;
}
/**
* Returns the number of elements in a
* sequence.
*/
public static int count(Enumerable enumerable) {
return (int) longCount(enumerable, Functions.truePredicate1());
}
/**
* Returns a number that represents how many elements
* in the specified sequence satisfy a condition.
*/
public static int count(Enumerable enumerable,
Predicate1 predicate) {
return (int) longCount(enumerable, predicate);
}
/**
* Returns the elements of the specified sequence or
* the type parameter's default value in a singleton collection if
* the sequence is empty.
*/
public static Enumerable<@Nullable TSource> defaultIfEmpty(
Enumerable enumerable) {
return defaultIfEmpty(enumerable, null);
}
/**
* Returns the elements of the specified sequence or
* the specified value in a singleton collection if the sequence
* is empty.
*
* If {@code value} is not null, the result is never null.
*/
@SuppressWarnings("return.type.incompatible")
public static Enumerable<@PolyNull TSource> defaultIfEmpty(
Enumerable enumerable,
@PolyNull TSource value) {
try (Enumerator os = enumerable.enumerator()) {
if (os.moveNext()) {
return Linq4j.asEnumerable(() -> new Iterator() {
private boolean nonFirst;
private @Nullable Iterator rest;
@Override public boolean hasNext() {
return !nonFirst || requireNonNull(rest, "rest").hasNext();
}
@Override public TSource next() {
if (nonFirst) {
return requireNonNull(rest, "rest").next();
} else {
final TSource first = os.current();
nonFirst = true;
rest = Linq4j.enumeratorIterator(os);
return first;
}
}
@Override public void remove() {
throw new UnsupportedOperationException("remove");
}
});
} else {
return Linq4j.singletonEnumerable(value);
}
}
}
/**
* Returns distinct elements from a sequence by using
* the default {@link EqualityComparer} to compare values.
*/
public static Enumerable distinct(
Enumerable enumerable) {
final Enumerator os = enumerable.enumerator();
final Set set = new HashSet<>();
while (os.moveNext()) {
set.add(os.current());
}
os.close();
return Linq4j.asEnumerable(set);
}
/**
* Returns distinct elements from a sequence by using
* a specified {@link EqualityComparer} to compare values.
*/
public static Enumerable distinct(
Enumerable enumerable, EqualityComparer comparer) {
if (comparer == Functions.identityComparer()) {
return distinct(enumerable);
}
final Set> set = new HashSet<>();
Function1> wrapper = wrapperFor(comparer);
Function1, TSource> unwrapper = unwrapper();
enumerable.select(wrapper).into(set);
return Linq4j.asEnumerable(set).select(unwrapper);
}
/**
* Returns the element at a specified index in a
* sequence.
*/
public static TSource elementAt(Enumerable enumerable,
int index) {
final ListEnumerable list = enumerable instanceof ListEnumerable
? ((ListEnumerable) enumerable)
: null;
if (list != null) {
return list.toList().get(index);
}
if (index < 0) {
throw new IndexOutOfBoundsException();
}
try (Enumerator os = enumerable.enumerator()) {
while (true) {
if (!os.moveNext()) {
throw new IndexOutOfBoundsException();
}
if (index == 0) {
return os.current();
}
index--;
}
}
}
/**
* Returns the element at a specified index in a
* sequence or a default value if the index is out of
* range.
*/
public static @Nullable TSource elementAtOrDefault(
Enumerable enumerable, int index) {
final ListEnumerable list = enumerable instanceof ListEnumerable
? ((ListEnumerable) enumerable)
: null;
if (index >= 0) {
if (list != null) {
final List rawList = list.toList();
if (index < rawList.size()) {
return rawList.get(index);
}
} else {
try (Enumerator os = enumerable.enumerator()) {
while (true) {
if (!os.moveNext()) {
break;
}
if (index == 0) {
return os.current();
}
index--;
}
}
}
}
return null;
}
/**
* Produces the set difference of two sequences by
* using the default equality comparer to compare values,
* eliminate duplicates. (Defined by Enumerable.)
*/
public static Enumerable except(
Enumerable source0, Enumerable source1) {
return except(source0, source1, false);
}
/**
* Produces the set difference of two sequences by
* using the default equality comparer to compare values,
* using {@code all} to indicate whether to eliminate duplicates.
* (Defined by Enumerable.)
*/
public static Enumerable except(
Enumerable source0, Enumerable source1, boolean all) {
Collection collection = all ? HashMultiset.create() : new HashSet<>();
source0.into(collection);
try (Enumerator os = source1.enumerator()) {
while (os.moveNext()) {
TSource o = os.current();
@SuppressWarnings("argument.type.incompatible")
boolean unused = collection.remove(o);
}
return Linq4j.asEnumerable(collection);
}
}
/**
* Produces the set difference of two sequences by
* using the specified {@code EqualityComparer} to compare
* values, eliminate duplicates.
*/
public static Enumerable except(
Enumerable source0, Enumerable source1,
EqualityComparer comparer) {
return except(source0, source1, comparer, false);
}
/**
* Produces the set difference of two sequences by
* using the specified {@code EqualityComparer} to compare
* values, using {@code all} to indicate whether to eliminate duplicates.
*/
public static Enumerable except(
Enumerable source0, Enumerable source1,
EqualityComparer comparer, boolean all) {
if (comparer == Functions.identityComparer()) {
return except(source0, source1, all);
}
Collection> collection = all ? HashMultiset.create() : new HashSet<>();
Function1> wrapper = wrapperFor(comparer);
source0.select(wrapper).into(collection);
try (Enumerator> os =
source1.select(wrapper).enumerator()) {
while (os.moveNext()) {
Wrapped o = os.current();
collection.remove(o);
}
}
Function1, TSource> unwrapper = unwrapper();
return Linq4j.asEnumerable(collection).select(unwrapper);
}
/**
* Returns the first element of a sequence. (Defined
* by Enumerable.)
*/
public static TSource first(Enumerable enumerable) {
try (Enumerator os = enumerable.enumerator()) {
if (os.moveNext()) {
return os.current();
}
throw new NoSuchElementException();
}
}
/**
* Returns the first element in a sequence that
* satisfies a specified condition.
*/
public static TSource first(Enumerable enumerable,
Predicate1 predicate) {
for (TSource o : enumerable) {
if (predicate.apply(o)) {
return o;
}
}
throw new NoSuchElementException();
}
/**
* Returns the first element of a sequence, or a
* default value if the sequence contains no elements.
*/
public static @Nullable TSource firstOrDefault(
Enumerable enumerable) {
try (Enumerator os = enumerable.enumerator()) {
if (os.moveNext()) {
return os.current();
}
return null;
}
}
/**
* Returns the first element of the sequence that
* satisfies a condition or a default value if no such element is
* found.
*/
public static @Nullable TSource firstOrDefault(Enumerable enumerable,
Predicate1 predicate) {
for (TSource o : enumerable) {
if (predicate.apply(o)) {
return o;
}
}
return null;
}
/**
* Groups the elements of a sequence according to a
* specified key selector function.
*/
public static Enumerable> groupBy(
final Enumerable enumerable,
final Function1 keySelector) {
return enumerable.toLookup(keySelector);
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and compares the keys by using
* a specified comparer.
*/
public static Enumerable> groupBy(
Enumerable enumerable, Function1 keySelector,
EqualityComparer comparer) {
return enumerable.toLookup(keySelector, comparer);
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and projects the elements for
* each group by using a specified function.
*/
public static Enumerable> groupBy(
Enumerable enumerable, Function1 keySelector,
Function1 elementSelector) {
return enumerable.toLookup(keySelector, elementSelector);
}
/**
* Groups the elements of a sequence according to a
* key selector function. The keys are compared by using a
* comparer and each group's elements are projected by using a
* specified function.
*/
public static Enumerable> groupBy(
Enumerable enumerable, Function1 keySelector,
Function1 elementSelector,
EqualityComparer comparer) {
return enumerable.toLookup(keySelector, elementSelector, comparer);
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and creates a result value from
* each group and its key.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
final Function2, TResult> resultSelector) {
return enumerable.toLookup(keySelector)
.select(group -> resultSelector.apply(group.getKey(), group));
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and creates a result value from
* each group and its key. The keys are compared by using a
* specified comparer.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
final Function2, TResult> resultSelector,
EqualityComparer comparer) {
return enumerable.toLookup(keySelector, comparer)
.select(group -> resultSelector.apply(group.getKey(), group));
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and creates a result value from
* each group and its key. The elements of each group are
* projected by using a specified function.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
Function1 elementSelector,
final Function2, TResult> resultSelector) {
return enumerable.toLookup(keySelector, elementSelector)
.select(group -> resultSelector.apply(group.getKey(), group));
}
/**
* Groups the elements of a sequence according to a
* specified key selector function and creates a result value from
* each group and its key. Key values are compared by using a
* specified comparer, and the elements of each group are
* projected by using a specified function.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
Function1 elementSelector,
final Function2, TResult> resultSelector,
EqualityComparer comparer) {
return enumerable.toLookup(keySelector, elementSelector, comparer)
.select(group -> resultSelector.apply(group.getKey(), group));
}
/**
* Groups the elements of a sequence according to a
* specified key selector function, initializing an accumulator for each
* group and adding to it each time an element with the same key is seen.
* Creates a result value from each accumulator and its key using a
* specified function.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
Function0 accumulatorInitializer,
Function2 accumulatorAdder,
final Function2 resultSelector) {
return groupBy_(new HashMap<>(), enumerable, keySelector,
accumulatorInitializer, accumulatorAdder, resultSelector);
}
/**
* Groups the elements of a sequence according to a list of
* specified key selector functions, initializing an accumulator for each
* group and adding to it each time an element with the same key is seen.
* Creates a result value from each accumulator and its key using a
* specified function.
*
* This method exists to support SQL {@code GROUPING SETS}.
* It does not correspond to any method in {@link Enumerable}.
*/
public static Enumerable groupByMultiple(
Enumerable enumerable, List> keySelectors,
Function0 accumulatorInitializer,
Function2 accumulatorAdder,
final Function2 resultSelector) {
return groupByMultiple_(
new HashMap<>(),
enumerable,
keySelectors,
accumulatorInitializer,
accumulatorAdder,
resultSelector);
}
/**
* Groups the elements of a sequence according to a
* specified key selector function, initializing an accumulator for each
* group and adding to it each time an element with the same key is seen.
* Creates a result value from each accumulator and its key using a
* specified function. Key values are compared by using a
* specified comparer.
*/
public static Enumerable groupBy(
Enumerable enumerable, Function1 keySelector,
Function0 accumulatorInitializer,
Function2 accumulatorAdder,
Function2 resultSelector,
EqualityComparer comparer) {
return groupBy_(
new WrapMap<>(
// Java 8 cannot infer return type with HashMap::new is used
() -> new HashMap, TAccumulate>(),
comparer),
enumerable,
keySelector,
accumulatorInitializer,
accumulatorAdder,
resultSelector);
}
/**
* Group keys are sorted already. Key values are compared by using a
* specified comparator. Groups the elements of a sequence according to a
* specified key selector function and initializing one accumulator at a time.
* Go over elements sequentially, adding to accumulator each time an element
* with the same key is seen. When key changes, creates a result value from the
* accumulator and then re-initializes the accumulator. In the case of NULL values
* in group keys, the comparator must be able to support NULL values by giving a
* consistent sort ordering.
*/
public static Enumerable sortedGroupBy(
Enumerable enumerable,
Function1 keySelector,
Function0 accumulatorInitializer,
Function2 accumulatorAdder,
final Function2 resultSelector,
final Comparator comparator) {
return new AbstractEnumerable() {
@Override public Enumerator enumerator() {
return new SortedAggregateEnumerator(
enumerable, keySelector, accumulatorInitializer,
accumulatorAdder, resultSelector, comparator);
}
};
}
/** Enumerator that evaluates aggregate functions over an input that is sorted
* by the group key.
*
* @param left input record type
* @param key type
* @param accumulator type
* @param result type */
private static class SortedAggregateEnumerator
implements Enumerator {
@SuppressWarnings("unused")
private final Enumerable enumerable;
private final Function1 keySelector;
private final Function0 accumulatorInitializer;
private final Function2 accumulatorAdder;
private final Function2 resultSelector;
private final Comparator comparator;
private boolean isInitialized;
private boolean isLastMoveNextFalse;
private @Nullable TAccumulate curAccumulator;
private Enumerator