All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jinq.orm.stream.NonQueryJinqStream Maven / Gradle / Ivy

package org.jinq.orm.stream;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.jinq.tuples.Pair;
import org.jinq.tuples.Tuple;
import org.jinq.tuples.Tuple3;
import org.jinq.tuples.Tuple4;
import org.jinq.tuples.Tuple5;

public class NonQueryJinqStream extends LazyWrappedStream implements JinqStream
{
   public NonQueryJinqStream(Stream wrapped)
   {
      this(wrapped, null);
   }

   protected InQueryStreamSource inQueryStreamSource;
   public NonQueryJinqStream(Stream wrapped, InQueryStreamSource inQueryStreamSource)
   {
      super(wrapped);
      this.inQueryStreamSource = inQueryStreamSource;
   }

   NonQueryJinqStream()
   {
      this((InQueryStreamSource)null);
   }
   
   NonQueryJinqStream(InQueryStreamSource inQueryStreamSource)
   {
      super();
      this.inQueryStreamSource = inQueryStreamSource;
   }
   
   protected  JinqStream wrap(Stream toWrap)
   {
      return new NonQueryJinqStream<>(toWrap, inQueryStreamSource);
   }

   
   @Override
   public  JinqStream where(Where test)
   {
      return wrap(filter(val -> { 
            try { 
               return test.where(val); 
            } catch (Exception e) {
               // Record that an exception occurred
               propagateException(test, e);
               // Throw a runtime exception to try and kill the stream?
               throw new RuntimeException(e);
            }} ));
   }

   @Override
   public  JinqStream select(Select select)
   {
      return wrap(map( val -> select.select(val) ));
   }

   @Override
   public  JinqStream> join(Join join)
   {
      // TODO: This stream should be constructed on the fly
      final Stream.Builder> streamBuilder = Stream.builder();
      forEach( left -> {
         join.join(left).forEach( right -> 
            { streamBuilder.accept(new Pair<>(left, right)); });
         });
      return wrap(streamBuilder.build());
   }

   @Override
   public  JinqStream> join(JoinWithSource join)
   {
      // TODO: This stream should be constructed on the fly
      final Stream.Builder> streamBuilder = Stream.builder();
      forEach( left -> {
         join.join(left, inQueryStreamSource).forEach( right -> 
            { streamBuilder.accept(new Pair<>(left, right)); });
         });
      return wrap(streamBuilder.build());
   }

   @Override
   public JinqStream unique()
   {
      return wrap(distinct());
   }
   
   protected  JinqStream groupToTuple(Select select, AggregateGroup[] aggregates)
   {
      Map> groups = collect(Collectors.groupingBy(in -> select.select(in)));
      final Stream.Builder streamBuilder = Stream.builder();
      for (Map.Entry> entry: groups.entrySet())
      {
         Object[] groupAggregates = new Object[aggregates.length + 1];
         for (int n = 0; n < aggregates.length; n++)
            groupAggregates[n + 1] = aggregates[n].aggregateSelect(entry.getKey(), wrap(entry.getValue().stream()));
         groupAggregates[0] = (Object)entry.getKey();
         streamBuilder.accept(Tuple.createTuple(groupAggregates));
      }
      return (JinqStream) wrap(streamBuilder.build());
   }
   
   @Override
   public  JinqStream> group(Select select, AggregateGroup aggregate)
   {
      @SuppressWarnings("unchecked")
      AggregateGroup[] aggregates = new AggregateGroup[] {
            aggregate
      };
      return groupToTuple(select, aggregates);
   }
   
   @Override
   public  JinqStream> group(
         JinqStream.Select select,
         JinqStream.AggregateGroup aggregate1,
         JinqStream.AggregateGroup aggregate2)
   {
      @SuppressWarnings("unchecked")
      AggregateGroup[] aggregates = new AggregateGroup[] {
            aggregate1, aggregate2,
      };
      return groupToTuple(select, aggregates);
   }

   @Override
   public  JinqStream> group(
         JinqStream.Select select,
         JinqStream.AggregateGroup aggregate1,
         JinqStream.AggregateGroup aggregate2,
         JinqStream.AggregateGroup aggregate3)
   {
      @SuppressWarnings("unchecked")
      AggregateGroup[] aggregates = new AggregateGroup[] {
            aggregate1, aggregate2, aggregate3,
      };
      return groupToTuple(select, aggregates);
   }

   @Override
   public  JinqStream> group(
         JinqStream.Select select,
         JinqStream.AggregateGroup aggregate1,
         JinqStream.AggregateGroup aggregate2,
         JinqStream.AggregateGroup aggregate3,
         JinqStream.AggregateGroup aggregate4)
   {
      @SuppressWarnings("unchecked")
      AggregateGroup[] aggregates = new AggregateGroup[] {
            aggregate1, aggregate2, aggregate3, aggregate4,
      };
      return groupToTuple(select, aggregates);
   }

