tech.tablesaw.aggregate.PivotTable Maven / Gradle / Ivy
package tech.tablesaw.aggregate;
import tech.tablesaw.api.CategoricalColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.NumberColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.table.TableSlice;
import tech.tablesaw.table.TableSliceGroup;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* PivotTable is used to 'rotate' a source table such that it is summarized on the values of some column. As implemented
* here, you supply:
* - a "key" categorical column from which the primary grouping is created, there will be one on each row of the result
* - a second categorical column for which a subtotal is created; this produces n columns on each row of the result - one column for each unique value
* - a numeric column that provides the values to be summarized
* - an aggregation function that defines what operation is performed on the values in the subgroups
*/
public class PivotTable {
public static Table pivot(Table table, CategoricalColumn> column1, CategoricalColumn> column2, NumberColumn> values, AggregateFunction,?> aggregateFunction) {
TableSliceGroup tsg = table.splitOn(column1);
Table pivotTable = Table.create("Pivot: " + column1.name() + " x " + column2.name());
pivotTable.addColumns(column1.type().create(column1.name()));
List valueColumnNames = getValueColumnNames(table, column2);
for (String colName : valueColumnNames) {
pivotTable.addColumns(DoubleColumn.create(colName));
}
int valueIndex = table.columnIndex(column2);
int keyIndex = table.columnIndex(column1);
String key;
for (TableSlice slice : tsg.getSlices()) {
key = String.valueOf(slice.get(0, keyIndex));
pivotTable.column(0).appendCell(key);
Map valueMap = getValueMap(column1, column2, values, valueIndex, slice, aggregateFunction);
for (String columnName : valueColumnNames) {
Double aDouble = valueMap.get(columnName);
NumberColumn> pivotValueColumn = pivotTable.numberColumn(columnName);
if (aDouble == null) {
pivotValueColumn.appendMissing();
} else {
pivotValueColumn.appendObj(aDouble);
}
}
}
return pivotTable;
}
private static Map getValueMap(
CategoricalColumn> column1,
CategoricalColumn> column2,
NumberColumn> values,
int valueIndex,
TableSlice slice,
AggregateFunction,?> function) {
Table temp = slice.asTable();
Table summary = temp.summarize(values.name(), function).by(column1.name(), column2.name());
Map valueMap = new HashMap<>();
NumberColumn> nc = summary.numberColumn(summary.columnCount() - 1);
for (int i = 0; i < summary.rowCount(); i++) {
valueMap.put(String.valueOf(summary.get(i, 1)), nc.getDouble(i));
}
return valueMap;
}
private static List getValueColumnNames(Table table, CategoricalColumn> column2) {
List valueColumnNames = new ArrayList<>();
for (Object colName : table.column(column2.name()).unique()) {
valueColumnNames.add(String.valueOf(colName));
}
valueColumnNames.sort(String::compareTo);
return valueColumnNames;
}
}