All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.rx.core.NQuery Maven / Gradle / Ivy

There is a newer version: 3.0.0
Show newest version
package org.rx.core;

import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.netty.util.internal.ThreadLocalRandom;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.IterableUtils;
import org.apache.commons.collections4.IteratorUtils;
import org.rx.annotation.ErrorCode;
import org.rx.bean.$;
import org.rx.bean.Decimal;
import org.rx.exception.ApplicationException;
import org.rx.exception.InvalidException;
import org.rx.util.function.*;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.*;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static org.rx.bean.$.$;
import static org.rx.core.Constants.NON_RAW_TYPES;
import static org.rx.core.Constants.NON_UNCHECKED;
import static org.rx.core.Extends.*;

/**
 * https://msdn.microsoft.com/en-us/library/bb738550(v=vs.110).aspx
 */
@Slf4j
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public final class NQuery implements Iterable, Serializable {
    private static final long serialVersionUID = -7167070585936243198L;

    //region staticMembers
    public static boolean couldBeCollection(Class type) {
        return Iterable.class.isAssignableFrom(type)
                || type.isArray()
                || Iterator.class.isAssignableFrom(type);
    }

    @SuppressWarnings(NON_UNCHECKED)
    @ErrorCode
    public static  List asList(@NonNull Object collection, boolean throwOnFail) {
        Iterable iterable;
        if ((iterable = as(collection, Iterable.class)) != null) {
            return IterableUtils.toList(iterable);
        }

        Class type = collection.getClass();
        if (type.isArray()) {
            int length = Array.getLength(collection);
            List list = new ArrayList<>(length);
            for (int i = 0; i < length; i++) {
                list.add((T) Array.get(collection, i));
            }
            return list;
        }

        Iterator iterator;
        if ((iterator = as(collection, Iterator.class)) != null) {
            return IteratorUtils.toList(iterator);
        }

        if (throwOnFail) {
            throw new ApplicationException(values(type.getSimpleName()));
        }
        ArrayList list = new ArrayList<>();
        list.add((T) collection);
        return list;
    }

    public static  NQuery ofCollection(Object collection) {
        return new NQuery<>(asList(collection, true), false);
    }

    public static  NQuery of(T one) {
        return of(Arrays.toList(one));
    }

    @SafeVarargs
    public static  NQuery of(T... array) {
        return of(Arrays.toList(array));
    }

    public static  NQuery of(@NonNull Stream stream) {
        return of(stream::iterator, stream.isParallel());
    }

    public static  NQuery of(Iterable iterable) {
        return of(iterable, false);
    }

    public static  NQuery of(Iterable iterable, boolean isParallel) {
        if (iterable == null) {
            iterable = Collections.emptyList();
        }
        return new NQuery<>(iterable, isParallel);
    }
    //endregion

    //region Member
    private final Iterable data;
    private final boolean isParallel;

    public Stream stream() {
        return StreamSupport.stream(data.spliterator(), isParallel);
    }

    private  List newList() {
        int count = count();
        return isParallel ? newConcurrentList(count, false) : new ArrayList<>(count);
    }

    private  Set newSet() {
        int count = count();
        return isParallel ? Collections.synchronizedSet(new LinkedHashSet<>(count)) : new LinkedHashSet<>(count);
    }

    private  Map newMap() {
        int count = count();
        return isParallel ? Collections.synchronizedMap(new LinkedHashMap<>(count)) : new LinkedHashMap<>(count);
    }

    private  Stream newStream(Iterable iterable) {
        return StreamSupport.stream(iterable.spliterator(), isParallel);
    }

    private  NQuery me(Iterable set) {
        return of(set, isParallel);
    }

    private  NQuery me(Stream stream) {
        return me(stream.collect(Collectors.toList()));
    }

    @SneakyThrows
    private NQuery me(EachFunc func, String prevMethod) {
        if (isParallel) {
            log.warn("Not supported parallel {}", prevMethod);
        }

        Spliterator spliterator = data.spliterator();
        Stream r = StreamSupport.stream(new Spliterators.AbstractSpliterator(spliterator.estimateSize(), spliterator.characteristics()) {
            final AtomicBoolean breaker = new AtomicBoolean();
            final AtomicInteger counter = new AtomicInteger();

            @Override
            public boolean tryAdvance(Consumer action) {
                return spliterator.tryAdvance(p -> {
                    int flags = func.each(p, counter.getAndIncrement());
                    if ((flags & EachFunc.ACCEPT) == EachFunc.ACCEPT) {
                        action.accept(p);
                    }
                    if ((flags & EachFunc.BREAK) == EachFunc.BREAK) {
                        breaker.set(true);
                    }
                }) && !breaker.get();
            }
        }, isParallel);
        return me(r);
    }

    @FunctionalInterface
    interface EachFunc {
        int NONE = 0;
        int ACCEPT = 1;
        int BREAK = 1 << 1;

        int each(T item, int index);
    }
    //endregion

    @Override
    public Iterator iterator() {
        return data.iterator();
    }

    @Override
    public void forEach(Consumer action) {
        stream().forEach(action);
    }

    public void forEachOrdered(Consumer action) {
        stream().forEachOrdered(action);
    }

    public  NQuery select(BiFunc selector) {
        return me(stream().map(selector.toFunction()));
    }

    public  NQuery select(BiFuncWithIndex selector) {
        AtomicInteger counter = new AtomicInteger();
        return me(stream().map(p -> {
            try {
                return selector.invoke(p, counter.getAndIncrement());
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        }));
    }

    public  NQuery selectMany(BiFunc> selector) {
        return me(stream().flatMap(p -> newStream(sneakyInvoke(() -> selector.invoke(p)))));
    }

    public  NQuery selectMany(BiFuncWithIndex> selector) {
        AtomicInteger counter = new AtomicInteger();
        return me(stream().flatMap(p -> newStream(sneakyInvoke(() -> selector.invoke(p, counter.getAndIncrement())))));
    }

    public NQuery where(PredicateFunc predicate) {
        return me(stream().filter(p -> {
            try {
                return predicate.invoke(p);
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        }));
    }

    public NQuery where(PredicateFuncWithIndex predicate) {
        AtomicInteger counter = new AtomicInteger();
        return me(stream().filter(p -> {
            try {
                return predicate.invoke(p, counter.getAndIncrement());
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        }));
    }

    public  NQuery join(Iterable inner, BiPredicate keySelector, BiFunction resultSelector) {
        return me(stream().flatMap(p -> newStream(inner).filter(p2 -> keySelector.test(p, p2)).map(p3 -> resultSelector.apply(p, p3))));
    }

    public  NQuery join(BiFunc innerSelector, BiPredicate keySelector, BiFunction resultSelector) {
        return join(stream().map(innerSelector.toFunction()).collect(Collectors.toList()), keySelector, resultSelector);
    }

    public  NQuery joinMany(BiFunc> innerSelector, BiPredicate keySelector, BiFunction resultSelector) {
        return join(stream().flatMap(p -> newStream(sneakyInvoke(() -> innerSelector.invoke(p)))).collect(Collectors.toList()), keySelector, resultSelector);
    }

    public  NQuery leftJoin(Iterable inner, BiPredicate keySelector, BiFunction resultSelector) {
        return me(stream().flatMap(p -> {
            if (!newStream(inner).anyMatch(p2 -> keySelector.test(p, p2))) {
                return Stream.of(resultSelector.apply(p, null));
            }
            return newStream(inner).filter(p2 -> keySelector.test(p, p2)).map(p3 -> resultSelector.apply(p, p3));
        }));
    }

    public  NQuery leftJoin(BiFunc innerSelector, BiPredicate keySelector, BiFunction resultSelector) {
        return leftJoin(stream().map(innerSelector.toFunction()).collect(Collectors.toList()), keySelector, resultSelector);
    }

    public  NQuery leftJoinMany(BiFunc> innerSelector, BiPredicate keySelector, BiFunction resultSelector) {
        return leftJoin(stream().flatMap(p -> newStream(sneakyInvoke(() -> innerSelector.invoke(p)))).collect(Collectors.toList()), keySelector, resultSelector);
    }

    public boolean all(PredicateFunc predicate) {
        return stream().allMatch(predicate.toPredicate());
    }

    public boolean any() {
        return stream().findAny().isPresent();
    }

    public boolean any(PredicateFunc predicate) {
        return stream().anyMatch(predicate.toPredicate());
    }

    public boolean contains(T item) {
        return stream().anyMatch(p -> p.equals(item));
    }

    public NQuery concat(Iterable set) {
        return me(Stream.concat(stream(), newStream(set)));
    }

    public NQuery distinct() {
        return me(stream().distinct());
    }

    public NQuery except(Iterable set) {
        return me(stream().filter(p -> !newStream(set).anyMatch(p2 -> p2.equals(p))));
    }

    public NQuery intersection(Iterable set) {
        return me(stream().filter(p -> newStream(set).anyMatch(p2 -> p2.equals(p))));
    }

    public NQuery difference(Iterable set) {
        return NQuery.of(CollectionUtils.disjunction(this, set));
    }

    public NQuery union(Iterable set) {
        return NQuery.of(CollectionUtils.union(this, set));
    }

    public NQuery orderByRand() {
        return me(stream().sorted(getComparator(p -> ThreadLocalRandom.current().nextInt(0, 100))));
    }

    public  NQuery orderBy(BiFunc keySelector) {
//        return me(stream().sorted(Comparator.nullsLast(Comparator.comparing((Function) keySelector.toFunction()))));
        return me(stream().sorted(getComparator(keySelector)));
    }

    @SuppressWarnings(NON_RAW_TYPES)
    public static  Comparator getComparator(BiFunc keySelector) {
        return (p1, p2) -> {
            try {
                Comparable c1 = as(keySelector.invoke(p1), Comparable.class);
                Comparable c2 = as(keySelector.invoke(p2), Comparable.class);
                if (c1 == null || c2 == null) {
                    return c1 == null ? (c2 == null ? 0 : 1) : -1;
                }
                return c1.compareTo(c2);
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        };
    }

    public  NQuery orderByDescending(BiFunc keySelector) {
        return me(stream().sorted(getComparator(keySelector).reversed()));
    }

    public NQuery orderByMany(BiFunc> keySelector) {
        return me(stream().sorted(getComparatorMany(keySelector)));
    }

    @SuppressWarnings(NON_RAW_TYPES)
    public static  Comparator getComparatorMany(BiFunc> keySelector) {
        return (p1, p2) -> {
            try {
                List k1s = keySelector.invoke(p1), k2s = keySelector.invoke(p2);
                for (int i = 0; i < k1s.size(); i++) {
                    Comparable c1 = as(k1s.get(i), Comparable.class);
                    Comparable c2 = as(k2s.get(i), Comparable.class);
                    if (c1 == null || c2 == null) {
                        return c1 == null ? (c2 == null ? 0 : 1) : -1;
                    }
                    int r = c1.compareTo(c2);
                    if (r == 0) {
                        continue;
                    }
                    return r;
                }
                return 0;
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        };
    }

    public NQuery orderByDescendingMany(BiFunc> keySelector) {
        return me(stream().sorted(getComparatorMany(keySelector).reversed()));
    }

    @SuppressWarnings(NON_UNCHECKED)
    public NQuery reverse() {
        try {
            return me(stream().sorted((Comparator) Comparator.reverseOrder()));
        } catch (Exception e) {
            log.warn("reverse fail, {}", e.getMessage());
            List list = toList();
            Collections.reverse(list);
            return me(list);
        }
    }

    public  NQuery groupBy(BiFunc keySelector, BiFunction, TR> resultSelector) {
        Map> map = stream().collect(Collectors.groupingBy(keySelector.toFunction(), this::newMap, Collectors.toList()));
        List result = newList();
        for (Map.Entry> entry : map.entrySet()) {
            result.add(resultSelector.apply(entry.getKey(), of(entry.getValue())));
        }
        return me(result);
    }

    public  Map groupByIntoMap(BiFunc keySelector, BiFunction, TR> resultSelector) {
        Map> map = stream().collect(Collectors.groupingBy(keySelector.toFunction(), this::newMap, Collectors.toList()));
        Map result = newMap();
        for (Map.Entry> entry : map.entrySet()) {
            result.put(entry.getKey(), resultSelector.apply(entry.getKey(), NQuery.of(entry.getValue())));
        }
        return result;
    }

    public  NQuery groupByMany(BiFunc> keySelector, BiFunction, NQuery, TR> resultSelector) {
        Map, List> map = stream().collect(Collectors.groupingBy(keySelector.toFunction(), this::newMap, Collectors.toList()));
        List result = newList();
        for (Map.Entry, List> entry : map.entrySet()) {
            result.add(resultSelector.apply(entry.getKey(), of(entry.getValue())));
        }
        return me(result);
    }

    public Double average(ToDoubleFunction selector) {
        OptionalDouble q = stream().mapToDouble(selector).average();
        return q.isPresent() ? q.getAsDouble() : null;
    }

    public int count() {
        if (data instanceof Collection) {
            return ((Collection) data).size();
        }
        return (int) stream().count();
    }

    public int count(PredicateFunc predicate) {
        return (int) stream().filter(p -> {
            try {
                return predicate.invoke(p);
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        }).count();
    }

    public T max() {
        return max(stream());
    }

    @SuppressWarnings(NON_UNCHECKED)
    private  TR max(Stream stream) {
        return stream.max((Comparator) Comparator.naturalOrder()).orElse(null);
    }

    public  TR max(BiFunc selector) {
        return max(stream().map(selector.toFunction()));
    }

    public T min() {
        return min(stream());
    }

    @SuppressWarnings(NON_UNCHECKED)
    private  TR min(Stream stream) {
        return stream.min((Comparator) Comparator.naturalOrder()).orElse(null);
    }

    public  TR min(BiFunc selector) {
        return min(stream().map(selector.toFunction()));
    }

    public double sum(ToDoubleFunction selector) {
        return stream().mapToDouble(selector).sum();
    }

    public Decimal sumDecimal(BiFunc selector) {
        $ sumValue = $(Decimal.ZERO);
        stream().forEach(p -> {
            try {
                sumValue.v = sumValue.v.add(selector.invoke(p));
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        });
        return sumValue.v;
    }

    @SuppressWarnings(NON_UNCHECKED)
    public  NQuery cast() {
        return (NQuery) this;
    }

    @SuppressWarnings(NON_UNCHECKED)
    public  NQuery ofType(Class type) {
        return where(p -> Reflects.isInstance(p, type)).select(p -> (TR) p);
    }

    public T first() {
        return stream().findFirst().get();
    }

    public T first(PredicateFunc predicate) {
        return stream().filter(predicate.toPredicate()).findFirst().get();
    }

    public T firstOrDefault() {
        return firstOrDefault((T) null);
    }

    public T firstOrDefault(T defaultValue) {
        return stream().findFirst().orElse(defaultValue);
    }

    public T firstOrDefault(Supplier defaultValue) {
        return stream().findFirst().orElseGet(defaultValue);
    }

    public T firstOrDefault(PredicateFunc predicate) {
        return stream().filter(predicate.toPredicate()).findFirst().orElse(null);
    }

    public T last() {
        return Streams.findLast(stream()).get();
    }

    public T last(PredicateFunc predicate) {
        return Streams.findLast(stream().filter(predicate.toPredicate())).get();
    }

    public T lastOrDefault() {
        return lastOrDefault((T) null);
    }

    public T lastOrDefault(T defaultValue) {
        return Streams.findLast(stream()).orElse(defaultValue);
    }

    public T lastOrDefault(Supplier defaultValue) {
        return Streams.findLast(stream()).orElseGet(defaultValue);
    }

    public T lastOrDefault(PredicateFunc predicate) {
        return Streams.findLast(stream().filter(predicate.toPredicate())).orElse(null);
    }

    public T single() {
        return single(null);
    }

    @ErrorCode
    public T single(PredicateFunc predicate) {
        Stream stream = stream();
        if (predicate != null) {
            stream = stream.filter(predicate.toPredicate());
        }
        List list = stream.limit(2).collect(Collectors.toList());
        if (list.size() != 1) {
            throw new ApplicationException(values(list.size()));
        }
        return list.get(0);
    }

    public T singleOrDefault() {
        return singleOrDefault(null);
    }

    @ErrorCode
    public T singleOrDefault(PredicateFunc predicate) {
        Stream stream = stream();
        if (predicate != null) {
            stream = stream.filter(predicate.toPredicate());
        }
        List list = stream.limit(2).collect(Collectors.toList());
        if (list.size() > 1) {
            throw new ApplicationException(values(list.size()));
        }
        return list.isEmpty() ? null : list.get(0);
    }

    public NQuery skip(int count) {
        return me(stream().skip(count));
    }

    public NQuery skipWhile(PredicateFunc predicate) {
        return skipWhile((p, i) -> predicate.invoke(p));
    }

    public NQuery skipWhile(PredicateFuncWithIndex predicate) {
        AtomicBoolean doAccept = new AtomicBoolean();
        return me((p, i) -> {
            int flags = EachFunc.NONE;
            if (doAccept.get()) {
                flags |= EachFunc.ACCEPT;
                return flags;
            }
            if (!sneakyInvoke(() -> predicate.invoke(p, i))) {
                doAccept.set(true);
                flags |= EachFunc.ACCEPT;
            }
            return flags;
        }, "skipWhile");
    }

    public NQuery take(int count) {
        return me(stream().limit(count));
    }

    public NQuery takeWhile(PredicateFunc predicate) {
        return takeWhile((p, i) -> predicate.invoke(p));
    }

    public NQuery takeWhile(PredicateFuncWithIndex predicate) {
        return me((p, i) -> {
            int flags = EachFunc.NONE;
            if (!sneakyInvoke(() -> predicate.invoke(p, i))) {
                flags |= EachFunc.BREAK;
                return flags;
            }
            flags |= EachFunc.ACCEPT;
            return flags;
        }, "takeWhile");
    }

    public String toJoinString(String delimiter, BiFunc selector) {
        return String.join(delimiter, select(selector));
    }

    @SuppressWarnings(NON_UNCHECKED)
    public T[] toArray() {
        List result = toList();
        Class type = null;
        for (T t : result) {
            if (t == null) {
                continue;
            }
            type = t.getClass();
            break;
        }
        if (type == null) {
            type = Object.class;
        }
        T[] array = (T[]) Array.newInstance(type, result.size());
        result.toArray(array);
        return array;
    }

    @SuppressWarnings(NON_UNCHECKED)
    public T[] toArray(Class type) {
        List result = toList();
        T[] array = (T[]) Array.newInstance(type, result.size());
        result.toArray(array);
        return array;
    }

    public List toList() {
        List result = newList();
        Iterables.addAll(result, data);
        return result;
    }

    public Set toSet() {
        Set result = newSet();
        Iterables.addAll(result, data);
        return result;
    }

    public  Map toMap(BiFunc keySelector) {
        return toMap(keySelector, p -> p);
    }

    //Collectors.toMap 会校验value为null的情况
    @SneakyThrows
    public  Map toMap(BiFunc keySelector, BiFunc resultSelector) {
        Map result = newMap();
        stream().forEach(item -> {
            try {
                result.put(keySelector.invoke(item), resultSelector.invoke(item));
            } catch (Throwable e) {
                throw InvalidException.sneaky(e);
            }
        });
        return result;
    }
}