All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor Maven / Gradle / Ivy

/*
 * Copyright (c) 2011-2022, baomidou ([email protected]).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.extension.plugins.inner;

import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
import lombok.*;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @author hubin
 * @since 3.4.0
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings({"rawtypes"})
public class TenantLineInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {

    private TenantLineHandler tenantLineHandler;

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        mpBs.sql(parserSingle(mpBs.sql(), null));
    }

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
                return;
            }
            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            mpBs.sql(parserMulti(mpBs.sql(), null));
        }
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        processSelectBody(select.getSelectBody());
        List withItemsList = select.getWithItemsList();
        if (!CollectionUtils.isEmpty(withItemsList)) {
            withItemsList.forEach(this::processSelectBody);
        }
    }

    protected void processSelectBody(SelectBody selectBody) {
        if (selectBody == null) {
            return;
        }
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            processSelectBody(withItem.getSubSelect().getSelectBody());
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            List selectBodyList = operationList.getSelects();
            if (CollectionUtils.isNotEmpty(selectBodyList)) {
                selectBodyList.forEach(this::processSelectBody);
            }
        }
    }

    @Override
    protected void processInsert(Insert insert, int index, String sql, Object obj) {
        if (tenantLineHandler.ignoreTable(insert.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        List columns = insert.getColumns();
        if (CollectionUtils.isEmpty(columns)) {
            // 针对不给列名的insert 不处理
            return;
        }
        String tenantIdColumn = tenantLineHandler.getTenantIdColumn();
        if (tenantLineHandler.ignoreInsert(columns, tenantIdColumn)) {
            // 针对已给出租户列的insert 不处理
            return;
        }
        columns.add(new Column(tenantIdColumn));

        // fixed gitee pulls/141 duplicate update
        List duplicateUpdateColumns = insert.getDuplicateUpdateExpressionList();
        if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
            EqualsTo equalsTo = new EqualsTo();
            equalsTo.setLeftExpression(new StringValue(tenantIdColumn));
            equalsTo.setRightExpression(tenantLineHandler.getTenantId());
            duplicateUpdateColumns.add(equalsTo);
        }

        Select select = insert.getSelect();
        if (select != null) {
            this.processInsertSelect(select.getSelectBody());
        } else if (insert.getItemsList() != null) {
            // fixed github pull/295
            ItemsList itemsList = insert.getItemsList();
            if (itemsList instanceof MultiExpressionList) {
                ((MultiExpressionList) itemsList).getExpressionLists().forEach(el -> el.getExpressions().add(tenantLineHandler.getTenantId()));
            } else {
                ((ExpressionList) itemsList).getExpressions().add(tenantLineHandler.getTenantId());
            }
        } else {
            throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
        }
    }

    /**
     * update 语句处理
     */
    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        final Table table = update.getTable();
        if (tenantLineHandler.ignoreTable(table.getName())) {
            // 过滤退出执行
            return;
        }
        update.setWhere(this.andExpression(table, update.getWhere()));
    }

    /**
     * delete 语句处理
     */
    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        if (tenantLineHandler.ignoreTable(delete.getTable().getName())) {
            // 过滤退出执行
            return;
        }
        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
    }

    /**
     * delete update 语句 where 处理
     */
    protected BinaryExpression andExpression(Table table, Expression where) {
        //获得where条件表达式
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(this.getAliasColumn(table));
        equalsTo.setRightExpression(tenantLineHandler.getTenantId());
        if (null != where) {
            if (where instanceof OrExpression) {
                return new AndExpression(equalsTo, new Parenthesis(where));
            } else {
                return new AndExpression(equalsTo, where);
            }
        }
        return equalsTo;
    }


    /**
     * 处理 insert into select
     * 

* 进入这里表示需要 insert 的表启用了多租户,则 select 的表都启动了 * * @param selectBody SelectBody */ protected void processInsertSelect(SelectBody selectBody) { PlainSelect plainSelect = (PlainSelect) selectBody; FromItem fromItem = plainSelect.getFromItem(); if (fromItem instanceof Table) { // fixed gitee pulls/141 duplicate update processPlainSelect(plainSelect); appendSelectItem(plainSelect.getSelectItems()); } else if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; appendSelectItem(plainSelect.getSelectItems()); processInsertSelect(subSelect.getSelectBody()); } } /** * 追加 SelectItem * * @param selectItems SelectItem */ protected void appendSelectItem(List selectItems) { if (CollectionUtils.isEmpty(selectItems)) { return; } if (selectItems.size() == 1) { SelectItem item = selectItems.get(0); if (item instanceof AllColumns || item instanceof AllTableColumns) { return; } } selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn()))); } /** * 处理 PlainSelect */ protected void processPlainSelect(PlainSelect plainSelect) { //#3087 github List selectItems = plainSelect.getSelectItems(); if (CollectionUtils.isNotEmpty(selectItems)) { selectItems.forEach(this::processSelectItem); } // 处理 where 中的子查询 Expression where = plainSelect.getWhere(); processWhereSubSelect(where); // 处理 fromItem FromItem fromItem = plainSelect.getFromItem(); List

list = processFromItem(fromItem); List
mainTables = new ArrayList<>(list); // 处理 join List joins = plainSelect.getJoins(); if (CollectionUtils.isNotEmpty(joins)) { mainTables = processJoins(mainTables, joins); } // 当有 mainTable 时,进行 where 条件追加 if (CollectionUtils.isNotEmpty(mainTables)) { plainSelect.setWhere(builderExpression(where, mainTables)); } } private List
processFromItem(FromItem fromItem) { // 处理括号括起来的表达式 while (fromItem instanceof ParenthesisFromItem) { fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); } List
mainTables = new ArrayList<>(); // 无 join 时的处理逻辑 if (fromItem instanceof Table) { Table fromTable = (Table) fromItem; mainTables.add(fromTable); } else if (fromItem instanceof SubJoin) { // SubJoin 类型则还需要添加上 where 条件 List
tables = processSubJoin((SubJoin) fromItem); mainTables.addAll(tables); } else { // 处理下 fromItem processOtherFromItem(fromItem); } return mainTables; } /** * 处理where条件内的子查询 *

* 支持如下: * 1. in * 2. = * 3. > * 4. < * 5. >= * 6. <= * 7. <> * 8. EXISTS * 9. NOT EXISTS *

* 前提条件: * 1. 子查询必须放在小括号中 * 2. 子查询一般放在比较操作符的右边 * * @param where where 条件 */ protected void processWhereSubSelect(Expression where) { if (where == null) { return; } if (where instanceof FromItem) { processOtherFromItem((FromItem) where); return; } if (where.toString().indexOf("SELECT") > 0) { // 有子查询 if (where instanceof BinaryExpression) { // 比较符号 , and , or , 等等 BinaryExpression expression = (BinaryExpression) where; processWhereSubSelect(expression.getLeftExpression()); processWhereSubSelect(expression.getRightExpression()); } else if (where instanceof InExpression) { // in InExpression expression = (InExpression) where; Expression inExpression = expression.getRightExpression(); if (inExpression instanceof SubSelect) { processSelectBody(((SubSelect) inExpression).getSelectBody()); } } else if (where instanceof ExistsExpression) { // exists ExistsExpression expression = (ExistsExpression) where; processWhereSubSelect(expression.getRightExpression()); } else if (where instanceof NotExpression) { // not exists NotExpression expression = (NotExpression) where; processWhereSubSelect(expression.getExpression()); } else if (where instanceof Parenthesis) { Parenthesis expression = (Parenthesis) where; processWhereSubSelect(expression.getExpression()); } } } protected void processSelectItem(SelectItem selectItem) { if (selectItem instanceof SelectExpressionItem) { SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; if (selectExpressionItem.getExpression() instanceof SubSelect) { processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody()); } else if (selectExpressionItem.getExpression() instanceof Function) { processFunction((Function) selectExpressionItem.getExpression()); } } } /** * 处理函数 *

支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)

*

fixed gitee pulls/141

* * @param function */ protected void processFunction(Function function) { ExpressionList parameters = function.getParameters(); if (parameters != null) { parameters.getExpressions().forEach(expression -> { if (expression instanceof SubSelect) { processSelectBody(((SubSelect) expression).getSelectBody()); } else if (expression instanceof Function) { processFunction((Function) expression); } }); } } /** * 处理子查询等 */ protected void processOtherFromItem(FromItem fromItem) { // 去除括号 while (fromItem instanceof ParenthesisFromItem) { fromItem = ((ParenthesisFromItem) fromItem).getFromItem(); } if (fromItem instanceof SubSelect) { SubSelect subSelect = (SubSelect) fromItem; if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); } } else if (fromItem instanceof ValuesList) { logger.debug("Perform a subQuery, if you do not give us feedback"); } else if (fromItem instanceof LateralSubSelect) { LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem; if (lateralSubSelect.getSubSelect() != null) { SubSelect subSelect = lateralSubSelect.getSubSelect(); if (subSelect.getSelectBody() != null) { processSelectBody(subSelect.getSelectBody()); } } } } /** * 处理 sub join * * @param subJoin subJoin * @return Table subJoin 中的主表 */ private List
processSubJoin(SubJoin subJoin) { List
mainTables = new ArrayList<>(); if (subJoin.getJoinList() != null) { List
list = processFromItem(subJoin.getLeft()); mainTables.addAll(list); mainTables = processJoins(mainTables, subJoin.getJoinList()); } return mainTables; } /** * 处理 joins * * @param mainTables 可以为 null * @param joins join 集合 * @return List
右连接查询的 Table 列表 */ private List
processJoins(List
mainTables, List joins) { // join 表达式中最终的主表 Table mainTable = null; // 当前 join 的左表 Table leftTable = null; if (mainTables == null) { mainTables = new ArrayList<>(); } else if (mainTables.size() == 1) { mainTable = mainTables.get(0); leftTable = mainTable; } //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名 Deque> onTableDeque = new LinkedList<>(); for (Join join : joins) { // 处理 on 表达式 FromItem joinItem = join.getRightItem(); // 获取当前 join 的表,subJoint 可以看作是一张表 List
joinTables = null; if (joinItem instanceof Table) { joinTables = new ArrayList<>(); joinTables.add((Table) joinItem); } else if (joinItem instanceof SubJoin) { joinTables = processSubJoin((SubJoin) joinItem); } if (joinTables != null) { // 如果是隐式内连接 if (join.isSimple()) { mainTables.addAll(joinTables); continue; } // 当前表是否忽略 Table joinTable = joinTables.get(0); List
onTables = null; // 如果不要忽略,且是右连接,则记录下当前表 if (join.isRight()) { mainTable = joinTable; if (leftTable != null) { onTables = Collections.singletonList(leftTable); } } else if (join.isLeft()) { onTables = Collections.singletonList(joinTable); } else if (join.isInner()) { if (mainTable == null) { onTables = Collections.singletonList(joinTable); } else { onTables = Arrays.asList(mainTable, joinTable); } mainTable = null; } mainTables = new ArrayList<>(); if (mainTable != null) { mainTables.add(mainTable); } // 获取 join 尾缀的 on 表达式列表 Collection originOnExpressions = join.getOnExpressions(); // 正常 join on 表达式只有一个,立刻处理 if (originOnExpressions.size() == 1 && onTables != null) { List onExpressions = new LinkedList<>(); onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables)); join.setOnExpressions(onExpressions); leftTable = joinTable; continue; } // 表名压栈,忽略的表压入 null,以便后续不处理 onTableDeque.push(onTables); // 尾缀多个 on 表达式的时候统一处理 if (originOnExpressions.size() > 1) { Collection onExpressions = new LinkedList<>(); for (Expression originOnExpression : originOnExpressions) { List
currentTableList = onTableDeque.poll(); if (CollectionUtils.isEmpty(currentTableList)) { onExpressions.add(originOnExpression); } else { onExpressions.add(builderExpression(originOnExpression, currentTableList)); } } join.setOnExpressions(onExpressions); } leftTable = joinTable; } else { processOtherFromItem(joinItem); leftTable = null; } } return mainTables; } /** * 处理条件 */ protected Expression builderExpression(Expression currentExpression, List
tables) { // 没有表需要处理直接返回 if (CollectionUtils.isEmpty(tables)) { return currentExpression; } // 构造每张表的条件 List
tempTables = tables.stream() .filter(x -> !tenantLineHandler.ignoreTable(x.getName())) .collect(Collectors.toList()); // 没有表需要处理直接返回 if (CollectionUtils.isEmpty(tempTables)) { return currentExpression; } Expression tenantId = tenantLineHandler.getTenantId(); List equalsTos = tempTables.stream() .map(item -> new EqualsTo(getAliasColumn(item), tenantId)) .collect(Collectors.toList()); // 注入的表达式 Expression injectExpression = equalsTos.get(0); // 如果有多表,则用 and 连接 if (equalsTos.size() > 1) { for (int i = 1; i < equalsTos.size(); i++) { injectExpression = new AndExpression(injectExpression, equalsTos.get(i)); } } if (currentExpression == null) { return injectExpression; } if (currentExpression instanceof OrExpression) { return new AndExpression(new Parenthesis(currentExpression), injectExpression); } else { return new AndExpression(currentExpression, injectExpression); } } /** * 租户字段别名设置 *

tenantId 或 tableAlias.tenantId

* * @param table 表对象 * @return 字段 */ protected Column getAliasColumn(Table table) { StringBuilder column = new StringBuilder(); // 为了兼容隐式内连接,没有别名时条件就需要加上表名 if (table.getAlias() != null) { column.append(table.getAlias().getName()); } else { column.append(table.getName()); } column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn()); return new Column(column.toString()); } @Override public void setProperties(Properties properties) { PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler", ClassUtils::newInstance, this::setTenantLineHandler); } }