All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.api.java.aggregation.SumAggregationFunction Maven / Gradle / Ivy

There is a newer version: 1.20.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.java.aggregation;

import org.apache.flink.annotation.Internal;
import org.apache.flink.types.ByteValue;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.FloatValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.ShortValue;

/**
 * Definitions of sum functions for different numerical types.
 * @param  type of elements being summed
 */
@Internal
public abstract class SumAggregationFunction extends AggregationFunction {

	private static final long serialVersionUID = 1L;

	@Override
	public String toString() {
		return "SUM";
	}

	// --------------------------------------------------------------------------------------------

	private static final class ByteSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(Byte value) {
			agg += value;
		}

		@Override
		public Byte getAggregate() {
			return (byte) agg;
		}
	}

	private static final class ByteValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(ByteValue value) {
			agg += value.getValue();
		}

		@Override
		public ByteValue getAggregate() {
			return new ByteValue((byte) agg);
		}
	}

	private static final class ShortSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(Short value) {
			agg += value;
		}

		@Override
		public Short getAggregate() {
			return (short) agg;
		}
	}

	private static final class ShortValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(ShortValue value) {
			agg += value.getValue();
		}

		@Override
		public ShortValue getAggregate() {
			return new ShortValue((short) agg);
		}
	}

	private static final class IntSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(Integer value) {
			agg += value;
		}

		@Override
		public Integer getAggregate() {
			return (int) agg;
		}
	}

	private static final class IntValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0;
		}

		@Override
		public void aggregate(IntValue value) {
			agg += value.getValue();
		}

		@Override
		public IntValue getAggregate() {
			return new IntValue((int) agg);
		}
	}

	private static final class LongSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0L;
		}

		@Override
		public void aggregate(Long value) {
			agg += value;
		}

		@Override
		public Long getAggregate() {
			return agg;
		}
	}

	private static final class LongValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private long agg;

		@Override
		public void initializeAggregate() {
			agg = 0L;
		}

		@Override
		public void aggregate(LongValue value) {
			agg += value.getValue();
		}

		@Override
		public LongValue getAggregate() {
			return new LongValue(agg);
		}
	}

	private static final class FloatSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private double agg;

		@Override
		public void initializeAggregate() {
			agg = 0.0f;
		}

		@Override
		public void aggregate(Float value) {
			agg += value;
		}

		@Override
		public Float getAggregate() {
			return (float) agg;
		}
	}

	private static final class FloatValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private double agg;

		@Override
		public void initializeAggregate() {
			agg = 0.0f;
		}

		@Override
		public void aggregate(FloatValue value) {
			agg += value.getValue();
		}

		@Override
		public FloatValue getAggregate() {
			return new FloatValue((float) agg);
		}
	}

	private static final class DoubleSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private double agg;

		@Override
		public void initializeAggregate() {
			agg = 0.0;
		}

		@Override
		public void aggregate(Double value) {
			agg += value;
		}

		@Override
		public Double getAggregate() {
			return agg;
		}
	}

	private static final class DoubleValueSumAgg extends SumAggregationFunction {
		private static final long serialVersionUID = 1L;

		private double agg;

		@Override
		public void initializeAggregate() {
			agg = 0.0;
		}

		@Override
		public void aggregate(DoubleValue value) {
			agg += value.getValue();
		}

		@Override
		public DoubleValue getAggregate() {
			return new DoubleValue(agg);
		}
	}

	// --------------------------------------------------------------------------------------------

	/**
	 * Factory for {@link SumAggregationFunction}.
	 */
	public static final class SumAggregationFunctionFactory implements AggregationFunctionFactory {
		private static final long serialVersionUID = 1L;

		@SuppressWarnings("unchecked")
		@Override
		public  AggregationFunction createAggregationFunction(Class type) {
			if (type == Long.class) {
				return (AggregationFunction) new LongSumAgg();
			}
			else if (type == LongValue.class) {
				return (AggregationFunction) new LongValueSumAgg();
			}
			else if (type == Integer.class) {
				return (AggregationFunction) new IntSumAgg();
			}
			else if (type == IntValue.class) {
				return (AggregationFunction) new IntValueSumAgg();
			}
			else if (type == Double.class) {
				return (AggregationFunction) new DoubleSumAgg();
			}
			else if (type == DoubleValue.class) {
				return (AggregationFunction) new DoubleValueSumAgg();
			}
			else if (type == Float.class) {
				return (AggregationFunction) new FloatSumAgg();
			}
			else if (type == FloatValue.class) {
				return (AggregationFunction) new FloatValueSumAgg();
			}
			else if (type == Byte.class) {
				return (AggregationFunction) new ByteSumAgg();
			}
			else if (type == ByteValue.class) {
				return (AggregationFunction) new ByteValueSumAgg();
			}
			else if (type == Short.class) {
				return (AggregationFunction) new ShortSumAgg();
			}
			else if (type == ShortValue.class) {
				return (AggregationFunction) new ShortValueSumAgg();
			}
			else {
				throw new UnsupportedAggregationTypeException("The type " + type.getName() +
					" is currently not supported for built-in sum aggregations.");
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy