org.apache.cassandra.cql3.functions.AggregateFcts Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of cassandra-all Show documentation
Show all versions of cassandra-all Show documentation
The Apache Cassandra Project develops a highly scalable second-generation distributed database, bringing together Dynamo's fully distributed design and Bigtable's ColumnFamily-based data model.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.cql3.functions;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.List;
import org.apache.cassandra.db.marshal.*;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;
/**
* Factory methods for aggregate functions.
*/
public abstract class AggregateFcts
{
public static void addFunctionsTo(NativeFunctions functions)
{
functions.add(countRowsFunction);
// sum for primitives
functions.add(sumFunctionForByte);
functions.add(sumFunctionForShort);
functions.add(sumFunctionForInt32);
functions.add(sumFunctionForLong);
functions.add(sumFunctionForFloat);
functions.add(sumFunctionForDouble);
functions.add(sumFunctionForDecimal);
functions.add(sumFunctionForVarint);
functions.add(sumFunctionForCounter);
// avg for primitives
functions.add(avgFunctionForByte);
functions.add(avgFunctionForShort);
functions.add(avgFunctionForInt32);
functions.add(avgFunctionForLong);
functions.add(avgFunctionForFloat);
functions.add(avgFunctionForDouble);
functions.add(avgFunctionForDecimal);
functions.add(avgFunctionForVarint);
functions.add(avgFunctionForCounter);
// count for all types
functions.add(makeCountFunction(BytesType.instance));
// max for all types
functions.add(new FunctionFactory("max", FunctionParameter.anyType(true))
{
@Override
protected NativeFunction doGetOrCreateFunction(List> argTypes, AbstractType> receiverType)
{
AbstractType> type = argTypes.get(0);
return type.isCounter() ? maxFunctionForCounter : makeMaxFunction(type);
}
});
// min for all types
functions.add(new FunctionFactory("min", FunctionParameter.anyType(true))
{
@Override
protected NativeFunction doGetOrCreateFunction(List> argTypes, AbstractType> receiverType)
{
AbstractType> type = argTypes.get(0);
return type.isCounter() ? minFunctionForCounter : makeMinFunction(type);
}
});
}
/**
* The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1)
* is specified.
*/
public static final CountRowsFunction countRowsFunction = new CountRowsFunction(false);
public static class CountRowsFunction extends NativeAggregateFunction
{
private CountRowsFunction(boolean useLegacyName)
{
super(useLegacyName ? "countRows" : "count_rows", LongType.instance);
}
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private long count;
public void reset()
{
count = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return LongType.instance.decompose(count);
}
@Override
public void addInput(Arguments arguments)
{
count++;
}
};
}
@Override
public String columnName(List columnNames)
{
return "count";
}
@Override
public NativeFunction withLegacyName()
{
return new CountRowsFunction(true);
}
}
/**
* The SUM function for decimal values.
*/
public static final NativeAggregateFunction sumFunctionForDecimal =
new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance)
{
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigDecimal sum = BigDecimal.ZERO;
public void reset()
{
sum = BigDecimal.ZERO;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((DecimalType) returnType()).decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
BigDecimal number = arguments.get(0);
if (number == null)
return;
sum = sum.add(number);
}
};
}
};
/**
* The AVG function for decimal values.
*
* The average of an empty value set returns zero.
*/
public static final NativeAggregateFunction avgFunctionForDecimal =
new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigDecimal avg = BigDecimal.ZERO;
private int count;
public void reset()
{
count = 0;
avg = BigDecimal.ZERO;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return DecimalType.instance.decompose(avg);
}
@Override
public void addInput(Arguments arguments)
{
BigDecimal number = arguments.get(0);
if (number == null)
return;
count++;
// avg = avg + (value - sum) / count.
avg = avg.add(number.subtract(avg).divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN));
}
};
}
};
/**
* The SUM function for varint values.
*/
public static final NativeAggregateFunction sumFunctionForVarint =
new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigInteger sum = BigInteger.ZERO;
public void reset()
{
sum = BigInteger.ZERO;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((IntegerType) returnType()).decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
BigInteger number = arguments.get(0);
if (number == null)
return;
sum = sum.add(number);
}
};
}
};
/**
* The AVG function for varint values.
*
* The average of an empty value set returns zero. The returned value is of the same type as the input values,
* so the returned average won't have a decimal part.
*/
public static final NativeAggregateFunction avgFunctionForVarint =
new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private BigInteger sum = BigInteger.ZERO;
private int count;
public void reset()
{
count = 0;
sum = BigInteger.ZERO;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
if (count == 0)
return IntegerType.instance.decompose(BigInteger.ZERO);
return IntegerType.instance.decompose(sum.divide(BigInteger.valueOf(count)));
}
@Override
public void addInput(Arguments arguments)
{
BigInteger number = arguments.get(0);
if (number == null)
return;
count++;
sum = sum.add(number);
}
};
}
};
/**
* The SUM function for byte values (tinyint).
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForByte =
new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private byte sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((ByteType) returnType()).decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
sum += number.byteValue();
}
};
}
};
/**
* AVG function for byte values (tinyint).
*
* The average of an empty value set returns zero. The returned value is of the same type as the input values,
* so the returned average won't have a decimal part.
*/
public static final NativeAggregateFunction avgFunctionForByte =
new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return ByteType.instance.decompose((byte) computeInternal());
}
};
}
};
/**
* The SUM function for short values (smallint).
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForShort =
new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private short sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((ShortType) returnType()).decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
sum += number.shortValue();
}
};
}
};
/**
* AVG function for for short values (smallint).
*
* The average of an empty value set returns zero. The returned value is of the same type as the input values,
* so the returned average won't have a decimal part.
*/
public static final NativeAggregateFunction avgFunctionForShort =
new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ShortType.instance.decompose((short) computeInternal());
}
};
}
};
/**
* The SUM function for int32 values.
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForInt32 =
new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private int sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((Int32Type) returnType()).decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
sum += number.intValue();
}
};
}
};
/**
* AVG function for int32 values.
*
* The average of an empty value set returns zero. The returned value is of the same type as the input values,
* so the returned average won't have a decimal part.
*/
public static final NativeAggregateFunction avgFunctionForInt32 =
new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return Int32Type.instance.decompose((int) computeInternal());
}
};
}
};
/**
* The SUM function for long values.
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForLong =
new NativeAggregateFunction("sum", LongType.instance, LongType.instance)
{
public Aggregate newAggregate()
{
return new LongSumAggregate();
}
};
/**
* AVG function for long values.
*
* The average of an empty value set returns zero. The returned value is of the same type as the input values,
* so the returned average won't have a decimal part.
*/
public static final NativeAggregateFunction avgFunctionForLong =
new NativeAggregateFunction("avg", LongType.instance, LongType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return LongType.instance.decompose(computeInternal());
}
};
}
};
/**
* The SUM function for float values.
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForFloat =
new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance)
{
public Aggregate newAggregate()
{
return new FloatSumAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return FloatType.instance.decompose((float) computeInternal());
}
};
}
};
/**
* AVG function for float values.
*
* The average of an empty value set returns zero.
*/
public static final NativeAggregateFunction avgFunctionForFloat =
new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance)
{
public Aggregate newAggregate()
{
return new FloatAvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return FloatType.instance.decompose((float) computeInternal());
}
};
}
};
/**
* The SUM function for double values.
*
* The returned value is of the same type as the input values, so there is a risk of overflow if the sum of the
* values exceeds the maximum value that the type can represent.
*/
public static final NativeAggregateFunction sumFunctionForDouble =
new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance)
{
public Aggregate newAggregate()
{
return new FloatSumAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return DoubleType.instance.decompose(computeInternal());
}
};
}
};
/**
* Sum aggregate function for floating point numbers, using double arithmetics and
* Kahan's algorithm to improve result precision.
*/
private static abstract class FloatSumAggregate implements AggregateFunction.Aggregate
{
private double sum;
private double compensation;
private double simpleSum;
public void reset()
{
sum = 0;
compensation = 0;
simpleSum = 0;
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
double d = number.doubleValue();
simpleSum += d;
double tmp = d - compensation;
double rounded = sum + tmp;
compensation = (rounded - sum) - tmp;
sum = rounded;
}
public double computeInternal()
{
// correctly compute final sum if it's NaN from consequently
// adding same-signed infinite values.
double tmp = sum + compensation;
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
return simpleSum;
else
return tmp;
}
}
/**
* Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm
* to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is
* converted to corresponding representation by concrete implementations.
*/
private static abstract class FloatAvgAggregate implements AggregateFunction.Aggregate
{
private double sum;
private double compensation;
private double simpleSum;
private int count;
private BigDecimal bigSum = null;
private boolean overflow = false;
public void reset()
{
sum = 0;
compensation = 0;
simpleSum = 0;
count = 0;
bigSum = null;
overflow = false;
}
public double computeInternal()
{
if (count == 0)
return 0d;
if (overflow)
{
return bigSum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN).doubleValue();
}
else
{
// correctly compute final sum if it's NaN from consequently
// adding same-signed infinite values.
double tmp = sum + compensation;
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
sum = simpleSum;
else
sum = tmp;
return sum / count;
}
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
count++;
double d = number.doubleValue();
if (overflow)
{
bigSum = bigSum.add(BigDecimal.valueOf(d));
}
else
{
simpleSum += d;
double prev = sum;
double tmp = d - compensation;
double rounded = sum + tmp;
compensation = (rounded - sum) - tmp;
sum = rounded;
if (Double.isInfinite(sum) && !Double.isInfinite(d))
{
overflow = true;
bigSum = BigDecimal.valueOf(prev).add(BigDecimal.valueOf(d));
}
}
}
}
/**
* AVG function for double values.
*
* The average of an empty value set returns zero.
*/
public static final NativeAggregateFunction avgFunctionForDouble =
new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance)
{
public Aggregate newAggregate()
{
return new FloatAvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return DoubleType.instance.decompose(computeInternal());
}
};
}
};
/**
* The SUM function for counter column values.
*/
public static final NativeAggregateFunction sumFunctionForCounter =
new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new LongSumAggregate();
}
};
/**
* AVG function for counter column values.
*/
public static final NativeAggregateFunction avgFunctionForCounter =
new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new AvgAggregate()
{
public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException
{
return CounterColumnType.instance.decompose(computeInternal());
}
};
}
};
/**
* The MIN function for counter column values.
*/
public static final NativeAggregateFunction minFunctionForCounter =
new NativeAggregateFunction("min", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private Long min;
public void reset()
{
min = null;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return min != null ? LongType.instance.decompose(min) : null;
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
long lval = number.longValue();
if (min == null || lval < min)
min = lval;
}
};
}
};
/**
* MAX function for counter column values.
*/
public static final NativeAggregateFunction maxFunctionForCounter =
new NativeAggregateFunction("max", CounterColumnType.instance, CounterColumnType.instance)
{
public Aggregate newAggregate()
{
return new Aggregate()
{
private Long max;
public void reset()
{
max = null;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return max != null ? LongType.instance.decompose(max) : null;
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
long lval = number.longValue();
if (max == null || lval > max)
max = lval;
}
};
}
};
/**
* Creates a MAX function for the specified type.
*
* @param inputType the function input and output type
* @return a MAX function for the specified type.
*/
public static NativeAggregateFunction makeMaxFunction(final AbstractType> inputType)
{
return new NativeAggregateFunction("max", inputType, inputType)
{
@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private ByteBuffer max;
public void reset()
{
max = null;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return max;
}
@Override
public void addInput(Arguments arguments)
{
ByteBuffer value = arguments.get(0);
if (value == null)
return;
if (max == null || returnType().compare(max, value) < 0)
max = value;
}
};
}
};
}
/**
* Creates a MIN function for the specified type.
*
* @param inputType the function input and output type
* @return a MIN function for the specified type.
*/
public static NativeAggregateFunction makeMinFunction(final AbstractType> inputType)
{
return new NativeAggregateFunction("min", inputType, inputType)
{
@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private ByteBuffer min;
public void reset()
{
min = null;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return min;
}
@Override
public void addInput(Arguments arguments)
{
ByteBuffer value = arguments.get(0);
if (value == null)
return;
if (min == null || returnType().compare(min, value) > 0)
min = value;
}
};
}
};
}
/**
* Creates a COUNT function for the specified type.
*
* @param inputType the function input type
* @return a COUNT function for the specified type.
*/
public static NativeAggregateFunction makeCountFunction(AbstractType> inputType)
{
return new NativeAggregateFunction("count", LongType.instance, inputType)
{
@Override
public Arguments newArguments(ProtocolVersion version)
{
return FunctionArguments.newNoopInstance(version, 1);
}
@Override
public Aggregate newAggregate()
{
return new Aggregate()
{
private long count;
public void reset()
{
count = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return ((LongType) returnType()).decompose(count);
}
@Override
public void addInput(Arguments arguments)
{
if (arguments.get(0) == null)
return;
count++;
}
};
}
};
}
private static class LongSumAggregate implements AggregateFunction.Aggregate
{
private long sum;
public void reset()
{
sum = 0;
}
public ByteBuffer compute(ProtocolVersion protocolVersion)
{
return LongType.instance.decompose(sum);
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
sum += number.longValue();
}
}
/**
* Average aggregate class, collecting the sum using long arithmetics, falling back
* to BigInteger on long overflow. Resulting number is converted to corresponding
* representation by concrete implementations.
*/
private static abstract class AvgAggregate implements AggregateFunction.Aggregate
{
private long sum;
private int count;
private BigInteger bigSum = null;
private boolean overflow = false;
public void reset()
{
count = 0;
sum = 0L;
overflow = false;
bigSum = null;
}
long computeInternal()
{
if (overflow)
{
return bigSum.divide(BigInteger.valueOf(count)).longValue();
}
else
{
return count == 0 ? 0 : (sum / count);
}
}
@Override
public void addInput(Arguments arguments)
{
Number number = arguments.get(0);
if (number == null)
return;
count++;
long l = number.longValue();
if (overflow)
{
bigSum = bigSum.add(BigInteger.valueOf(l));
}
else
{
long prev = sum;
sum += l;
if (((prev ^ sum) & (l ^ sum)) < 0)
{
overflow = true;
bigSum = BigInteger.valueOf(prev).add(BigInteger.valueOf(l));
}
}
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy