org.h2.expression.JavaAggregate Maven / Gradle / Ivy
/*
* Copyright 2004-2011 H2 Group. Multiple-Licensed under the H2 License,
* Version 1.0, and under the Eclipse Public License, Version 1.0
* (http://h2database.com/html/license.html).
* Initial Developer: H2 Group
*/
package org.h2.expression;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import org.h2.api.AggregateFunction;
import org.h2.command.Parser;
import org.h2.command.dml.Select;
import org.h2.constant.ErrorCode;
import org.h2.engine.Session;
import org.h2.engine.UserAggregate;
import org.h2.message.DbException;
import org.h2.table.ColumnResolver;
import org.h2.table.TableFilter;
import org.h2.util.StatementBuilder;
import org.h2.value.DataType;
import org.h2.value.Value;
import org.h2.value.ValueNull;
/**
* This class wraps a user-defined aggregate.
*/
public class JavaAggregate extends Expression {
private final UserAggregate userAggregate;
private final Select select;
private AggregateFunction aggregate;
private Expression[] args;
private int[] argTypes;
private int dataType;
private Connection userConnection;
private int lastGroupRowId;
public JavaAggregate(UserAggregate userAggregate, Expression[] args, Select select) {
this.userAggregate = userAggregate;
this.args = args;
this.select = select;
}
public int getCost() {
int cost = 5;
for (Expression e : args) {
cost += e.getCost();
}
return cost;
}
public long getPrecision() {
return Integer.MAX_VALUE;
}
public int getDisplaySize() {
return Integer.MAX_VALUE;
}
public int getScale() {
return DataType.getDataType(dataType).defaultScale;
}
public String getSQL() {
StatementBuilder buff = new StatementBuilder();
buff.append(Parser.quoteIdentifier(userAggregate.getName())).append('(');
for (Expression e : args) {
buff.appendExceptFirst(", ");
buff.append(e.getSQL());
}
return buff.append(')').toString();
}
public int getType() {
return dataType;
}
public boolean isEverything(ExpressionVisitor visitor) {
switch(visitor.getType()) {
case ExpressionVisitor.DETERMINISTIC:
// TODO optimization: some functions are deterministic, but we don't
// know (no setting for that)
case ExpressionVisitor.OPTIMIZABLE_MIN_MAX_COUNT_ALL:
// user defined aggregate functions can not be optimized
return false;
case ExpressionVisitor.GET_DEPENDENCIES:
visitor.addDependency(userAggregate);
break;
default:
}
for (Expression e : args) {
if (e != null && !e.isEverything(visitor)) {
return false;
}
}
return true;
}
public void mapColumns(ColumnResolver resolver, int level) {
for (Expression arg : args) {
arg.mapColumns(resolver, level);
}
}
public Expression optimize(Session session) {
userConnection = session.createConnection(false);
int len = args.length;
argTypes = new int[len];
int[] argSqlTypes = new int[len];
for (int i = 0; i < len; i++) {
Expression expr = args[i];
args[i] = expr.optimize(session);
int type = expr.getType();
argTypes[i] = type;
argSqlTypes[i] = DataType.convertTypeToSQLType(type);
}
try {
aggregate = getInstance();
dataType = DataType.convertSQLTypeToValueType(aggregate.getType(argSqlTypes));
} catch (SQLException e) {
throw DbException.convert(e);
}
return this;
}
public void setEvaluatable(TableFilter tableFilter, boolean b) {
for (Expression e : args) {
e.setEvaluatable(tableFilter, b);
}
}
private AggregateFunction getInstance() throws SQLException {
AggregateFunction agg = userAggregate.getInstance();
agg.init(userConnection);
return agg;
}
public Value getValue(Session session) {
HashMap group = select.getCurrentGroup();
if (group == null) {
throw DbException.get(ErrorCode.INVALID_USE_OF_AGGREGATE_FUNCTION_1, getSQL());
}
try {
AggregateFunction agg = (AggregateFunction) group.get(this);
if (agg == null) {
agg = getInstance();
}
Object obj = agg.getResult();
if (obj == null) {
return ValueNull.INSTANCE;
}
return DataType.convertToValue(session, obj, dataType);
} catch (SQLException e) {
throw DbException.convert(e);
}
}
public void updateAggregate(Session session) {
HashMap group = select.getCurrentGroup();
if (group == null) {
// this is a different level (the enclosing query)
return;
}
int groupRowId = select.getCurrentGroupRowId();
if (lastGroupRowId == groupRowId) {
// already visited
return;
}
lastGroupRowId = groupRowId;
AggregateFunction agg = (AggregateFunction) group.get(this);
try {
if (agg == null) {
agg = getInstance();
group.put(this, agg);
}
Object[] argValues = new Object[args.length];
Object arg = null;
for (int i = 0, len = args.length; i < len; i++) {
Value v = args[i].getValue(session);
v = v.convertTo(argTypes[i]);
arg = v.getObject();
argValues[i] = arg;
}
if (args.length == 1) {
agg.add(arg);
} else {
agg.add(argValues);
}
} catch (SQLException e) {
throw DbException.convert(e);
}
}
}