All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.benchto.driver.macro.query.QueryMacroExecutionDriver Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.benchto.driver.macro.query;

import io.trino.benchto.driver.Benchmark;
import io.trino.benchto.driver.BenchmarkExecutionException;
import io.trino.benchto.driver.Query;
import io.trino.benchto.driver.loader.BenchmarkDescriptor;
import io.trino.benchto.driver.loader.QueryLoader;
import io.trino.benchto.driver.loader.SqlStatementGenerator;
import io.trino.benchto.driver.macro.MacroExecutionDriver;
import io.trino.benchto.driver.utils.QueryUtils;
import io.trino.jdbc.TrinoConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;

import javax.sql.DataSource;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;

@Component
public class QueryMacroExecutionDriver
        implements MacroExecutionDriver
{
    private static final Logger LOGGER = LoggerFactory.getLogger(QueryMacroExecutionDriver.class);
    private static final String SET_SESSION = "set session";
    private static final Pattern KEY_VALUE_PATTERN = Pattern.compile("([^=]+)='??([^']+)'??");

    @Autowired
    private ApplicationContext applicationContext;

    @Autowired
    private QueryLoader queryLoader;

    @Autowired
    private SqlStatementGenerator sqlStatementGenerator;

    public boolean canExecuteBenchmarkMacro(String macroName)
    {
        return macroName.endsWith(".sql");
    }

    @Override
    public void runBenchmarkMacro(String macroName, Optional benchmarkOptional, Optional connectionOptional)
    {
        checkArgument(benchmarkOptional.isPresent(), "Benchmark is required to run query based macro");
        Benchmark benchmark = benchmarkOptional.get();
        Query macroQuery = queryLoader.loadFromFile(macroName);

        List sqlStatements = sqlStatementGenerator.generateQuerySqlStatement(macroQuery, benchmark.getNonReservedKeywordVariables());

        try {
            if (connectionOptional.isPresent() && !macroQuery.getProperty(BenchmarkDescriptor.DATA_SOURCE_KEY).isPresent()) {
                runSqlStatements(connectionOptional.get(), sqlStatements);
            }
            else {
                String dataSourceName = macroQuery.getProperty(BenchmarkDescriptor.DATA_SOURCE_KEY, benchmark.getDataSource());
                try (Connection connection = getConnectionFor(dataSourceName)) {
                    runSqlStatements(connection, sqlStatements);
                }
            }
        }
        catch (SQLException e) {
            throw new BenchmarkExecutionException(
                    "Could not execute macro SQL queries for benchmark: " + benchmark, e);
        }
    }

    private void runSqlStatements(Connection connection, List sqlStatements)
            throws SQLException
    {
        for (String sqlStatement : sqlStatements) {
            sqlStatement = sqlStatement.trim();
            LOGGER.info("Executing macro query: {}", sqlStatement);
            if (sqlStatement.toLowerCase().startsWith(SET_SESSION) && connection.isWrapperFor(TrinoConnection.class)) {
                setSessionForTrino(connection, sqlStatement);
            }
            else {
                try (Statement statement = connection.createStatement()) {
                    if (statement.execute(sqlStatement)) {
                        try (ResultSet resultSet = statement.getResultSet()) {
                            QueryUtils.fetchRows(sqlStatement, resultSet);
                        }
                    }
                }
            }
        }
    }

    private void setSessionForTrino(Connection connection, String sqlStatement)
    {
        TrinoConnection trinoConnection;
        try {
            trinoConnection = connection.unwrap(TrinoConnection.class);
        }
        catch (SQLException e) {
            LOGGER.error(e.getMessage());
            throw new UnsupportedOperationException(format("SET SESSION for non PrestoConnection [%s] is not supported", connection.getClass()));
        }
        String[] keyValue = extractKeyValue(sqlStatement);
        trinoConnection.setSessionProperty(keyValue[0].trim(), keyValue[1].trim());
    }

    public static String[] extractKeyValue(String sqlStatement)
    {
        String keyValueSql = sqlStatement.substring(SET_SESSION.length()).trim();
        Matcher matcher = KEY_VALUE_PATTERN.matcher(keyValueSql);
        checkState(matcher.matches(), "Unexpected SET SESSION format [%s]", sqlStatement);
        String[] keyValue = new String[2];
        keyValue[0] = matcher.group(1).trim();
        keyValue[1] = matcher.group(2).trim();
        return keyValue;
    }

    private Connection getConnectionFor(String dataSourceName)
            throws SQLException
    {
        return applicationContext.getBean(dataSourceName, DataSource.class).getConnection();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy