All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.cli.lexer.StatementSplitter Maven / Gradle / Ivy

There is a newer version: 464
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.cli.lexer;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.grammar.sql.SqlBaseBaseVisitor;
import io.trino.grammar.sql.SqlBaseLexer;
import io.trino.grammar.sql.SqlBaseParser;
import io.trino.grammar.sql.SqlBaseParser.FunctionSpecificationContext;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.RuleNode;

import java.util.List;
import java.util.Objects;
import java.util.Set;

import static java.util.Objects.requireNonNull;

public class StatementSplitter
{
    private final List completeStatements;
    private final String partialStatement;

    public StatementSplitter(String sql)
    {
        this(sql, ImmutableSet.of(";"));
    }

    public StatementSplitter(String sql, Set delimiters)
    {
        DelimiterLexer lexer = getLexer(sql, delimiters);
        CommonTokenStream tokenStream = new CommonTokenStream(lexer);
        tokenStream.fill();

        SqlBaseParser parser = new SqlBaseParser(tokenStream);
        parser.removeErrorListeners();

        ImmutableList.Builder list = ImmutableList.builder();
        StringBuilder sb = new StringBuilder();
        int index = 0;

        while (index < tokenStream.size()) {
            ParserRuleContext context = parser.statement();

            if (containsFunction(context)) {
                Token stop = context.getStop();
                if ((stop != null) && (stop.getTokenIndex() >= index)) {
                    int endIndex = stop.getTokenIndex();
                    while (index <= endIndex) {
                        Token token = tokenStream.get(index);
                        index++;
                        sb.append(token.getText());
                    }
                }
            }

            while (index < tokenStream.size()) {
                Token token = tokenStream.get(index);
                index++;
                if (token.getType() == Token.EOF) {
                    break;
                }
                if (lexer.isDelimiter(token)) {
                    String statement = sb.toString().trim();
                    if (!statement.isEmpty()) {
                        list.add(new Statement(statement, token.getText()));
                    }
                    sb = new StringBuilder();
                    break;
                }
                sb.append(token.getText());
            }
        }

        this.completeStatements = list.build();
        this.partialStatement = sb.toString().trim();
    }

    public List getCompleteStatements()
    {
        return completeStatements;
    }

    public String getPartialStatement()
    {
        return partialStatement;
    }

    public static String squeezeStatement(String sql)
    {
        TokenSource tokens = getLexer(sql, ImmutableSet.of());
        StringBuilder sb = new StringBuilder();
        while (true) {
            Token token = tokens.nextToken();
            if (token.getType() == Token.EOF) {
                break;
            }
            if (token.getType() == SqlBaseLexer.WS) {
                sb.append(' ');
            }
            else {
                sb.append(token.getText());
            }
        }
        return sb.toString().trim();
    }

    public static boolean isEmptyStatement(String sql)
    {
        TokenSource tokens = getLexer(sql, ImmutableSet.of());
        while (true) {
            Token token = tokens.nextToken();
            if (token.getType() == Token.EOF) {
                return true;
            }
            if (token.getChannel() != Token.HIDDEN_CHANNEL) {
                return false;
            }
        }
    }

    public static DelimiterLexer getLexer(String sql, Set terminators)
    {
        requireNonNull(sql, "sql is null");
        return new DelimiterLexer(CharStreams.fromString(sql), terminators);
    }

    private static boolean containsFunction(ParseTree tree)
    {
        return new SqlBaseBaseVisitor()
        {
            @Override
            protected Boolean defaultResult()
            {
                return false;
            }

            @Override
            protected Boolean aggregateResult(Boolean aggregate, Boolean nextResult)
            {
                return aggregate || nextResult;
            }

            @Override
            protected boolean shouldVisitNextChild(RuleNode node, Boolean currentResult)
            {
                return !currentResult;
            }

            @Override
            public Boolean visitFunctionSpecification(FunctionSpecificationContext context)
            {
                return true;
            }
        }.visit(tree);
    }

    public static class Statement
    {
        private final String statement;
        private final String terminator;

        public Statement(String statement, String terminator)
        {
            this.statement = requireNonNull(statement, "statement is null");
            this.terminator = requireNonNull(terminator, "terminator is null");
        }

        public String statement()
        {
            return statement;
        }

        public String terminator()
        {
            return terminator;
        }

        @Override
        public boolean equals(Object obj)
        {
            if (this == obj) {
                return true;
            }
            if ((obj == null) || (getClass() != obj.getClass())) {
                return false;
            }
            Statement o = (Statement) obj;
            return Objects.equals(statement, o.statement) &&
                    Objects.equals(terminator, o.terminator);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(statement, terminator);
        }

        @Override
        public String toString()
        {
            return statement + terminator;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy