All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.querydsl.sql.AbstractSQLQuery Maven / Gradle / Ivy

/*
 * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.querydsl.sql;

import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import javax.annotation.Nullable;
import javax.inject.Provider;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import com.google.common.collect.ImmutableList;
import com.mysema.commons.lang.CloseableIterator;
import com.querydsl.core.*;
import com.querydsl.core.support.QueryMixin;
import com.querydsl.core.types.*;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.SimpleExpression;
import com.querydsl.core.types.dsl.Wildcard;
import com.querydsl.core.util.ResultSetAdapter;

/**
 * {@code AbstractSQLQuery} is the base type for SQL query implementations
 *
 * @param  result type
 * @param  concrete subtype
 *
 * @author tiwe
 */
public abstract class AbstractSQLQuery> extends ProjectableSQLQuery {

    protected static final String PARENT_CONTEXT = AbstractSQLQuery.class.getName() + "#PARENT_CONTEXT";

    private static final Logger logger = LoggerFactory.getLogger(AbstractSQLQuery.class);

    private static final QueryFlag rowCountFlag = new QueryFlag(QueryFlag.Position.AFTER_PROJECTION, ", count(*) over() ");

    @Nullable
    private Provider connProvider;

    @Nullable
    private Connection conn;

    protected SQLListeners listeners;

    protected boolean useLiterals;

    private boolean getLastCell;

    private Object lastCell;

    private SQLListenerContext parentContext;

    private StatementOptions statementOptions = StatementOptions.DEFAULT;

