All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.blazebit.persistence.impl.query.CustomQuerySpecification Maven / Gradle / Ivy

The newest version!
package com.blazebit.persistence.impl.query;

import com.blazebit.persistence.impl.AbstractCommonQueryBuilder;
import com.blazebit.persistence.impl.CustomSQLQuery;
import com.blazebit.persistence.impl.CustomSQLTypedQuery;
import com.blazebit.persistence.impl.plan.CustomSelectQueryPlan;
import com.blazebit.persistence.impl.plan.ModificationQueryPlan;
import com.blazebit.persistence.impl.plan.SelectQueryPlan;
import com.blazebit.persistence.spi.*;

import javax.persistence.EntityManager;
import javax.persistence.Query;
import java.util.*;

/**
 *
 * @author Christian Beikov
 * @since 1.2.0
 */
public class CustomQuerySpecification implements QuerySpecification {

    protected final EntityManager em;
    protected final DbmsDialect dbmsDialect;
    protected final ServiceProvider serviceProvider;
    protected final ExtendedQuerySupport extendedQuerySupport;

    protected final DbmsStatementType statementType;
    protected final Query baseQuery;
    protected final Set parameterListNames;
    protected final String limit;
    protected final String offset;

    protected final List keyRestrictedLeftJoinAliases;
    protected final List entityFunctionNodes;
    protected final boolean recursive;
    protected final List ctes;
    protected final boolean shouldRenderCtes;

    protected boolean dirty;
    protected String sql;
    protected List participatingQueries;
    protected Map addedCtes;

    public CustomQuerySpecification(AbstractCommonQueryBuilder commonQueryBuilder, Query baseQuery, Set parameterListNames, String limit, String offset,
                                    List keyRestrictedLeftJoinAliases, List entityFunctionNodes, boolean recursive, List ctes, boolean shouldRenderCtes) {
        this.em = commonQueryBuilder.getEntityManager();
        this.dbmsDialect = commonQueryBuilder.getService(DbmsDialect.class);
        this.serviceProvider = commonQueryBuilder;
        this.extendedQuerySupport = commonQueryBuilder.getService(ExtendedQuerySupport.class);
        this.statementType = commonQueryBuilder.getStatementType();
        this.baseQuery = baseQuery;
        this.parameterListNames = parameterListNames;
        this.limit = limit;
        this.offset = offset;

        this.keyRestrictedLeftJoinAliases = keyRestrictedLeftJoinAliases;
        this.entityFunctionNodes = entityFunctionNodes;
        this.recursive = recursive;
        this.ctes = ctes;
        this.shouldRenderCtes = shouldRenderCtes;
        this.dirty = true;
    }

    @Override
    public ModificationQueryPlan createModificationPlan(int firstResult, int maxResults) {
        throw new UnsupportedOperationException();
    }

    @Override
    public SelectQueryPlan createSelectPlan(int firstResult, int maxResults) {
        final String sql = getSql();
        return new CustomSelectQueryPlan(extendedQuerySupport, serviceProvider, baseQuery, participatingQueries, sql, firstResult, maxResults);
    }

    @Override
    public String getSql() {
        if (dirty) {
            initialize();
        }
        return sql;
    }

    @Override
    public List getParticipatingQueries() {
        if (dirty) {
            initialize();
        }
        return participatingQueries;
    }

    @Override
    public Map getAddedCtes() {
        if (dirty) {
            initialize();
        }
        return addedCtes;
    }

    @Override
    public Query getBaseQuery() {
        return baseQuery;
    }

    @Override
    public void onParameterChange(String parameterName) {
        if (parameterListNames.contains(parameterName)) {
            dirty = true;
        }
    }

    protected void initialize() {
        List participatingQueries = new ArrayList();

        String sqlQuery = extendedQuerySupport.getSql(em, baseQuery);
        StringBuilder sqlSb = applySqlTransformations(baseQuery, sqlQuery, participatingQueries);
        StringBuilder withClause = applyCtes(sqlSb, baseQuery, participatingQueries);
        Map addedCtes = applyExtendedSql(sqlSb, false, false, withClause, null, null);
        participatingQueries.add(baseQuery);

        this.sql = sqlSb.toString();
        this.participatingQueries = participatingQueries;
        this.addedCtes = addedCtes;
        this.dirty = false;
    }

    protected Map applyExtendedSql(StringBuilder sqlSb, boolean isSubquery, boolean isEmbedded, StringBuilder withClause, String[] returningColumns, Map includedModificationStates) {
        return dbmsDialect.appendExtendedSql(sqlSb, statementType, isSubquery, isEmbedded, withClause, null, null, returningColumns, includedModificationStates);
    }

    protected StringBuilder applyCtes(StringBuilder sqlSb, Query baseQuery, List participatingQueries) {
        if (!shouldRenderCtes || (ctes.isEmpty() && statementType != DbmsStatementType.DELETE)) {
            return null;
        }
        // EntityAlias -> CteName
        Map tableNameRemapping = new LinkedHashMap(0);

        StringBuilder sb = new StringBuilder(ctes.size() * 100);
        sb.append(dbmsDialect.getWithClause(recursive));
        sb.append(" ");

        boolean firstCte = true;
        for (CTENode cteInfo : ctes) {
            // Build queries and add as participating queries
            QuerySpecification nonRecursiveQuerySpecification = cteInfo.getNonRecursiveQuerySpecification();
            Query nonRecursiveQuery = nonRecursiveQuerySpecification.getBaseQuery();
            participatingQueries.addAll(nonRecursiveQuerySpecification.getParticipatingQueries());

            QuerySpecification recursiveQuerySpecification = null;
            if (cteInfo.isRecursive()) {
                recursiveQuerySpecification = cteInfo.getRecursiveQuerySpecification();
                participatingQueries.addAll(recursiveQuerySpecification.getParticipatingQueries());
            }

            // add cascading delete statements as CTEs
            firstCte = applyCascadingDelete(nonRecursiveQuery, participatingQueries, sb, cteInfo.getName(), firstCte);

            firstCte = applyAddedCtes(nonRecursiveQuerySpecification, cteInfo.getNonRecursiveTableNameRemappings(), sb, tableNameRemapping, firstCte);
            firstCte = applyAddedCtes(recursiveQuerySpecification, cteInfo.getRecursiveTableNameRemappings(), sb, tableNameRemapping, firstCte);

            if (firstCte) {
                firstCte = false;
            } else {
                sb.append(",\n");
            }

            sb.append(cteInfo.getHead());
            sb.append(" AS(\n");

            sb.append(cteInfo.getNonRecursiveQuerySpecification().getSql());

            if (cteInfo.isRecursive()) {
                if (cteInfo.isUnionAll()) {
                    sb.append("\nUNION ALL\n");
                } else {
                    sb.append("\nUNION\n");
                }
                sb.append(cteInfo.getRecursiveQuerySpecification().getSql());
            } else if (!dbmsDialect.supportsNonRecursiveWithClause()) {
                sb.append(cteInfo.getNonRecursiveWithClauseSuffix());
            }

            sb.append("\n)");
        }

        // Add cascading delete statements from base query as CTEs
        firstCte = applyCascadingDelete(baseQuery, participatingQueries, sb, "main_query", firstCte);

        // If no CTE has been added, we can just return
        if (firstCte) {
            return null;
        }

        for (CTENode cteInfo : ctes) {
            String cteName = cteInfo.getEntityName();
            // TODO: this is a hibernate specific integration detail
            // Replace the subview subselect that is generated for this cte
            final String subselect = "( select * from " + cteName + " )";
            int subselectIndex = 0;
            while ((subselectIndex = sb.indexOf(subselect, subselectIndex)) > -1) {
                sb.replace(subselectIndex, subselectIndex + subselect.length(), cteName);
            }

            final String mainSubselect = "( select * from " + cteName + " )";
            subselectIndex = 0;
            while ((subselectIndex = sqlSb.indexOf(mainSubselect, subselectIndex)) > -1) {
                sqlSb.replace(subselectIndex, subselectIndex + mainSubselect.length(), cteName);
            }
        }

        sb.append("\n");

        for (Map.Entry tableNameRemappingEntry : tableNameRemapping.entrySet()) {
            String sqlAlias = extendedQuerySupport.getSqlAlias(em, baseQuery, tableNameRemappingEntry.getKey());
            String newCteName = tableNameRemappingEntry.getValue();

            applyTableNameRemapping(sqlSb, sqlAlias, newCteName, null);
        }

        return sb;
    }

    private boolean applyAddedCtes(QuerySpecification querySpecification, Map cteTableNameRemappings, StringBuilder sb, Map tableNameRemapping, boolean firstCte) {
        if (querySpecification != null) {
            // CteName -> CteQueryString
            Map addedCtes = querySpecification.getAddedCtes();
            if (addedCtes != null && addedCtes.size() > 0) {
                for (Map.Entry simpleCteEntry : addedCtes.entrySet()) {
                    for (Map.Entry cteTableNameRemapping : cteTableNameRemappings.entrySet()) {
                        if (cteTableNameRemapping.getValue().equals(simpleCteEntry.getKey())) {
                            tableNameRemapping.put(cteTableNameRemapping.getKey(), cteTableNameRemapping.getValue());
                        }
                    }

                    if (firstCte) {
                        firstCte = false;
                    } else {
                        sb.append(",\n");
                    }

                    sb.append(simpleCteEntry.getKey());
                    sb.append(" AS (\n");
                    sb.append(simpleCteEntry.getValue());
                    sb.append("\n)");
                }
            }
        }

        return firstCte;
    }

    private boolean applyCascadingDelete(Query baseQuery, List participatingQueries, StringBuilder sb, String cteBaseName, boolean firstCte) {
        List cascadingDeleteSqls = extendedQuerySupport.getCascadingDeleteSql(em, baseQuery);
        StringBuilder cascadingDeleteSqlSb = new StringBuilder();
        int cteBaseNameCount = 0;
        for (String cascadingDeleteSql : cascadingDeleteSqls) {
            if (firstCte) {
                firstCte = false;
            } else {
                sb.append(",\n");
            }

            // Since we kind of need the parameters from the base query, it will participate for each cascade
            participatingQueries.add(baseQuery);

            sb.append(cteBaseName);
            sb.append('_').append(cteBaseNameCount++);
            sb.append(" AS (\n");

            cascadingDeleteSqlSb.setLength(0);
            cascadingDeleteSqlSb.append(cascadingDeleteSql);
            dbmsDialect.appendExtendedSql(cascadingDeleteSqlSb, DbmsStatementType.DELETE, false, true, null, null, null, null, null);
            sb.append(cascadingDeleteSqlSb);

            sb.append("\n)");
        }

        return firstCte;
    }

    protected StringBuilder applySqlTransformations(Query baseQuery, String sqlQuery, List participatingQueries) {
        if (entityFunctionNodes.isEmpty() && keyRestrictedLeftJoinAliases.isEmpty()) {
            return new StringBuilder(sqlQuery);
        }

        // TODO: find a better size estimate
        StringBuilder sb = new StringBuilder(sqlQuery.length() +
                // Just a stupid estimate
                entityFunctionNodes.size() * 100 +
                // we put "(select * from )" around
                keyRestrictedLeftJoinAliases.size() * 20);
        sb.append(sqlQuery);

        for (String sqlAlias : keyRestrictedLeftJoinAliases) {
            applyLeftJoinSubqueryRewrite(sb, sqlAlias);
        }

        for (EntityFunctionNode node : entityFunctionNodes) {
            String valuesTableSqlAlias = node.getTableAlias();
            String valuesClause = node.getValuesClause();
            String valuesAliases = node.getValuesAliases();

            // TODO: this is a hibernate specific integration detail
            // Replace the subview subselect that is generated for this subselect
            String entityName = node.getEntityClass().getSimpleName();
            final String subselect = "( select * from " + entityName + " )";
            int subselectIndex = 0;
            while ((subselectIndex = sb.indexOf(subselect, subselectIndex)) > -1) {
                sb.replace(subselectIndex, subselectIndex + subselect.length(), entityName);
            }

            applyTableNameRemapping(sb, valuesTableSqlAlias, valuesClause, valuesAliases);
            participatingQueries.add(node.getValueQuery());
        }

        return sb;
    }

    private void applyLeftJoinSubqueryRewrite(StringBuilder sb, String sqlAlias) {
        final String searchAs = " as";
        final String searchAlias = " " + sqlAlias;
        int searchIndex = 0;
        while ((searchIndex = sb.indexOf(searchAlias, searchIndex)) > -1) {
            char c = sb.charAt(searchIndex + searchAlias.length());
            if (c == '.') {
                // This is a dereference of the alias, skip this
            } else {
                int[] indexRange;
                if (searchAs.equalsIgnoreCase(sb.substring(searchIndex - searchAs.length(), searchIndex))) {
                    // Uses aliasing with the AS keyword
                    indexRange = rtrimBackwardsToFirstWhitespace(sb, searchIndex - searchAs.length());
                } else {
                    // Uses aliasing without the AS keyword
                    indexRange = rtrimBackwardsToFirstWhitespace(sb, searchIndex);
                }

                // Jump back two left joins to further inspect the join table
                String leftJoinString = "left outer join ";
                int joinTableJoinIndex = -1;
                int targetTableJoinIndex = -1;
                int currentIndex = -1;
                while ((currentIndex = sb.indexOf(leftJoinString, currentIndex + 1)) < indexRange[0] && currentIndex > 0) {
                    joinTableJoinIndex = targetTableJoinIndex;
                    targetTableJoinIndex = currentIndex;
                }

                if (joinTableJoinIndex < 1) {
                    throw new IllegalStateException("The left join for subquery rewriting could not be found!");
                }

                int joinTableIndex = joinTableJoinIndex + leftJoinString.length();

                // Extract the on condition so we can move it
                String onString = " on ";
                int onIndex = sb.indexOf(onString, joinTableIndex);

                if (onIndex > targetTableJoinIndex) {
                    throw new IllegalStateException("The left join for subquery rewriting could not be found!");
                }
                StringBuilder onCondition = new StringBuilder(sb.substring(onIndex, targetTableJoinIndex));

                // Extract the join table alias since we need to replace it
                int aliasIndex = sb.indexOf(" ", joinTableIndex) + 1;
                String joinTableAlias = sb.substring(aliasIndex, onIndex);

                int realOnConditionStartIndex = indexRange[1];
                // Find the index at which the actual key restriction begins
                String realOnConditionStart = " and (";
                int realOnConditionIndex = sb.indexOf(realOnConditionStart, realOnConditionStartIndex);

                // We need to find the column name of the key
                List joinTableParentExpressions = getColumnExpressions(sb, joinTableAlias, onIndex, targetTableJoinIndex);
                List joinTableKeyExpressions = getColumnExpressions(sb, joinTableAlias, realOnConditionIndex, sb.length());

                if (joinTableKeyExpressions.size() != 1) {
                    throw new IllegalStateException("Expected exactly one key expression but found: " + joinTableKeyExpressions.size());
                }

                String joinTableKeyExpression = joinTableKeyExpressions.get(0);

                // Construct the subquery part that will replace the join table join part
                String joinTableKeyAlias = "join_table_key";
                String joinTableParentAliasPrefix = "join_table_parent_";
                StringBuilder subqueryPrefixSb = new StringBuilder();
                subqueryPrefixSb.append("(select ");
                subqueryPrefixSb.append(joinTableKeyExpression);
                subqueryPrefixSb.append(" as ");
                subqueryPrefixSb.append(joinTableKeyAlias);
                subqueryPrefixSb.append(", ");
                subqueryPrefixSb.append(sqlAlias);
                subqueryPrefixSb.append(".*");

                for (int i = 0; i < joinTableParentExpressions.size(); i++) {
                    subqueryPrefixSb.append(", ");
                    subqueryPrefixSb.append(joinTableParentExpressions.get(i));
                    subqueryPrefixSb.append(" as ");
                    subqueryPrefixSb.append(joinTableParentAliasPrefix);
                    subqueryPrefixSb.append(i);

                    String newParentExpression = sqlAlias + "." + joinTableParentAliasPrefix + i;
                    int lengthDifference = newParentExpression.length() - joinTableParentExpressions.get(i).length();
                    replaceExpressionUntil(0, onCondition.length(), lengthDifference, onCondition, joinTableParentExpressions.get(i), newParentExpression);
                }

                subqueryPrefixSb.append(" from ");

                // Replace the join table join with a subquery part
                String subqueryPrefix =  subqueryPrefixSb.toString();
                String subqueryInsert = subqueryPrefix + sb.substring(joinTableIndex, onIndex);
                sb.replace(joinTableIndex, targetTableJoinIndex - 1, subqueryInsert);

                // Adapt index since we replaced stuff before
                realOnConditionStartIndex += (subqueryInsert.length() - (targetTableJoinIndex - joinTableIndex));
                realOnConditionIndex += (subqueryInsert.length() - (targetTableJoinIndex - joinTableIndex - 1));

                // Insert the target table alias for the subquery and the on condition for joining with the parent
                String subqueryEnd = ") " + sqlAlias + onCondition;
                sb.insert(realOnConditionIndex, subqueryEnd);
                realOnConditionStartIndex += subqueryEnd.length();

                // Replace the join table key expression with the target table key expression until reaching joinTableIndex and then again at realOnConditionStartIndex
                String targetTableKeyExpression = sqlAlias + "." + joinTableKeyAlias;
                int lengthDifference = targetTableKeyExpression.length() - joinTableKeyExpression.length();
                // Replace the join table alias with the target table alias until reaching joinTableIndex
                int diff = replaceExpressionUntil(-1, joinTableIndex, lengthDifference, sb, joinTableKeyExpression, targetTableKeyExpression);
                // and then again from realOnConditionStartIndex until the end
                replaceExpressionUntil(realOnConditionStartIndex + diff, sb.length(), lengthDifference, sb, joinTableKeyExpression, targetTableKeyExpression);

                break;
            }

            searchIndex = searchIndex + 1;
        }
    }

    private List getColumnExpressions(StringBuilder sb, String tableAlias, int startIndex, int endIndex) {
        String columnExpressionStart = tableAlias + ".";
        List columnExpressions = new ArrayList();
        while (startIndex < endIndex) {
            int expressionIndex = sb.indexOf(columnExpressionStart, startIndex);

            if (expressionIndex < 0) {
                if (columnExpressions.isEmpty()) {
                    throw new IllegalStateException("The join table column expression needed for subquery rewriting could not be found!");
                }
                break;
            }

            StringBuilder columnExpressionSb = new StringBuilder(80);
            columnExpressionSb.append(columnExpressionStart);
            expressionIndex += columnExpressionStart.length();
            char keyChar;
            while (isIdentifier(keyChar = sb.charAt(expressionIndex))) {
                columnExpressionSb.append(keyChar);
                expressionIndex++;
            }
            columnExpressions.add(columnExpressionSb.toString());
            startIndex = expressionIndex;
        }

        return columnExpressions;
    }

    private int replaceExpressionUntil(int searchIndex, int endIndex, int lengthDifference, StringBuilder sb, String oldExpression, String newExpression) {
        int diff = 0;
        while ((searchIndex = sb.indexOf(oldExpression, searchIndex + 1)) > 0 && searchIndex < endIndex) {
            if (isIdentifierStart(sb.charAt(searchIndex - 1)) || isIdentifier(sb.charAt(searchIndex + oldExpression.length()))) {
                continue;
            }
            sb.replace(searchIndex, searchIndex + oldExpression.length(), newExpression);
            searchIndex += lengthDifference;
            endIndex += lengthDifference;
            diff += lengthDifference;
        }
        return diff;
    }

    private boolean isIdentifierStart(char c) {
        return Character.isLetter(c) || c == '_';
    }

    private boolean isIdentifier(char c) {
        return Character.isLetterOrDigit(c) || c == '_';
    }

    private void applyTableNameRemapping(StringBuilder sb, String sqlAlias, String newCteName, String aliasExtension) {
        final String searchAs = " as";
        final String searchAlias = " " + sqlAlias;
        int searchIndex = 0;
        while ((searchIndex = sb.indexOf(searchAlias, searchIndex)) > -1) {
            char c = sb.charAt(searchIndex + searchAlias.length());
            if (c == '.') {
                // This is a dereference of the alias, skip this
            } else {
                int[] indexRange;
                if (searchAs.equalsIgnoreCase(sb.substring(searchIndex - searchAs.length(), searchIndex))) {
                    // Uses aliasing with the AS keyword
                    indexRange = rtrimBackwardsToFirstWhitespace(sb, searchIndex - searchAs.length());
                } else {
                    // Uses aliasing without the AS keyword
                    indexRange = rtrimBackwardsToFirstWhitespace(sb, searchIndex);
                }

                int oldLength = indexRange[1] - indexRange[0];
                // Replace table name with cte name
                sb.replace(indexRange[0], indexRange[1], newCteName);

                if (aliasExtension != null) {
                    sb.insert(searchIndex + searchAlias.length() + (newCteName.length() - oldLength), aliasExtension);
                    searchIndex += aliasExtension.length();
                }

                // Adjust index after replacing
                searchIndex += newCteName.length() - oldLength;
            }

            searchIndex = searchIndex + 1;
        }
    }

    private int[] rtrimBackwardsToFirstWhitespace(StringBuilder sb, int startIndex) {
        int tableNameStartIndex;
        int tableNameEndIndex = startIndex;
        boolean text = false;
        for (tableNameStartIndex = tableNameEndIndex; tableNameStartIndex >= 0; tableNameStartIndex--) {
            if (text) {
                final char c = sb.charAt(tableNameStartIndex);
                if (Character.isWhitespace(c) || c == ',') {
                    tableNameStartIndex++;
                    break;
                }
            } else {
                if (Character.isWhitespace(sb.charAt(tableNameStartIndex))) {
                    tableNameEndIndex--;
                } else {
                    text = true;
                    tableNameEndIndex++;
                }
            }
        }

        return new int[]{ tableNameStartIndex, tableNameEndIndex };
    }

    private String getSql(Query query) {
        if (query instanceof CustomSQLQuery) {
            return ((CustomSQLQuery) query).getSql();
        } else if (query instanceof CustomSQLTypedQuery) {
            return ((CustomSQLTypedQuery) query).getSql();
        }
        return extendedQuerySupport.getSql(em, query);
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy