com.github.chengyuxing.sql.support.JdbcSupport Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of rabbit-sql Show documentation
Show all versions of rabbit-sql Show documentation
Light wrapper of JDBC, support ddl, dml, query, plsql/procedure/function, transaction and manage sql
file.
package com.github.chengyuxing.sql.support;
import com.github.chengyuxing.common.DataRow;
import com.github.chengyuxing.common.UncheckedCloseable;
import com.github.chengyuxing.common.utils.ObjectUtil;
import com.github.chengyuxing.common.utils.StringUtil;
import com.github.chengyuxing.sql.exceptions.UncheckedSqlException;
import com.github.chengyuxing.sql.types.Param;
import com.github.chengyuxing.sql.types.ParamMode;
import com.github.chengyuxing.sql.utils.JdbcUtil;
import com.github.chengyuxing.sql.utils.SqlGenerator;
import com.github.chengyuxing.sql.utils.SqlHighlighter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.sql.*;
import java.time.LocalDateTime;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
/**
* JDBC support
* Provide basic support: {@link Stream stream query}, {@code ddl}, {@code dml}, {@code store procedure/function}, {@code plsql}.
* e.g. sql statement:
*
* select * from ...
* where name = :name
* or id in (${!idList}) ${cnd};
*
*
* @see com.github.chengyuxing.sql.types.Variable
*/
public abstract class JdbcSupport extends SqlParser {
private final static Logger log = LoggerFactory.getLogger(JdbcSupport.class);
/**
* Get datasource.
*
* @return datasource
*/
protected abstract DataSource getDataSource();
/**
* Get connection.
*
* @return connection
*/
protected abstract Connection getConnection();
/**
* Release connection when execute finished.
*
* @param connection connection
* @param dataSource datasource
*/
protected abstract void releaseConnection(Connection connection, DataSource dataSource);
/**
* Handle prepared statement value.
*
* @param ps PreparedStatement
* @param index parameter index
* @param value parameter value
* @throws SQLException ex
*/
protected abstract void doHandleStatementValue(PreparedStatement ps, int index, Object value) throws SQLException;
/**
* Sql watcher
*
* @param sql sql
* @param args args
* @param startTime connection request start time
* @param endTime finish execute time
* @param throwable throwable
*/
protected void watchSql(String sql, Object args, long startTime, long endTime, Throwable throwable) {
}
/**
* jdbc execute sql timeout.
*
* @return time out (seconds)
* @see Statement#setQueryTimeout(int)
*/
protected int queryTimeout() {
return 0;
}
/**
* Execute any prepared sql.
*
* @param sql sql
* @param callback statement callback
* @param result type
* @return any result
* @throws UncheckedSqlException if connection states error
*/
protected T execute(final String sql, StatementCallback callback) {
PreparedStatement statement = null;
Connection connection = null;
try {
connection = getConnection();
//noinspection SqlSourceToSinkFlow
statement = connection.prepareStatement(sql);
return callback.doInStatement(statement);
} catch (SQLException e) {
try {
JdbcUtil.closeStatement(statement);
} catch (SQLException ex) {
e.addSuppressed(ex);
}
statement = null;
releaseConnection(connection, getDataSource());
throw new UncheckedSqlException("execute sql:\n" + sql, e);
} finally {
try {
JdbcUtil.closeStatement(statement);
} catch (SQLException e) {
log.error("close statement error.", e);
}
releaseConnection(connection, getDataSource());
}
}
/**
* Set prepared sql statement args.
*
* @param ps sql statement object
* @param args args
* @param names ordered arg names
* @throws SQLException if connection states error
*/
protected void setPreparedSqlArgs(PreparedStatement ps, Map args, Map> names) throws SQLException {
for (Map.Entry> e : names.entrySet()) {
String name = e.getKey();
Object value = name.contains(".") ? ObjectUtil.getDeepValue(args, name) : args.get(name);
for (Integer i : e.getValue()) {
doHandleStatementValue(ps, i, value);
}
}
}
/**
* Set callable statement args.
*
* @param cs store procedure/function statement object
* @param args args
* @param names ordered arg names
* @throws SQLException if connection states error
*/
protected void setPreparedStoreArgs(CallableStatement cs, Map args, Map> names) throws SQLException {
if (args != null && !args.isEmpty()) {
// adapt postgresql
// out and inout param first
for (Map.Entry> e : names.entrySet()) {
Param param = args.get(e.getKey());
if (param.getParamMode() == ParamMode.OUT || param.getParamMode() == ParamMode.IN_OUT) {
for (Integer i : e.getValue()) {
cs.registerOutParameter(i, param.getType().typeNumber());
}
}
}
// in param next
for (Map.Entry> e : names.entrySet()) {
Param param = args.get(e.getKey());
for (Integer i : e.getValue()) {
if (param.getParamMode() == ParamMode.IN || param.getParamMode() == ParamMode.IN_OUT) {
doHandleStatementValue(cs, i, param.getValue());
}
}
}
}
}
/**
* Execute query, ddl, dml or plsql statement.
* Execute result:
*
* - result: {@link DataRow#getFirst(Object...) getFirst()} or {@link DataRow#get(Object) get("result")}
* - type: {@link DataRow#getString(int, String...) getString(1)} 或 {@link DataRow#get(Object) get("type")}
*
*
* @param sql named parameter sql
* @param args args
* @return Query: List{@code }, DML: affected row count, DDL: 0
* @throws UncheckedSqlException sql execute error
*/
public DataRow execute(final String sql, Map args) {
long startTime = System.currentTimeMillis();
SqlGenerator.GeneratedSqlMetaData sqlMetaData = prepare(sql, args);
final Map> argNames = sqlMetaData.getArgNameIndexMapping();
final String preparedSql = sqlMetaData.getResultSql();
final Map myArgs = sqlMetaData.getArgs();
Throwable reason = null;
try {
debugSql(sqlMetaData.getNamedParamSql(), Collections.singletonList(myArgs));
return execute(preparedSql, ps -> {
ps.setQueryTimeout(queryTimeout());
setPreparedSqlArgs(ps, myArgs, argNames);
boolean isQuery = ps.execute();
printSqlConsole(ps);
if (isQuery) {
ResultSet resultSet = ps.getResultSet();
List rows = JdbcUtil.createDataRows(resultSet, preparedSql, -1);
JdbcUtil.closeResultSet(resultSet);
return DataRow.of("result", rows, "type", "QUERY");
}
int count = ps.getUpdateCount();
return DataRow.of("result", count, "type", "DD(M)L");
});
} catch (Exception e) {
reason = e;
throw new RuntimeException("prepare sql error:\n" + sql + "\n" + myArgs, e);
} finally {
watchSql(sql, myArgs, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Lazy execute query based on {@link Stream} support, real execute query when terminal
* operation called, every Stream query hold a connection, in case connection pool dead
* do must have to close this stream finally, e.g.
* Auto close by {@code try-with-resource}:
*
*
* try ({@link Stream}<{@link DataRow}> stream = executeQueryStream(...)) {
* stream.limit(10).forEach(System.out::println);
* }
*
* Manual close by call {@link Stream#close()}:
*
*
* {@link Stream}<{@link DataRow}> stream = executeQueryStream(...);
* ...
* stream.close();
*
*
* @param sql named parameter sql, e.g. select * from test.user where id = :id
* @param args args
* @return Stream query result
* @throws UncheckedSqlException sql execute error
*/
public Stream executeQueryStream(final String sql, Map args) {
long startTime = System.currentTimeMillis();
SqlGenerator.GeneratedSqlMetaData sqlMetaData = prepare(sql, args);
final Map> argNames = sqlMetaData.getArgNameIndexMapping();
final String preparedSql = sqlMetaData.getResultSql();
final Map myArgs = sqlMetaData.getArgs();
UncheckedCloseable close = null;
Throwable reason = null;
try {
Connection connection = getConnection();
// if this query is not in transaction, it's connection managed by Stream
// if transaction is active connection will not be close when read stream to the end in 'try-with-resource' block
close = UncheckedCloseable.wrap(() -> releaseConnection(connection, getDataSource()));
debugSql(sqlMetaData.getNamedParamSql(), Collections.singletonList(myArgs));
//noinspection SqlSourceToSinkFlow
PreparedStatement ps = connection.prepareStatement(preparedSql);
ps.setQueryTimeout(queryTimeout());
close = close.nest(ps);
setPreparedSqlArgs(ps, myArgs, argNames);
ResultSet resultSet = ps.executeQuery();
close = close.nest(resultSet);
return StreamSupport.stream(new Spliterators.AbstractSpliterator(Long.MAX_VALUE, Spliterator.ORDERED) {
final String[] names = JdbcUtil.createNames(resultSet, preparedSql);
@Override
public boolean tryAdvance(Consumer super DataRow> action) {
try {
if (!resultSet.next()) {
return false;
}
action.accept(JdbcUtil.createDataRow(names, resultSet));
return true;
} catch (SQLException ex) {
throw new UncheckedSqlException("reading result set of query:\n" + preparedSql + "\nerror.", ex);
}
}
}, false).onClose(close);
} catch (Exception ex) {
if (close != null) {
try {
close.close();
} catch (Exception e) {
ex.addSuppressed(e);
}
}
reason = ex;
throw new RuntimeException("streaming query error:\n" + preparedSql + "\n" + myArgs, ex);
} finally {
watchSql(sql, myArgs, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Batch execute not prepared sql ({@code ddl} or {@code dml}).
*
* @param sqls more than 1 sql
* @param batchSize batch size
* @return affected row count
* @throws UncheckedSqlException execute sql error
* @throws IllegalArgumentException if sqls count less than 1
*/
public int executeBatch(final List sqls, int batchSize) {
if (batchSize < 1) {
throw new IllegalArgumentException("batchSize must greater than 0.");
}
if (sqls.isEmpty()) {
return 0;
}
long startTime = System.currentTimeMillis();
Statement s = null;
Throwable reason = null;
Connection connection = getConnection();
try {
s = connection.createStatement();
final Stream.Builder result = Stream.builder();
int i = 1;
for (String sql : sqls) {
if (StringUtil.isEmpty(sql)) {
continue;
}
String parsedSql = parseSql(sql, Collections.emptyMap()).getItem1();
debugSql(parsedSql, Collections.emptyList());
//noinspection SqlSourceToSinkFlow
s.addBatch(parsedSql);
if (i % batchSize == 0) {
result.add(s.executeBatch());
s.clearBatch();
}
i++;
}
result.add(s.executeBatch());
s.clearBatch();
return result.build().flatMapToInt(IntStream::of).sum();
} catch (SQLException e) {
try {
JdbcUtil.closeStatement(s);
} catch (SQLException ex) {
e.addSuppressed(ex);
}
s = null;
releaseConnection(connection, getDataSource());
reason = e;
throw new UncheckedSqlException("execute batch error.", e);
} finally {
try {
JdbcUtil.closeStatement(s);
} catch (SQLException e) {
log.error("close statement error.", e);
}
releaseConnection(connection, getDataSource());
watchSql(String.join("###", sqls), null, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Batch execute prepared non-query sql ({@code insert}, {@code update}, {@code delete}).
*
* @param sql named parameter sql
* @param args args collection
* @param batchSize batch size
* @return affected row count
*/
public int executeBatchUpdate(final String sql, Collection extends Map> args, int batchSize) {
if (batchSize < 1) {
throw new IllegalArgumentException("batchSize must greater than 0.");
}
long startTime = System.currentTimeMillis();
Map first = args.iterator().next();
SqlGenerator.GeneratedSqlMetaData sqlMetaData = prepare(sql, first);
final Map> argNames = sqlMetaData.getArgNameIndexMapping();
final String preparedSql = sqlMetaData.getResultSql();
Throwable reason = null;
try {
debugSql(sqlMetaData.getNamedParamSql(), args);
return execute(preparedSql, ps -> {
final Stream.Builder result = Stream.builder();
int i = 1;
for (Map item : args) {
setPreparedSqlArgs(ps, item, argNames);
ps.addBatch();
if (i % batchSize == 0) {
result.add(ps.executeBatch());
ps.clearBatch();
}
i++;
}
result.add(ps.executeBatch());
ps.clearBatch();
return result.build().flatMapToInt(IntStream::of).sum();
});
} catch (Exception e) {
reason = e;
throw new RuntimeException("prepare sql error:\n" + sql, e);
} finally {
watchSql(sql, args, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Execute prepared non-query sql ({@code insert}, {@code update}, {@code delete})
* e.g. insert statement:
*
* insert into table (a,b,c) values (:v1,:v2,:v3)
*
* args:
*
* {v1:'a',v2:'b',v3:'c'}
*
*
* @param sql named parameter sql
* @param args args
* @return affect row count
*/
public int executeUpdate(final String sql, Map args) {
long startTime = System.currentTimeMillis();
SqlGenerator.GeneratedSqlMetaData sqlMetaData = prepare(sql, args);
final Map> argNames = sqlMetaData.getArgNameIndexMapping();
final String preparedSql = sqlMetaData.getResultSql();
final Map myArgs = sqlMetaData.getArgs();
Throwable reason = null;
try {
debugSql(sqlMetaData.getNamedParamSql(), Collections.singletonList(myArgs));
return execute(preparedSql, sc -> {
sc.setQueryTimeout(queryTimeout());
if (myArgs.isEmpty()) {
return sc.executeUpdate();
}
setPreparedSqlArgs(sc, myArgs, argNames);
return sc.executeUpdate();
});
} catch (Exception e) {
reason = e;
throw new RuntimeException("prepare sql error:\n" + sql + "\n" + myArgs, e);
} finally {
watchSql(sql, myArgs, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Execute store {@code procedure} or {@code function}.
*
*
* { call func1(:in1, :in2, :out1, :out2) }
* { call func2(:out::refcursor) } //postgresql
* { :out = call func3() }
* { call func_returns_table() } //postgresql
* call procedure() //postgresql v13+
*
*
* 2 ways to get result:
*
* - zero OUT parameters: {@link DataRow#getFirst(Object...) getFirst()} or {@link DataRow#getFirstAs(Object...) getFirstAs()}
* - by OUT parameter name: {@link DataRow#getAs(String, Object...) getAs(String)} or {@link DataRow#get(Object) get(String)}
*
*
* @param procedure procedure
* @param args args
* @return DataRow
* @throws UncheckedSqlException execute procedure error
*/
public DataRow executeCallStatement(final String procedure, Map args) {
long startTime = System.currentTimeMillis();
SqlGenerator.GeneratedSqlMetaData sqlMetaData = prepare(procedure, args);
final String executeSql = sqlMetaData.getResultSql();
final Map> argNames = sqlMetaData.getArgNameIndexMapping();
Connection connection = getConnection();
CallableStatement statement = null;
Throwable reason = null;
try {
debugSql(sqlMetaData.getNamedParamSql(), Collections.singletonList(args));
//noinspection SqlSourceToSinkFlow
statement = connection.prepareCall(executeSql);
statement.setQueryTimeout(queryTimeout());
List outNames = new ArrayList<>();
if (!args.isEmpty()) {
setPreparedStoreArgs(statement, args, argNames);
for (String name : argNames.keySet()) {
if (args.containsKey(name)) {
ParamMode mode = args.get(name).getParamMode();
if (mode == ParamMode.OUT || mode == ParamMode.IN_OUT) {
outNames.add(name);
}
}
}
}
statement.execute();
printSqlConsole(statement);
if (outNames.isEmpty()) {
ResultSet resultSet = statement.getResultSet();
List dataRows = JdbcUtil.createDataRows(resultSet, "", -1);
JdbcUtil.closeResultSet(resultSet);
if (dataRows.isEmpty()) {
return DataRow.of();
}
return DataRow.of("result", dataRows);
}
Object[] values = new Object[outNames.size()];
int resultIndex = 0;
for (Map.Entry> e : argNames.entrySet()) {
if (outNames.contains(e.getKey())) {
for (Integer i : e.getValue()) {
Object result = statement.getObject(i);
if (Objects.isNull(result)) {
values[resultIndex] = null;
} else if (result instanceof ResultSet) {
List rows = JdbcUtil.createDataRows((ResultSet) result, "", -1);
JdbcUtil.closeResultSet((ResultSet) result);
values[resultIndex] = rows;
} else {
values[resultIndex] = result;
}
resultIndex++;
}
}
}
return DataRow.of(outNames.toArray(new String[0]), values);
} catch (SQLException e) {
try {
JdbcUtil.closeStatement(statement);
} catch (SQLException ex) {
e.addSuppressed(ex);
}
statement = null;
reason = e;
releaseConnection(connection, getDataSource());
throw new UncheckedSqlException("execute procedure error:\n" + procedure + "\n" + args, e);
} finally {
try {
JdbcUtil.closeStatement(statement);
} catch (SQLException e) {
log.error("close statement error.", e);
}
releaseConnection(connection, getDataSource());
watchSql(procedure, args, startTime, System.currentTimeMillis(), reason);
}
}
/**
* Debug executed sql and args.
*
* @param sql sql
* @param args args
*/
protected void debugSql(String sql, Collection extends Map> args) {
if (log.isDebugEnabled()) {
log.debug("SQL: {}", SqlHighlighter.highlightIfAnsiCapable(sql));
for (Map arg : args) {
StringJoiner sb = new StringJoiner(", ", "{", "}");
arg.forEach((k, v) -> {
if (v == null) {
sb.add(k + " -> null");
} else {
sb.add(k + " -> " + v + "(" + v.getClass().getSimpleName() + ")");
}
});
log.debug("Args: {}", sb);
}
}
}
/**
* Print sql log, e.g postgresql:
*
* raise notice 'my console.';
*
*
* @param sc sql statement object
*/
private void printSqlConsole(Statement sc) {
if (log.isWarnEnabled()) {
try {
SQLWarning warning = sc.getWarnings();
if (warning != null) {
String state = warning.getSQLState();
warning.forEach(r -> log.warn("[{}] [{}] {}", LocalDateTime.now(), state, r.getMessage()));
}
} catch (SQLException e) {
log.error("get sql warning error.", e);
}
}
}
}