   @SuppressWarnings("unchecked")
   private static  V genericSum(V a, V b)
   {
      if (a == null) return b;
      if (b == null) return a;
      if (!a.getClass().equals(b.getClass())) throw new IllegalArgumentException("Mismatched number types");
      if (a instanceof Long) return (V)Long.valueOf(a.longValue() + b.longValue());
      if (a instanceof Integer) return (V)Integer.valueOf(a.intValue() + b.intValue());
      if (a instanceof Double) return (V)Double.valueOf(a.doubleValue() + b.doubleValue());
      if (a instanceof BigDecimal) return (V)((BigDecimal)a).add((BigDecimal)b);
      if (a instanceof BigInteger) return (V)((BigInteger)a).add((BigInteger)b);
      throw new IllegalArgumentException("Summing unknown number types");
   }
   
   @Override
   public Long sumInteger(CollectInteger aggregate)
   {
      return reduce((Long)null, 
            (accum, val) -> genericSum(accum, (long)aggregate.aggregate(val)),
            (accum1, accum2) -> genericSum(accum1, accum2));
   }
   @Override
   public Long sumLong(CollectLong aggregate)
   {
      return reduce((Long)null, 
            (accum, val) -> genericSum(accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericSum(accum1, accum2));
   }
   @Override
   public Double sumDouble(CollectDouble aggregate)
   {
      return reduce((Double)null, 
            (accum, val) -> genericSum(accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericSum(accum1, accum2));
   }
   @Override
   public BigDecimal sumBigDecimal(CollectBigDecimal aggregate)
   {
      return reduce((BigDecimal)null, 
            (accum, val) -> genericSum(accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericSum(accum1, accum2));
   }
   @Override
   public BigInteger sumBigInteger(CollectBigInteger aggregate)
   {
      return reduce((BigInteger)null, 
            (accum, val) -> genericSum(accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericSum(accum1, accum2));
   }

   private static > V genericCompare(boolean isMax, V a, V b)
   {
      if (a == null) return b;
      if (b == null) return a;
      if (isMax)
         return a.compareTo(b) <= 0 ? b : a;
      else
         return a.compareTo(b) >= 0 ? b : a;
   }
   
   @Override
   public > V max(
         org.jinq.orm.stream.JinqStream.CollectComparable aggregate)
   {
      return reduce((V)null,
            (accum, val) -> genericCompare(true, accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericCompare(true, accum1, accum2));
   }

   @Override
   public > V min(
         org.jinq.orm.stream.JinqStream.CollectComparable aggregate)
   {
      return reduce((V)null,
            (accum, val) -> genericCompare(false, accum, aggregate.aggregate(val)),
            (accum1, accum2) -> genericCompare(false, accum1, accum2));
   }

   private static class GenericAverage
   {
      double sum = 0;
      int count = 0;
      synchronized  void accumulate(V a)
      {
         if (a == null) return;
         sum += a.doubleValue();
         count++;
      }
   }
   
   @Override
   public > Double avg(CollectNumber aggregate)
   {
      final GenericAverage avg = new GenericAverage();
      forEach(val -> avg.accumulate(aggregate.aggregate(val)));
      if (avg.count == 0) return null;
      return avg.sum / avg.count;
   }

//   @Override
//   public  U selectAggregates(AggregateSelect aggregate)
//   {
//      return aggregate.aggregateSelect(this);
//   }
//
   @Override
   public > JinqStream sortedBy(
         JinqStream.CollectComparable sortField)
   {
      return wrap(sorted(
            (o1, o2) -> sortField.aggregate(o1).compareTo(sortField.aggregate(o2))));
   }

   @Override
   public > JinqStream sortedDescendingBy(
         JinqStream.CollectComparable sortField)
   {
      return wrap(sorted(
            (o1, o2) -> -sortField.aggregate(o1).compareTo(sortField.aggregate(o2))));
   }

   @Override 
   public JinqStream skip(long n)
   {
      return wrap(super.skip(n));
   }
   
   @Override 
   public JinqStream limit(long n)
   {
      return wrap(super.limit(n));
   }

   @Override
   public T getOnlyValue()
   {
      List vals = collect(Collectors.toList());
      if (vals.size() == 1) return vals.get(0);
      throw new NoSuchElementException();
   }
   
   @Override
   public JinqStream with(T toAdd)
   {
      return wrap(
            Stream.concat(this, Stream.of(toAdd)));
   }
   
   @Override
   public List toList()
   {
      return collect(Collectors.toList());
   }
   
   @Override
   public String getDebugQueryString()
   {
      // TODO: It would be nice if this could follow the stream chain
      //    down to get the underlying query (the stream chain isn't currently
      //    recorded, so this is not possible at the moment).
      return null;
   }
   
   protected Map recordedExceptions = new HashMap<>();
   
   @Override
   public void propagateException(Object source, Throwable exception)
   {
      if (!recordedExceptions.containsKey(source))
         recordedExceptions.put(source, exception);
   }

   @Override
   public Collection getExceptions()
   {
      return recordedExceptions.values();
   }

//   @Override
//   public  U aggregate(AggregateSelect aggregate1)
//   {
//      AggregateSelect[] aggregates = new AggregateSelect[]
//            {
//               aggregate1
//            };
//      Object [] results = multiaggregate(aggregates);
//      return (U)results[0];
//   }
   
   @Override
   public  Pair aggregate(AggregateSelect aggregate1, AggregateSelect aggregate2)
   {
      @SuppressWarnings("unchecked")
      AggregateSelect[] aggregates = new AggregateSelect[]
            {
               aggregate1, aggregate2
            };
      return  multiaggregate(aggregates);
   }

   @Override
   public  Tuple3 aggregate(AggregateSelect aggregate1,
         AggregateSelect aggregate2, AggregateSelect aggregate3)
   {
      @SuppressWarnings("unchecked")
      AggregateSelect[] aggregates = new AggregateSelect[]
            {
               aggregate1, aggregate2, aggregate3
            };
      return multiaggregate(aggregates);
   }


   @Override
   public  Tuple4 aggregate(
         JinqStream.AggregateSelect aggregate1, JinqStream.AggregateSelect aggregate2,
         JinqStream.AggregateSelect aggregate3, JinqStream.AggregateSelect aggregate4)
   {
      @SuppressWarnings("unchecked")
      AggregateSelect[] aggregates = new AggregateSelect[]
            {
               aggregate1, aggregate2, aggregate3, aggregate4
            };
      return multiaggregate(aggregates);
   }

   @Override
   public  Tuple5 aggregate(
         JinqStream.AggregateSelect aggregate1, JinqStream.AggregateSelect aggregate2,
         JinqStream.AggregateSelect aggregate3, JinqStream.AggregateSelect aggregate4,
         JinqStream.AggregateSelect aggregate5)
   {
      @SuppressWarnings("unchecked")
      AggregateSelect[] aggregates = new AggregateSelect[]
            {
               aggregate1, aggregate2, aggregate3, aggregate4, aggregate5
            };
      return multiaggregate(aggregates);
   }

    U multiaggregate(AggregateSelect[] aggregates)
   {
      IteratorTee tee = new IteratorTee<>(this, aggregates.length);
      
      // Run each aggregator in a separate thread so that we can
      // use producer-consumer queues and hence avoid using too much
      // memory.
      Thread [] aggregateThreads = new Thread[aggregates.length];
      final Object [] results = new Object[aggregates.length];
      for (int n = 0; n < aggregates.length; n++)
      {
         final int idx = n;
         final AggregateSelect fn = aggregates[idx];
         aggregateThreads[n] = new Thread() {
            @Override public void run()
            {
               JinqStream stream = 
                     wrap(StreamSupport.stream(
                           Spliterators.spliteratorUnknownSize(
                                 tee.createIterator(idx), 
                                 Spliterator.CONCURRENT), 
                           false));
               results[idx] = fn.aggregateSelect(stream);
            }
         };
         aggregateThreads[n].start();
      }
      for (int n = 0; n < aggregateThreads.length; n++)
      {
         try {
            aggregateThreads[n].join();
         } catch (InterruptedException e)
         {
            Thread.currentThread().interrupt();
         }
      }
      return Tuple.createTuple(results);
   }
   
   public static class IteratorTee 
   {
      static final int MAX_QUEUE_SIZE = 100;
      final Object DONE = new Object();
      
      ArrayBlockingQueue[] outputQueues;
      Stream inputStream;
      public IteratorTee(Stream inputStream, int size)
      {
         this.inputStream = inputStream;
         outputQueues = new ArrayBlockingQueue[size];
         for (int n = 0; n < size; n++)
            outputQueues[n] = new ArrayBlockingQueue<>(MAX_QUEUE_SIZE);
      }
      
      boolean isStarted = false;
      synchronized void startInputStreamPump()
      {
         if (isStarted) return;
         isStarted = true;
         new Thread() {
            @Override public void run()
            {
               inputStream.forEach( val -> {
                  for (int n = 0; n < outputQueues.length; n++)
                  {
                     try {
                        outputQueues[n].put(val);
                     } catch (InterruptedException e)
                     {
                        Thread.currentThread().interrupt();
                     }
                  }
               });
               try {
                  for (int n = 0; n < outputQueues.length; n++)
                     outputQueues[n].put(DONE);
               } catch (InterruptedException e)
               {
                  Thread.currentThread().interrupt();
               }
            }
         }.start();
      }
      public Iterator createIterator(int idx)
      {
         return new NextOnlyIterator()
               {
                  @Override
                  protected void generateNext()
                  {
                     startInputStreamPump();
                     Object taken = DONE;
                     try {
                        taken = outputQueues[idx].take();
                     } catch (InterruptedException e)
                     {
                        Thread.currentThread().interrupt();
                     }
                     if (taken == DONE)
                        noMoreElements();
                     else
                        nextElement((T)taken);
                  }
               };
      }
   }
   
   @Override
   public JinqStream setHint(String name, Object value)
   {
      return this;
   }
}