io.trino.benchto.driver.macro.query.QueryMacroExecutionDriver Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.benchto.driver.macro.query;
import io.trino.benchto.driver.Benchmark;
import io.trino.benchto.driver.BenchmarkExecutionException;
import io.trino.benchto.driver.Query;
import io.trino.benchto.driver.loader.BenchmarkDescriptor;
import io.trino.benchto.driver.loader.QueryLoader;
import io.trino.benchto.driver.loader.SqlStatementGenerator;
import io.trino.benchto.driver.macro.MacroExecutionDriver;
import io.trino.benchto.driver.utils.QueryUtils;
import io.trino.jdbc.TrinoConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
@Component
public class QueryMacroExecutionDriver
implements MacroExecutionDriver
{
private static final Logger LOGGER = LoggerFactory.getLogger(QueryMacroExecutionDriver.class);
private static final String SET_SESSION = "set session";
private static final Pattern KEY_VALUE_PATTERN = Pattern.compile("([^=]+)='??([^']+)'??");
@Autowired
private ApplicationContext applicationContext;
@Autowired
private QueryLoader queryLoader;
@Autowired
private SqlStatementGenerator sqlStatementGenerator;
public boolean canExecuteBenchmarkMacro(String macroName)
{
return macroName.endsWith(".sql");
}
@Override
public void runBenchmarkMacro(String macroName, Optional benchmarkOptional, Optional connectionOptional)
{
checkArgument(benchmarkOptional.isPresent(), "Benchmark is required to run query based macro");
Benchmark benchmark = benchmarkOptional.get();
Query macroQuery = queryLoader.loadFromFile(macroName);
List sqlStatements = sqlStatementGenerator.generateQuerySqlStatement(macroQuery, benchmark.getNonReservedKeywordVariables());
try {
if (connectionOptional.isPresent() && !macroQuery.getProperty(BenchmarkDescriptor.DATA_SOURCE_KEY).isPresent()) {
runSqlStatements(connectionOptional.get(), sqlStatements);
}
else {
String dataSourceName = macroQuery.getProperty(BenchmarkDescriptor.DATA_SOURCE_KEY, benchmark.getDataSource());
try (Connection connection = getConnectionFor(dataSourceName)) {
runSqlStatements(connection, sqlStatements);
}
}
}
catch (SQLException e) {
throw new BenchmarkExecutionException(
"Could not execute macro SQL queries for benchmark: " + benchmark, e);
}
}
private void runSqlStatements(Connection connection, List sqlStatements)
throws SQLException
{
for (String sqlStatement : sqlStatements) {
sqlStatement = sqlStatement.trim();
LOGGER.info("Executing macro query: {}", sqlStatement);
if (sqlStatement.toLowerCase().startsWith(SET_SESSION) && connection.isWrapperFor(TrinoConnection.class)) {
setSessionForTrino(connection, sqlStatement);
}
else {
try (Statement statement = connection.createStatement()) {
if (statement.execute(sqlStatement)) {
try (ResultSet resultSet = statement.getResultSet()) {
QueryUtils.fetchRows(sqlStatement, resultSet);
}
}
}
}
}
}
private void setSessionForTrino(Connection connection, String sqlStatement)
{
TrinoConnection trinoConnection;
try {
trinoConnection = connection.unwrap(TrinoConnection.class);
}
catch (SQLException e) {
LOGGER.error(e.getMessage());
throw new UnsupportedOperationException(format("SET SESSION for non PrestoConnection [%s] is not supported", connection.getClass()));
}
String[] keyValue = extractKeyValue(sqlStatement);
trinoConnection.setSessionProperty(keyValue[0].trim(), keyValue[1].trim());
}
public static String[] extractKeyValue(String sqlStatement)
{
String keyValueSql = sqlStatement.substring(SET_SESSION.length()).trim();
Matcher matcher = KEY_VALUE_PATTERN.matcher(keyValueSql);
checkState(matcher.matches(), "Unexpected SET SESSION format [%s]", sqlStatement);
String[] keyValue = new String[2];
keyValue[0] = matcher.group(1).trim();
keyValue[1] = matcher.group(2).trim();
return keyValue;
}
private Connection getConnectionFor(String dataSourceName)
throws SQLException
{
return applicationContext.getBean(dataSourceName, DataSource.class).getConnection();
}
}