    public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration) {
        this(conn, configuration, new DefaultQueryMetadata());
    }

    public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration, QueryMetadata metadata) {
        super(new QueryMixin(metadata, false), configuration);
        this.conn = conn;
        this.listeners = new SQLListeners(configuration.getListeners());
        this.useLiterals = configuration.getUseLiterals();
    }

    public AbstractSQLQuery(Provider connProvider, Configuration configuration) {
        this(connProvider, configuration, new DefaultQueryMetadata());
    }

    public AbstractSQLQuery(Provider connProvider, Configuration configuration, QueryMetadata metadata) {
        super(new QueryMixin(metadata, false), configuration);
        this.connProvider = connProvider;
        this.listeners = new SQLListeners(configuration.getListeners());
        this.useLiterals = configuration.getUseLiterals();
    }

    /**
     * Create an alias for the expression
     *
     * @param alias alias
     * @return this as alias
     */
    public SimpleExpression as(String alias) {
        return Expressions.as(this, alias);
    }

    /**
     * Create an alias for the expression
     *
     * @param alias alias
     * @return this as alias
     */
    @SuppressWarnings("unchecked")
    public SimpleExpression as(Path alias) {
        return Expressions.as(this, (Path) alias);
    }

    /**
     * Add a listener
     *
     * @param listener listener to add
     */
    public void addListener(SQLListener listener) {
        listeners.add(listener);
    }

    @Override
    public long fetchCount() {
        try {
            return unsafeCount();
        } catch (SQLException e) {
            String error = "Caught " + e.getClass().getName();
            logger.error(error, e);
            throw configuration.translate(e);
        }
    }

    /**
     * If you use forUpdate() with a backend that uses page or row locks, rows examined by the
     * query are write-locked until the end of the current transaction.
     *
     * Not supported for SQLite and CUBRID
     *
     * @return the current object
     */
    public Q forUpdate() {
        QueryFlag forUpdateFlag = configuration.getTemplates().getForUpdateFlag();
        return addFlag(forUpdateFlag);
    }

    /**
     * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update.
     *
     * Supported by MySQL, PostgreSQL, SQLServer.
     *
     * @return the current object
     *
     * @throws QueryException
     *          if the FOR SHARE is not supported.
     */
    public Q forShare() {
        return forShare(false);
    }

    /**
     * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update.
     *
     * Supported by MySQL, PostgreSQL, SQLServer.
     *
     * @param fallbackToForUpdate
     *          if the FOR SHARE is not supported and this parameter is true, the
     *          {@link #forUpdate()} functionality will be used.
     *
     * @return the current object
     *
     * @throws QueryException
     *          if the FOR SHARE is not supported and fallbackToForUpdate is set to
     *          false.
     */
    public Q forShare(boolean fallbackToForUpdate) {
        SQLTemplates sqlTemplates = configuration.getTemplates();

        if (sqlTemplates.isForShareSupported()) {
            QueryFlag forShareFlag = sqlTemplates.getForShareFlag();
            return addFlag(forShareFlag);
        }

        if (fallbackToForUpdate) {
            return forUpdate();
        }

        throw new QueryException("Using forShare() is not supported");
    }

    @Override
    protected SQLSerializer createSerializer() {
        SQLSerializer serializer = new SQLSerializer(configuration);
        serializer.setUseLiterals(useLiterals);
        return serializer;
    }

    @Nullable
    private  U get(ResultSet rs, Expression expr, int i, Class type) throws SQLException {
        return configuration.get(rs, expr instanceof Path ? (Path) expr : null, i, type);
    }

    private void set(PreparedStatement stmt, Path path, int i, Object value) throws SQLException {
        configuration.set(stmt, path, i, value);
    }

    /**
     * Called to create and start a new SQL Listener context
     *
     * @param connection the database connection
     * @param metadata   the meta data for that context
     * @return the newly started context
     */
    protected SQLListenerContextImpl startContext(Connection connection, QueryMetadata metadata) {
        SQLListenerContextImpl context = new SQLListenerContextImpl(metadata, connection);
        if (parentContext != null) {
            context.setData(PARENT_CONTEXT, parentContext);
        }
        listeners.start(context);
        return context;
    }

    /**
     * Called to make the call back to listeners when an exception happens
     *
     * @param context the current context in play
     * @param e       the exception
     */
    protected void onException(SQLListenerContextImpl context, Exception e) {
        context.setException(e);
        listeners.exception(context);
    }

    /**
     * Called to end a SQL listener context
     *
     * @param context the listener context to end
     */
    protected void endContext(SQLListenerContext context) {
        listeners.end(context);
    }

    /**
     * Get the results as a JDBC ResultSet
     *
     * @param exprs the expression arguments to retrieve
     * @return results as ResultSet
     * @deprecated Use @{code select(..)} to define the projection and {@code getResults()} to obtain
     *             the result set
     */
    @Deprecated
    public ResultSet getResults(Expression... exprs) {
        if (exprs.length > 0) {
            queryMixin.setProjection(exprs);
        }
        return getResults();
    }

    /**
     * Get the results as a JDBC ResultSet
     *
     * @return results as ResultSet
     */
    public ResultSet getResults() {
        final SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
        String queryString = null;
        List constants = ImmutableList.of();

        try {
            listeners.preRender(context);
            SQLSerializer serializer = serialize(false);
            queryString = serializer.toString();
            logQuery(queryString, serializer.getConstants());
            context.addSQL(queryString);
            listeners.rendered(context);

            listeners.notifyQuery(queryMixin.getMetadata());

            constants = serializer.getConstants();

            listeners.prePrepare(context);
            final PreparedStatement stmt = getPreparedStatement(queryString);
            setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams());
            context.addPreparedStatement(stmt);
            listeners.prepared(context);

            listeners.preExecute(context);
            final ResultSet rs = stmt.executeQuery();
            listeners.executed(context);

            return new ResultSetAdapter(rs) {
                @Override
                public void close() throws SQLException {
                    try {
                        super.close();
                    } finally {
                        stmt.close();
                        reset();
                        endContext(context);
                    }
                }
            };
        } catch (SQLException e) {
            onException(context, e);
            reset();
            endContext(context);
            throw configuration.translate(queryString, constants, e);
        }
    }

    private PreparedStatement getPreparedStatement(String queryString) throws SQLException {
        PreparedStatement statement = connection().prepareStatement(queryString);
        if (statementOptions.getFetchSize() != null) {
            statement.setFetchSize(statementOptions.getFetchSize());
        }
        if (statementOptions.getMaxFieldSize() != null) {
            statement.setMaxFieldSize(statementOptions.getMaxFieldSize());
        }
        if (statementOptions.getQueryTimeout() != null) {
            statement.setQueryTimeout(statementOptions.getQueryTimeout());
        }
        if (statementOptions.getMaxRows() != null) {
            statement.setMaxRows(statementOptions.getMaxRows());
        }
        return statement;
    }

    protected Configuration getConfiguration() {
        return configuration;
    }

    @SuppressWarnings("unchecked")
    @Override
    public CloseableIterator iterate() {
        Expression expr = (Expression) queryMixin.getMetadata().getProjection();
        return iterateSingle(queryMixin.getMetadata(), expr);
    }

    @SuppressWarnings("unchecked")
    private CloseableIterator iterateSingle(QueryMetadata metadata, @Nullable final Expression expr) {
        SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
        String queryString = null;
        List constants = ImmutableList.of();

        try {
            listeners.preRender(context);
            SQLSerializer serializer = serialize(false);
            queryString = serializer.toString();
            logQuery(queryString, serializer.getConstants());
            context.addSQL(queryString);
            listeners.rendered(context);


            listeners.notifyQuery(queryMixin.getMetadata());
            constants = serializer.getConstants();

            listeners.prePrepare(context);
            final PreparedStatement stmt = getPreparedStatement(queryString);
            setParameters(stmt, constants, serializer.getConstantPaths(), metadata.getParams());
            context.addPreparedStatement(stmt);
            listeners.prepared(context);

            listeners.preExecute(context);
            final ResultSet rs = stmt.executeQuery();
            listeners.executed(context);

            if (expr == null) {
                return new SQLResultIterator(configuration, stmt, rs, listeners, context) {
                    @Override
                    public T produceNext(ResultSet rs) throws Exception {
                        return (T) rs.getObject(1);
                    }
                };
            } else if (expr instanceof FactoryExpression) {
                return new SQLResultIterator(configuration, stmt, rs, listeners, context) {
                    @Override
                    public T produceNext(ResultSet rs) throws Exception {
                        return newInstance((FactoryExpression) expr, rs, 0);
                    }
                };
            } else if (expr.equals(Wildcard.all)) {
                return new SQLResultIterator(configuration, stmt, rs, listeners, context) {
                    @Override
                    public T produceNext(ResultSet rs) throws Exception {
                        Object[] rv = new Object[rs.getMetaData().getColumnCount()];
                        for (int i = 0; i < rv.length; i++) {
                            rv[i] = rs.getObject(i + 1);
                        }
                        return (T) rv;
                    }
                };
            } else {
                return new SQLResultIterator(configuration, stmt, rs, listeners, context) {
                    @Override
                    public T produceNext(ResultSet rs) throws Exception {
                        return get(rs, expr, 1, expr.getType());
                    }
                };
            }

        } catch (SQLException e) {
            onException(context, e);
            endContext(context);
            throw configuration.translate(queryString, constants, e);
        } catch (RuntimeException e) {
            logger.error("Caught " + e.getClass().getName() + " for " + queryString);
            throw e;
        } finally {
            reset();
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public List fetch() {
        Expression expr = (Expression) queryMixin.getMetadata().getProjection();
        SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
        String queryString = null;
        List constants = ImmutableList.of();

        try {
            listeners.preRender(context);
            SQLSerializer serializer = serialize(false);
            queryString = serializer.toString();
            logQuery(queryString, serializer.getConstants());
            context.addSQL(queryString);
            listeners.rendered(context);

            listeners.notifyQuery(queryMixin.getMetadata());
            constants = serializer.getConstants();

            listeners.prePrepare(context);
            final PreparedStatement stmt = getPreparedStatement(queryString);
            try {
                setParameters(stmt, constants, serializer.getConstantPaths(), queryMixin.getMetadata().getParams());
                context.addPreparedStatement(stmt);
                listeners.prepared(context);

                listeners.preExecute(context);
                final ResultSet rs = stmt.executeQuery();
                listeners.executed(context);
                try {
                    lastCell = null;
                    final List rv = new ArrayList();
                    if (expr instanceof FactoryExpression) {
                        FactoryExpression fe = (FactoryExpression) expr;
                        while (rs.next()) {
                            if (getLastCell) {
                                lastCell = rs.getObject(fe.getArgs().size() + 1);
                                getLastCell = false;
                            }
                            rv.add(newInstance(fe, rs, 0));
                        }
                    } else if (expr.equals(Wildcard.all)) {
                        while (rs.next()) {
                            Object[] row = new Object[rs.getMetaData().getColumnCount()];
                            if (getLastCell) {
                                lastCell = rs.getObject(row.length);
                                getLastCell = false;
                            }
                            for (int i = 0; i < row.length; i++) {
                                row[i] = rs.getObject(i + 1);
                            }
                            rv.add((T) row);
                        }
                    } else {
                        while (rs.next()) {
                            if (getLastCell) {
                                lastCell = rs.getObject(2);
                                getLastCell = false;
                            }
                            rv.add(get(rs, expr, 1, expr.getType()));
                        }
                    }
                    return rv;
                } catch (IllegalAccessException e) {
                    onException(context, e);
                    throw new QueryException(e);
                } catch (InvocationTargetException e) {
                    onException(context,e);
                    throw new QueryException(e);
                } catch (InstantiationException e) {
                    onException(context,e);
                    throw new QueryException(e);
                } catch (SQLException e) {
                    onException(context,e);
                    throw configuration.translate(queryString, constants, e);
                } finally {
                    rs.close();
                }
            } finally {
                stmt.close();
            }
        } catch (SQLException e) {
            onException(context, e);
            throw configuration.translate(queryString, constants, e);
        } finally {
            endContext(context);
            reset();
        }
    }

    @SuppressWarnings("unchecked")
    @Override
    public QueryResults fetchResults() {
        parentContext = startContext(connection(), queryMixin.getMetadata());
        Expression expr = (Expression) queryMixin.getMetadata().getProjection();
        QueryModifiers originalModifiers = queryMixin.getMetadata().getModifiers();
        try {
            if (configuration.getTemplates().isCountViaAnalytics()
                && queryMixin.getMetadata().getGroupBy().isEmpty()) {
                List results;
                try {
                    queryMixin.addFlag(rowCountFlag);
                    getLastCell = true;
                    results = fetch();
                } finally {
                    queryMixin.removeFlag(rowCountFlag);
                }
                long total;
                if (!results.isEmpty()) {
                    if (lastCell instanceof Number) {
                        total = ((Number) lastCell).longValue();
                    } else {
                        throw new IllegalStateException("Unsupported lastCell instance " + lastCell);
                    }
                } else {
                    total = fetchCount();
                }
                return new QueryResults(results, originalModifiers, total);

            } else {
                queryMixin.setProjection(expr);
                long total = fetchCount();
                if (total > 0) {
                    return new QueryResults(fetch(), originalModifiers, total);
                } else {
                    return QueryResults.emptyResults();
                }
            }

        } finally {
            endContext(parentContext);
            reset();
            getLastCell = false;
            parentContext = null;
        }
    }

    private  RT newInstance(FactoryExpression c, ResultSet rs, int offset)
        throws InstantiationException, IllegalAccessException, InvocationTargetException, SQLException {
        Object[] args = new Object[c.getArgs().size()];
        for (int i = 0; i < args.length; i++) {
            args[i] = get(rs, c.getArgs().get(i), offset + i + 1, c.getArgs().get(i).getType());
        }
        return c.newInstance(args);
    }

    private void reset() {
        cleanupMDC();
    }

    protected void setParameters(PreparedStatement stmt, List objects, List> constantPaths,
            Map, ?> params) {
        if (objects.size() != constantPaths.size()) {
            throw new IllegalArgumentException("Expected " + objects.size() +
                    " paths, but got " + constantPaths.size());
        }
        for (int i = 0; i < objects.size(); i++) {
            Object o = objects.get(i);
            try {
                if (o instanceof ParamExpression) {
                    if (!params.containsKey(o)) {
                        throw new ParamNotSetException((ParamExpression) o);
                    }
                    o = params.get(o);
                }
                set(stmt, constantPaths.get(i), i + 1, o);
            } catch (SQLException e) {
                throw configuration.translate(e);
            }
        }
    }

    private long unsafeCount() throws SQLException {
        SQLListenerContextImpl context = startContext(connection(), getMetadata());
        String queryString = null;
        List constants = ImmutableList.of();
        PreparedStatement stmt = null;
        ResultSet rs = null;

        try {
            listeners.preRender(context);
            SQLSerializer serializer = serialize(true);
            queryString = serializer.toString();
            logQuery(queryString, serializer.getConstants());
            context.addSQL(queryString);
            listeners.rendered(context);

            constants = serializer.getConstants();
            listeners.prePrepare(context);

            stmt = getPreparedStatement(queryString);
            setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams());

            context.addPreparedStatement(stmt);
            listeners.prepared(context);

            listeners.preExecute(context);
            rs = stmt.executeQuery();
            boolean hasResult = rs.next();
            listeners.executed(context);

            if (hasResult) {
                return rs.getLong(1);
            } else {
                return 0;
            }
        } catch (SQLException e) {
            onException(context, e);
            throw configuration.translate(queryString, constants, e);
        } finally {
            try {
                if (rs != null) {
                    rs.close();
                }
            } finally {
                if (stmt != null) {
                    stmt.close();
                }
            }
            endContext(context);
            cleanupMDC();
        }
    }

    protected void logQuery(String queryString, Collection parameters) {
        if (logger.isDebugEnabled()) {
            String normalizedQuery = queryString.replace('\n', ' ');
            MDC.put(MDC_QUERY, normalizedQuery);
            MDC.put(MDC_PARAMETERS, String.valueOf(parameters));
            logger.debug(normalizedQuery);
        }
    }

    protected void cleanupMDC() {
        MDC.remove(MDC_QUERY);
        MDC.remove(MDC_PARAMETERS);
    }

    private Connection connection() {
        if (conn == null) {
            if (connProvider != null) {
                conn = connProvider.get();
            } else {
                throw new IllegalStateException("No connection provided");
            }
        }
        return conn;
    }

    /**
     * Set whether literals are used in SQL strings instead of parameter bindings (default: false)
     *
     * 

Warning: When literals are used, prepared statement won't have any parameter bindings * and also batch statements will only be simulated, but not executed as actual batch statements.

* * @param useLiterals true for literals and false for bindings */ public void setUseLiterals(boolean useLiterals) { this.useLiterals = useLiterals; } @Override protected void clone(Q query) { super.clone(query); this.useLiterals = query.useLiterals; this.listeners = new SQLListeners(query.listeners); } @Override public Q clone() { return this.clone(this.conn); } public abstract Q clone(Connection connection); /** * Set the options to be applied to the JDBC statements of this query * * @param statementOptions options to be applied to statements */ public void setStatementOptions(StatementOptions statementOptions) { this.statementOptions = statementOptions; } }