com.marklogic.flux.impl.importdata.AggregationParams Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of flux-api Show documentation
Show all versions of flux-api Show documentation
Flux API for data movement with MarkLogic
/*
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
*/
package com.marklogic.flux.impl.importdata;
import com.marklogic.flux.api.FluxException;
import org.apache.spark.sql.*;
import picocli.CommandLine;
import java.util.*;
import java.util.stream.Collectors;
class AggregationParams implements CommandLine.ITypeConverter {
private static final String AGGREGATE_DELIMITER = ",";
@CommandLine.Option(
names = "--group-by",
description = "Name of a column to group the rows by before constructing documents. Typically used with at " +
"least one instance of the --aggregate option.")
private String groupBy;
@CommandLine.Option(
names = "--aggregate",
description = "Define an aggregation of multiple columns into a new column. Each aggregation must be of the " +
"form newColumnName=column1,column2,etc. Requires the use of --group-by.",
converter = AggregationParams.class
)
private List aggregations = new ArrayList<>();
public static class Aggregation {
private String newColumnName;
private List columnNamesToGroup;
public Aggregation(String newColumnName, List columnNamesToGroup) {
this.newColumnName = newColumnName;
this.columnNamesToGroup = columnNamesToGroup;
}
}
@Override
public Aggregation convert(String value) {
String[] parts = value.split("=");
if (parts.length != 2) {
throw new FluxException(String.format("Invalid aggregation: %s; must be of " +
"the form newColumnName=columnToGroup1,columnToGroup2,etc.", value));
}
final String newColumnName = parts[0];
String[] columnNamesToAggregate = parts[1].split(AGGREGATE_DELIMITER);
return new Aggregation(newColumnName, Arrays.asList(columnNamesToAggregate));
}
public void setGroupBy(String groupBy) {
this.groupBy = groupBy;
}
public void addAggregationExpression(String newColumnName, String... columns) {
if (this.aggregations == null) {
this.aggregations = new ArrayList<>();
}
this.aggregations.add(new Aggregation(newColumnName, Arrays.asList(columns)));
}
public Dataset applyGroupBy(Dataset dataset) {
if (groupBy == null || groupBy.trim().isEmpty()) {
return dataset;
}
final RelationalGroupedDataset groupedDataset = dataset.groupBy(this.groupBy);
List columns = getColumnsNotInAggregation(dataset);
List aggregationColumns = makeAggregationColumns();
columns.addAll(aggregationColumns);
final Column aliasColumn = columns.get(0);
final Column[] columnsToGroup = columns.subList(1, columns.size()).toArray(new Column[]{});
try {
return groupedDataset.agg(aliasColumn, columnsToGroup);
} catch (Exception e) {
String columnNames = aggregations.stream().map(agg -> agg.columnNamesToGroup.toString()).collect(Collectors.joining(", "));
throw new FluxException(String.format("Unable to aggregate columns: %s; please ensure that each column " +
"name will be present in the data read from the data source.", columnNames), e);
}
}
/**
* @param dataset
* @return a list of columns reflecting each column that is not referenced in an aggregation and is also not the
* "groupBy" column. These columns are assumed to have the same value in every row, and thus only the first value
* is needed for each column.
*/
private List getColumnsNotInAggregation(Dataset dataset) {
Set aggregatedColumnNames = new HashSet<>();
aggregations.forEach(agg -> aggregatedColumnNames.addAll(agg.columnNamesToGroup));
List columns = new ArrayList<>();
for (String name : dataset.schema().names()) {
if (!aggregatedColumnNames.contains(name) && !groupBy.equals(name)) {
columns.add(functions.first(name).alias(name));
}
}
return columns;
}
/**
* @return a list of columns, one per aggregation.
*/
private List makeAggregationColumns() {
List columns = new ArrayList<>();
aggregations.forEach(aggregation -> {
final List columnNames = aggregation.columnNamesToGroup;
if (columnNames.size() == 1) {
Column column = new Column(columnNames.get(0));
Column listOfValuesColumn = functions.collect_list(functions.concat(column));
columns.add(listOfValuesColumn.alias(aggregation.newColumnName));
} else {
Column[] structColumns = columnNames.stream().map(functions::col).toArray(Column[]::new);
Column arrayColumn = functions.collect_list(functions.struct(structColumns));
// array_distinct removes duplicate objects that can result from 2+ joins existing in the query.
// See https://www.sparkreference.com/reference/array_distinct/ for performance considerations.
columns.add(functions.array_distinct(arrayColumn).alias(aggregation.newColumnName));
}
});
return columns;
}
}