All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.cql3.functions.AggregateFcts Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.cql3.functions;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.List;

import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.ByteType;
import org.apache.cassandra.db.marshal.CounterColumnType;
import org.apache.cassandra.db.marshal.DecimalType;
import org.apache.cassandra.db.marshal.DoubleType;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.marshal.Int32Type;
import org.apache.cassandra.db.marshal.IntegerType;
import org.apache.cassandra.db.marshal.LongType;
import org.apache.cassandra.db.marshal.ShortType;

/**
 * Factory methods for aggregate functions.
 */
public abstract class AggregateFcts
{
    /**
     * Checks if the specified function is the count rows (e.g. COUNT(*) or COUNT(1)) function.
     *
     * @param function the function to check
     * @return true if the specified function is the count rows one, false otherwise.
     */
    public static boolean isCountRows(Function function)
    {
        return function == countRowsFunction;
    }

    /**
     * The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1)
     * is specified.
     */
    public static final AggregateFunction countRowsFunction =
            new NativeAggregateFunction("countRows", LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private long count;

                        public void reset()
                        {
                            count = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((LongType) returnType()).decompose(Long.valueOf(count));
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            count++;
                        }
                    };
                }
            };

    /**
     * The SUM function for decimal values.
     */
    public static final AggregateFunction sumFunctionForDecimal =
            new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance)
            {
                @Override
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigDecimal sum = BigDecimal.ZERO;

                        public void reset()
                        {
                            sum = BigDecimal.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((DecimalType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value));
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The AVG function for decimal values.
     */
    public static final AggregateFunction avgFunctionForDecimal =
            new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigDecimal sum = BigDecimal.ZERO;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = BigDecimal.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            if (count == 0)
                                return ((DecimalType) returnType()).decompose(BigDecimal.ZERO);

                            return ((DecimalType) returnType()).decompose(sum.divide(BigDecimal.valueOf(count)));
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value));
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The SUM function for varint values.
     */
    public static final AggregateFunction sumFunctionForVarint =
            new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigInteger sum = BigInteger.ZERO;

                        public void reset()
                        {
                            sum = BigInteger.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((IntegerType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            BigInteger number = ((BigInteger) argTypes().get(0).compose(value));
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The AVG function for varint values.
     */
    public static final AggregateFunction avgFunctionForVarint =
            new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private BigInteger sum = BigInteger.ZERO;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = BigInteger.ZERO;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            if (count == 0)
                                return ((IntegerType) returnType()).decompose(BigInteger.ZERO);

                            return ((IntegerType) returnType()).decompose(sum.divide(BigInteger.valueOf(count)));
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            BigInteger number = ((BigInteger) argTypes().get(0).compose(value));
                            sum = sum.add(number);
                        }
                    };
                }
            };

    /**
     * The SUM function for byte values (tinyint).
     */
    public static final AggregateFunction sumFunctionForByte =
            new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private byte sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((ByteType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.byteValue();
                        }
                    };
                }
            };

    /**
     * AVG function for byte values (tinyint).
     */
    public static final AggregateFunction avgFunctionForByte =
            new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private byte sum;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            int avg = count == 0 ? 0 : sum / count;

                            return ((ByteType) returnType()).decompose((byte) avg);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.byteValue();
                        }
                    };
                }
            };

    /**
     * The SUM function for short values (smallint).
     */
    public static final AggregateFunction sumFunctionForShort =
            new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private short sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((ShortType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.shortValue();
                        }
                    };
                }
            };

    /**
     * AVG function for for short values (smallint).
     */
    public static final AggregateFunction avgFunctionForShort =
            new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private short sum;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            int avg = count == 0 ? 0 : sum / count;

                            return ((ShortType) returnType()).decompose((short) avg);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.shortValue();
                        }
                    };
                }
            };

    /**
     * The SUM function for int32 values.
     */
    public static final AggregateFunction sumFunctionForInt32 =
            new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private int sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((Int32Type) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.intValue();
                        }
                    };
                }
            };

    /**
     * AVG function for int32 values.
     */
    public static final AggregateFunction avgFunctionForInt32 =
            new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private int sum;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            int avg = count == 0 ? 0 : sum / count;

                            return ((Int32Type) returnType()).decompose(avg);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.intValue();
                        }
                    };
                }
            };

    /**
     * The SUM function for long values.
     */
    public static final AggregateFunction sumFunctionForLong =
            new NativeAggregateFunction("sum", LongType.instance, LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new LongSumAggregate();
                }
            };

    /**
     * AVG function for long values.
     */
    public static final AggregateFunction avgFunctionForLong =
            new NativeAggregateFunction("avg", LongType.instance, LongType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new LongAvgAggregate();
                }
            };

    /**
     * The SUM function for float values.
     */
    public static final AggregateFunction sumFunctionForFloat =
            new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private float sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((FloatType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.floatValue();
                        }
                    };
                }
            };

    /**
     * AVG function for float values.
     */
    public static final AggregateFunction avgFunctionForFloat =
            new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private float sum;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            float avg = count == 0 ? 0 : sum / count;

                            return ((FloatType) returnType()).decompose(avg);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.floatValue();
                        }
                    };
                }
            };

    /**
     * The SUM function for double values.
     */
    public static final AggregateFunction sumFunctionForDouble =
            new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private double sum;

                        public void reset()
                        {
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            return ((DoubleType) returnType()).decompose(sum);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.doubleValue();
                        }
                    };
                }
            };

    /**
     * AVG function for double values.
     */
    public static final AggregateFunction avgFunctionForDouble =
            new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new Aggregate()
                    {
                        private double sum;

                        private int count;

                        public void reset()
                        {
                            count = 0;
                            sum = 0;
                        }

                        public ByteBuffer compute(int protocolVersion)
                        {
                            double avg = count == 0 ? 0 : sum / count;

                            return ((DoubleType) returnType()).decompose(avg);
                        }

                        public void addInput(int protocolVersion, List values)
                        {
                            ByteBuffer value = values.get(0);

                            if (value == null)
                                return;

                            count++;
                            Number number = ((Number) argTypes().get(0).compose(value));
                            sum += number.doubleValue();
                        }
                    };
                }
            };

    /**
     * The SUM function for counter column values.
     */
    public static final AggregateFunction sumFunctionForCounter =
            new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new LongSumAggregate();
                }
            };

    /**
     * AVG function for counter column values.
     */
    public static final AggregateFunction avgFunctionForCounter =
            new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance)
            {
                public Aggregate newAggregate()
                {
                    return new LongAvgAggregate();
                }
            };

    /**
     * Creates a MAX function for the specified type.
     *
     * @param inputType the function input and output type
     * @return a MAX function for the specified type.
     */
    public static AggregateFunction makeMaxFunction(final AbstractType inputType)
    {
        return new NativeAggregateFunction("max", inputType, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private ByteBuffer max;

                    public void reset()
                    {
                        max = null;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return max;
                    }

                    public void addInput(int protocolVersion, List values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        if (max == null || returnType().compare(max, value) < 0)
                            max = value;
                    }
                };
            }
        };
    }

    /**
     * Creates a MIN function for the specified type.
     *
     * @param inputType the function input and output type
     * @return a MIN function for the specified type.
     */
    public static AggregateFunction makeMinFunction(final AbstractType inputType)
    {
        return new NativeAggregateFunction("min", inputType, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private ByteBuffer min;

                    public void reset()
                    {
                        min = null;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return min;
                    }

                    public void addInput(int protocolVersion, List values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        if (min == null || returnType().compare(min, value) > 0)
                            min = value;
                    }
                };
            }
        };
    }

    /**
     * Creates a COUNT function for the specified type.
     *
     * @param inputType the function input type
     * @return a COUNT function for the specified type.
     */
    public static AggregateFunction makeCountFunction(AbstractType inputType)
    {
        return new NativeAggregateFunction("count", LongType.instance, inputType)
        {
            public Aggregate newAggregate()
            {
                return new Aggregate()
                {
                    private long count;

                    public void reset()
                    {
                        count = 0;
                    }

                    public ByteBuffer compute(int protocolVersion)
                    {
                        return ((LongType) returnType()).decompose(count);
                    }

                    public void addInput(int protocolVersion, List values)
                    {
                        ByteBuffer value = values.get(0);

                        if (value == null)
                            return;

                        count++;
                    }
                };
            }
        };
    }

    private static class LongSumAggregate implements AggregateFunction.Aggregate
    {
        private long sum;

        public void reset()
        {
            sum = 0;
        }

        public ByteBuffer compute(int protocolVersion)
        {
            return LongType.instance.decompose(sum);
        }

        public void addInput(int protocolVersion, List values)
        {
            ByteBuffer value = values.get(0);

            if (value == null)
                return;

            Number number = LongType.instance.compose(value);
            sum += number.longValue();
        }
    }

    private static class LongAvgAggregate implements AggregateFunction.Aggregate
    {
        private long sum;

        private int count;

        public void reset()
        {
            count = 0;
            sum = 0;
        }

        public ByteBuffer compute(int protocolVersion)
        {
            long avg = count == 0 ? 0 : sum / count;

            return LongType.instance.decompose(avg);
        }

        public void addInput(int protocolVersion, List values)
        {
            ByteBuffer value = values.get(0);

            if (value == null)
                return;

            count++;
            Number number = LongType.instance.compose(value);
            sum += number.longValue();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy