All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fr.insee.vtl.model.AggregationExpression Maven / Gradle / Ivy

package fr.insee.vtl.model;

import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;

/**
 * The AggregationExpression class is an abstract representation of an aggregation expression.
 */
public class AggregationExpression implements Collector, TypedExpression {

    private final Collector aggregation;
    private final Class type;

    /**
     * Constructor taking a collector of data points and an intended type for the aggregation results.
     *
     * @param aggregation Collector of data points.
     * @param type        Expected type for aggregation results.
     */
    public  AggregationExpression(Collector aggregation, Class type) {
        this.aggregation = aggregation;
        this.type = type;
    }

    /**
     * Constructor based on an input expression, a data point collector and an expected type.
     * The input expression is applied to each data point before it is accepted by the data point collector.
     *
     * @param expression The input resolvable expression.
     * @param collector  The data point collector.
     * @param type       The expected type of the aggregation expression results.
     */
    public  AggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
        this(Collectors.mapping(expression::resolve, collector), type);
    }

    /**
     * Returns an aggregation expression that counts data points and returns a long integer.
     *
     * @return The counting expression.
     */
    public static AggregationExpression count() {
        return new CountAggregationExpression(Collectors.counting(), Long.class);
    }

    public static class CountAggregationExpression extends AggregationExpression {
        private  CountAggregationExpression(Collector aggregation, Class type) {
            super(aggregation, type);
        }
    }

    /**
     * Returns an aggregation expression that averages an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The averaging expression.
     */
    public static AggregationExpression avg(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new AverageAggregationExpression(
                    expression,
                    Collectors.averagingLong(value -> (Long) value),
                    Double.class
            );
        } else if (Double.class.equals(expression.getType())) {
            return new AverageAggregationExpression(
                    expression,
                    Collectors.averagingDouble(value -> (Double) value),
                    Double.class
            );
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class AverageAggregationExpression extends AggregationExpression {
        public  AverageAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    /**
     * Returns an aggregation expression that sums an expression on data points and returns a long integer or double number.
     *
     * @param expression The expression on data points.
     * @return The summing expression.
     */
    public static AggregationExpression sum(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new SumAggregationExpression(expression, Collectors.summingLong(value -> (Long) value), Long.class);
        } else if (Double.class.equals(expression.getType())) {
            return new SumAggregationExpression(expression, Collectors.summingDouble(value -> (Double) value), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class SumAggregationExpression extends AggregationExpression {
        public  SumAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    /**
     * Returns an aggregation expression that give median of an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The median expression.
     */
    public static AggregationExpression median(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new MedianAggregationExpression(expression, Collectors.mapping(v -> (Long) v, medianCollectorLong()), Double.class);
        } else if (Double.class.equals(expression.getType())) {
            return new MedianAggregationExpression(expression, Collectors.mapping(v -> (Double) v, medianCollectorDouble()), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class MedianAggregationExpression extends AggregationExpression {
        public  MedianAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    /**
     * Returns an aggregation expression that give max of an expression on data points and returns a long integer or double number.
     *
     * @param expression The expression on data points.
     * @return The max expression.
     */
    public static AggregationExpression max(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            Collector> maxBy = Collectors.maxBy(Comparator.nullsFirst(Comparator.naturalOrder()));
            Collector> mapping = Collectors.mapping(v -> (Long) v, maxBy);
            Collector res = Collectors.collectingAndThen(mapping, v -> v.orElse(null));
            return new MaxAggregationExpression(expression, res, Long.class);
        } else if (Double.class.equals(expression.getType())) {
            Collector> maxBy = Collectors.maxBy(Comparator.nullsFirst(Comparator.naturalOrder()));
            Collector> mapping = Collectors.mapping(v -> (Double) v, maxBy);
            Collector res = Collectors.collectingAndThen(mapping, v -> v.orElse(null));
            return new MaxAggregationExpression(expression, res, Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class MaxAggregationExpression extends AggregationExpression {
        public  MaxAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    /**
     * Returns an aggregation expression that give min of an expression on data points and returns a long integer or double number.
     *
     * @param expression The expression on data points.
     * @return The min expression.
     */
    public static AggregationExpression min(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            Collector> maxBy = Collectors.minBy(Comparator.nullsFirst(Comparator.naturalOrder()));
            Collector> mapping = Collectors.mapping(v -> (Long) v, maxBy);
            Collector res = Collectors.collectingAndThen(mapping, v -> v.orElse(null));
            return new MinAggregationExpression(expression, res, Long.class);
        } else if (Double.class.equals(expression.getType())) {
            Collector> maxBy = Collectors.minBy(Comparator.nullsFirst(Comparator.naturalOrder()));
            Collector> mapping = Collectors.mapping(v -> (Double) v, maxBy);
            Collector res = Collectors.collectingAndThen(mapping, v -> v.orElse(null));
            return new MinAggregationExpression(expression, res, Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class MinAggregationExpression extends AggregationExpression {
        public  MinAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    /**
     * Returns an aggregation expression that give population standard deviation of an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The population standard deviation expression.
     */
    public static AggregationExpression stdDevPop(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new StdDevPopAggregationExpression(expression, Collectors.mapping(v -> (Long) v, stdDevPopCollectorLong()), Double.class);
        } else if (Double.class.equals(expression.getType())) {
            return new StdDevPopAggregationExpression(expression, Collectors.mapping(v -> (Double) v, stdDevPopCollectorDouble()), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class StdDevPopAggregationExpression extends AggregationExpression {
        public  StdDevPopAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }


    /**
     * Returns an aggregation expression that give sample standard deviation of an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The sample standard deviation expression.
     */
    public static AggregationExpression stdDevSamp(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new StdDevSampAggregationExpression(expression, Collectors.mapping(v -> (Long) v, stdDevSampCollectorLong()), Double.class);
        } else if (Double.class.equals(expression.getType())) {
            return new StdDevSampAggregationExpression(expression, Collectors.mapping(v -> (Double) v, stdDevSampCollectorDouble()), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class StdDevSampAggregationExpression extends AggregationExpression {
        public  StdDevSampAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }


    /**
     * Returns an aggregation expression that give population variance of an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The population variance expression.
     */
    public static AggregationExpression varPop(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new VarPopAggregationExpression(expression, Collectors.mapping(v -> (Long) v, varPopCollectorLong()), Double.class);
        } else if (Double.class.equals(expression.getType())) {
            return new VarPopAggregationExpression(expression, Collectors.mapping(v -> (Double) v, varPopCollectorDouble()), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class VarPopAggregationExpression extends AggregationExpression {
        public  VarPopAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }


    /**
     * Returns an aggregation expression that give sample variance of an expression on data points and returns a double number.
     *
     * @param expression The expression on data points.
     * @return The sample variance expression.
     */
    public static AggregationExpression varSamp(ResolvableExpression expression) {
        if (Long.class.equals(expression.getType())) {
            return new VarSampAggregationExpression(expression, Collectors.mapping(v -> (Long) v, varSampCollectorLong()), Double.class);
        } else if (Double.class.equals(expression.getType())) {
            return new VarSampAggregationExpression(expression, Collectors.mapping(v -> (Double) v, varSampCollectorDouble()), Double.class);
        } else {
            // Type asserted in visitor.
            throw new Error("unexpected type");
        }
    }

    public static class VarSampAggregationExpression extends AggregationExpression {
        public  VarSampAggregationExpression(ResolvableExpression expression, Collector collector, Class type) {
            super(expression, collector, type);
        }
    }

    private static Collector, Double> medianCollectorLong() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                longs -> {
                    if (longs.contains(null)) return null;
                    Collections.sort(longs);
                    if (longs.size() % 2 == 0) {
                        return (double) (longs.get(longs.size() / 2 - 1) + longs.get(longs.size() / 2)) / 2;
                    } else {
                        return (double) longs.get(longs.size() / 2);
                    }
                }
        );
    }

    private static Collector, Double> medianCollectorDouble() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                longs -> {
                    if (longs.contains(null)) return null;
                    Collections.sort(longs);
                    if (longs.size() % 2 == 0) {
                        return (longs.get(longs.size() / 2 - 1) + longs.get(longs.size() / 2)) / 2;
                    } else {
                        return longs.get(longs.size() / 2);
                    }
                }
        );
    }

    private static Collector, Double> stdDevPopCollectorLong() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getDeviationLongFn(true)
        );
    }

    private static Collector, Double> stdDevPopCollectorDouble() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getDeviationDoubleFn(true)
        );
    }

    private static Collector, Double> stdDevSampCollectorLong() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getDeviationLongFn(false)
        );
    }

    private static Collector, Double> stdDevSampCollectorDouble() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getDeviationDoubleFn(false)
        );
    }

    private static Collector, Double> varPopCollectorLong() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getVarLongFn(true)
        );
    }

    private static Collector, Double> varPopCollectorDouble() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getVarDoubleFn(true)
        );
    }

    private static Collector, Double> varSampCollectorLong() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getVarLongFn(false)
        );
    }

    private static Collector, Double> varSampCollectorDouble() {
        return Collector.of(
                ArrayList::new,
                List::add,
                (longs, longs2) -> {
                    longs.addAll(longs2);
                    return longs;
                },
                getVarDoubleFn(false)
        );
    }

    private static Function, Double> getDeviationLongFn(Boolean usePopulation) {
        return longs -> {
            if (longs.contains(null)) return null;
            if (longs.size() <= 1) return 0D;
            Double avg = longs.stream().collect(Collectors.averagingLong(v -> v));
            return Math.sqrt(
                    longs.stream().map(v -> Math.pow(((double) v) - avg, 2)).mapToDouble(v -> v).sum()
                            / (longs.size() - (usePopulation ? 0D : 1D))
            );
        };
    }

    private static Function, Double> getDeviationDoubleFn(Boolean usePopulation) {
        return doubles -> {
            if (doubles.contains(null)) return null;
            if (doubles.size() <= 1) return 0D;
            Double avg = doubles.stream().collect(Collectors.averagingDouble(v -> v));
            return Math.sqrt(
                    doubles.stream().map(v -> Math.pow(v - avg, 2)).mapToDouble(v -> v).sum()
                            / (doubles.size() - (usePopulation ? 0D : 1D))
            );
        };
    }

    private static Function, Double> getVarLongFn(Boolean usePopulation) {
        return longs -> {
            if (longs.contains(null)) return null;
            if (longs.size() <= 1) return 0D;
            Double avg = longs.stream().collect(Collectors.averagingLong(v -> v));
            return longs.stream().map(v -> Math.pow(((double) v) - avg, 2)).mapToDouble(v -> v).sum()
                    / (longs.size() - (usePopulation ? 0D : 1D));
        };
    }

    private static Function, Double> getVarDoubleFn(Boolean usePopulation) {
        return doubles -> {
            if (doubles.contains(null)) return null;
            if (doubles.size() <= 1) return 0D;
            Double avg = doubles.stream().collect(Collectors.averagingDouble(v -> v));
            return doubles.stream().map(v -> Math.pow(v - avg, 2)).mapToDouble(v -> v).sum()
                    / (doubles.size() - (usePopulation ? 0D : 1D));
        };
    }

    @Override
    public Class getType() {
        return type;
    }

    @Override
    public Supplier supplier() {
        return (Supplier) aggregation.supplier();
    }

    @Override
    public BiConsumer accumulator() {
        return (BiConsumer) aggregation.accumulator();
    }

    @Override
    public BinaryOperator combiner() {
        return (BinaryOperator) aggregation.combiner();
    }

    @Override
    public Function finisher() {
        return (Function) aggregation.finisher();
    }

    @Override
    public Set characteristics() {
        return aggregation.characteristics();
    }

}