All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.flux.impl.importdata.AggregationParams Maven / Gradle / Ivy

There is a newer version: 1.1.3
Show newest version
/*
 * Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
 */
package com.marklogic.flux.impl.importdata;

import com.marklogic.flux.api.FluxException;
import org.apache.spark.sql.*;
import picocli.CommandLine;

import java.util.*;
import java.util.stream.Collectors;

class AggregationParams implements CommandLine.ITypeConverter {

    private static final String AGGREGATE_DELIMITER = ",";

    @CommandLine.Option(
        names = "--group-by",
        description = "Name of a column to group the rows by before constructing documents. Typically used with at " +
            "least one instance of the --aggregate option.")
    private String groupBy;

    @CommandLine.Option(
        names = "--aggregate",
        description = "Define an aggregation of multiple columns into a new column. Each aggregation must be of the " +
            "form newColumnName=column1,column2,etc. Requires the use of --group-by.",
        converter = AggregationParams.class
    )
    private List aggregations = new ArrayList<>();

    public static class Aggregation {
        private String newColumnName;
        private List columnNamesToGroup;

        public Aggregation(String newColumnName, List columnNamesToGroup) {
            this.newColumnName = newColumnName;
            this.columnNamesToGroup = columnNamesToGroup;
        }
    }

    @Override
    public Aggregation convert(String value) {
        String[] parts = value.split("=");
        if (parts.length != 2) {
            throw new FluxException(String.format("Invalid aggregation: %s; must be of " +
                "the form newColumnName=columnToGroup1,columnToGroup2,etc.", value));
        }
        final String newColumnName = parts[0];
        String[] columnNamesToAggregate = parts[1].split(AGGREGATE_DELIMITER);
        return new Aggregation(newColumnName, Arrays.asList(columnNamesToAggregate));
    }

    public void setGroupBy(String groupBy) {
        this.groupBy = groupBy;
    }

    public void addAggregationExpression(String newColumnName, String... columns) {
        if (this.aggregations == null) {
            this.aggregations = new ArrayList<>();
        }
        this.aggregations.add(new Aggregation(newColumnName, Arrays.asList(columns)));
    }

    public Dataset applyGroupBy(Dataset dataset) {
        if (groupBy == null || groupBy.trim().isEmpty()) {
            return dataset;
        }

        final RelationalGroupedDataset groupedDataset = dataset.groupBy(this.groupBy);

        List columns = getColumnsNotInAggregation(dataset);
        List aggregationColumns = makeAggregationColumns();
        columns.addAll(aggregationColumns);
        final Column aliasColumn = columns.get(0);
        final Column[] columnsToGroup = columns.subList(1, columns.size()).toArray(new Column[]{});
        try {
            return groupedDataset.agg(aliasColumn, columnsToGroup);
        } catch (Exception e) {
            String columnNames = aggregations.stream().map(agg -> agg.columnNamesToGroup.toString()).collect(Collectors.joining(", "));
            throw new FluxException(String.format("Unable to aggregate columns: %s; please ensure that each column " +
                "name will be present in the data read from the data source.", columnNames), e);
        }
    }

    /**
     * @param dataset
     * @return a list of columns reflecting each column that is not referenced in an aggregation and is also not the
     * "groupBy" column. These columns are assumed to have the same value in every row, and thus only the first value
     * is needed for each column.
     */
    private List getColumnsNotInAggregation(Dataset dataset) {
        Set aggregatedColumnNames = new HashSet<>();
        aggregations.forEach(agg -> aggregatedColumnNames.addAll(agg.columnNamesToGroup));

        List columns = new ArrayList<>();
        for (String name : dataset.schema().names()) {
            if (!aggregatedColumnNames.contains(name) && !groupBy.equals(name)) {
                columns.add(functions.first(name).alias(name));
            }
        }
        return columns;
    }

    /**
     * @return a list of columns, one per aggregation.
     */
    private List makeAggregationColumns() {
        List columns = new ArrayList<>();
        aggregations.forEach(aggregation -> {
            final List columnNames = aggregation.columnNamesToGroup;
            if (columnNames.size() == 1) {
                Column column = new Column(columnNames.get(0));
                Column listOfValuesColumn = functions.collect_list(functions.concat(column));
                columns.add(listOfValuesColumn.alias(aggregation.newColumnName));
            } else {
                Column[] structColumns = columnNames.stream().map(functions::col).toArray(Column[]::new);
                Column arrayColumn = functions.collect_list(functions.struct(structColumns));
                // array_distinct removes duplicate objects that can result from 2+ joins existing in the query.
                // See https://www.sparkreference.com/reference/array_distinct/ for performance considerations.
                columns.add(functions.array_distinct(arrayColumn).alias(aggregation.newColumnName));
            }
        });
        return columns;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy