All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.sql.parser.ErrorHandler Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.sql.parser;

import com.google.common.collect.ImmutableSet;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.NoViableAltException;
import org.antlr.v4.runtime.Parser;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.RuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.Vocabulary;
import org.antlr.v4.runtime.atn.ATN;
import org.antlr.v4.runtime.atn.ATNState;
import org.antlr.v4.runtime.atn.NotSetTransition;
import org.antlr.v4.runtime.atn.PrecedencePredicateTransition;
import org.antlr.v4.runtime.atn.RuleStartState;
import org.antlr.v4.runtime.atn.RuleStopState;
import org.antlr.v4.runtime.atn.RuleTransition;
import org.antlr.v4.runtime.atn.Transition;
import org.antlr.v4.runtime.atn.WildcardTransition;
import org.antlr.v4.runtime.misc.IntervalSet;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import static com.google.common.base.MoreObjects.firstNonNull;
import static java.util.logging.Level.SEVERE;
import static org.antlr.v4.runtime.atn.ATNState.RULE_START;

class ErrorHandler
        extends BaseErrorListener
{
    private static final Logger LOG = Logger.getLogger(ErrorHandler.class.getName());

    private final Map specialRules;
    private final Map specialTokens;
    private final Set ignoredRules;

    private ErrorHandler(Map specialRules, Map specialTokens, Set ignoredRules)
    {
        this.specialRules = new HashMap<>(specialRules);
        this.specialTokens = specialTokens;
        this.ignoredRules = new HashSet<>(ignoredRules);
    }

    @Override
    public void syntaxError(Recognizer recognizer, Object offendingSymbol, int line, int charPositionInLine, String message, RecognitionException e)
    {
        try {
            Parser parser = (Parser) recognizer;

            ATN atn = parser.getATN();

            ATNState currentState;
            Token currentToken;
            RuleContext context;

            if (e != null) {
                currentState = atn.states.get(e.getOffendingState());
                currentToken = e.getOffendingToken();
                context = e.getCtx();

                if (e instanceof NoViableAltException) {
                    currentToken = ((NoViableAltException) e).getStartToken();
                }
            }
            else {
                currentState = atn.states.get(parser.getState());
                currentToken = parser.getCurrentToken();
                context = parser.getContext();
            }

            Analyzer analyzer = new Analyzer(parser, specialRules, specialTokens, ignoredRules);
            Result result = analyzer.process(currentState, currentToken.getTokenIndex(), context);

            // pick the candidate tokens associated largest token index processed (i.e., the path that consumed the most input)
            String expected = result.getExpected().stream()
                    .sorted()
                    .collect(Collectors.joining(", "));

            message = "mismatched input '%s'. Expecting: %s".formatted(parser.getTokenStream().get(result.getErrorTokenIndex()).getText(), expected);
        }
        catch (Exception exception) {
            LOG.log(SEVERE, "Unexpected failure when handling parsing error. This is likely a bug in the implementation", exception);
        }

        throw new ParsingException(message, e, line, charPositionInLine + 1);
    }

    private static class ParsingState
    {
        public final ATNState state;
        public final int tokenIndex;
        public final boolean suppressed;
        public final Parser parser;

        public ParsingState(ATNState state, int tokenIndex, boolean suppressed, Parser parser)
        {
            this.state = state;
            this.tokenIndex = tokenIndex;
            this.suppressed = suppressed;
            this.parser = parser;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            ParsingState that = (ParsingState) o;
            return tokenIndex == that.tokenIndex &&
                    state.equals(that.state);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(state, tokenIndex);
        }

        @Override
        public String toString()
        {
            Token token = parser.getTokenStream().get(tokenIndex);

            String text = firstNonNull(token.getText(), "?");
            if (text != null) {
                text = text.replace("\\", "\\\\");
                text = text.replace("\n", "\\n");
                text = text.replace("\r", "\\r");
                text = text.replace("\t", "\\t");
            }

            return "%s%s:%s @ %s:<%s>:%s".formatted(
                    suppressed ? "-" : "+",
                    parser.getRuleNames()[state.ruleIndex],
                    state.stateNumber,
                    tokenIndex,
                    parser.getVocabulary().getSymbolicName(token.getType()),
                    text);
        }
    }

    private static class Analyzer
    {
        private final Parser parser;
        private final ATN atn;
        private final Vocabulary vocabulary;
        private final Map specialRules;
        private final Map specialTokens;
        private final Set ignoredRules;
        private final TokenStream stream;

        private int furthestTokenIndex = -1;
        private final Set candidates = new HashSet<>();

        private final Map> memo = new HashMap<>();

        public Analyzer(
                Parser parser,
                Map specialRules,
                Map specialTokens,
                Set ignoredRules)
        {
            this.parser = parser;
            this.stream = parser.getTokenStream();
            this.atn = parser.getATN();
            this.vocabulary = parser.getVocabulary();
            this.specialRules = specialRules;
            this.specialTokens = specialTokens;
            this.ignoredRules = ignoredRules;
        }

        public Result process(ATNState currentState, int tokenIndex, RuleContext context)
        {
            RuleStartState startState = atn.ruleToStartState[currentState.ruleIndex];

            if (isReachable(currentState, startState)) {
                // We've been dropped inside a rule in a state that's reachable via epsilon transitions. This is,
                // effectively, equivalent to starting at the beginning (or immediately outside) the rule.
                // In that case, backtrack to the beginning to be able to take advantage of logic that remaps
                // some rules to well-known names for reporting purposes
                currentState = startState;
            }

            Set endTokens = process(new ParsingState(currentState, tokenIndex, false, parser), 0);
            while (!endTokens.isEmpty() && context.invokingState != -1) {
                ATNState nextState = ((RuleTransition) atn.states.get(context.invokingState).transition(0)).followState;
                endTokens = endTokens.stream()
                    .flatMap(endToken -> process(new ParsingState(nextState, endToken, false, parser), 0).stream())
                    .collect(Collectors.toSet());
                context = context.parent;
            }

            return new Result(furthestTokenIndex, candidates);
        }

        private boolean isReachable(ATNState target, RuleStartState from)
        {
            Deque activeStates = new ArrayDeque<>();
            activeStates.add(from);

            while (!activeStates.isEmpty()) {
                ATNState current = activeStates.pop();

                if (current.stateNumber == target.stateNumber) {
                    return true;
                }

                for (int i = 0; i < current.getNumberOfTransitions(); i++) {
                    Transition transition = current.transition(i);

                    if (transition.isEpsilon()) {
                        activeStates.push(transition.target);
                    }
                }
            }

            return false;
        }

        private Set process(ParsingState start, int precedence)
        {
            Set result = memo.get(start);
            if (result != null) {
                return result;
            }

            ImmutableSet.Builder endTokens = ImmutableSet.builder();

            // Simulates the ATN by consuming input tokens and walking transitions.
            // The ATN can be in multiple states (similar to an NFA)
            Deque activeStates = new ArrayDeque<>();
            activeStates.add(start);

            while (!activeStates.isEmpty()) {
                ParsingState current = activeStates.pop();

                ATNState state = current.state;
                int tokenIndex = current.tokenIndex;
                boolean suppressed = current.suppressed;

                while (stream.get(tokenIndex).getChannel() == Token.HIDDEN_CHANNEL) {
                    // Ignore whitespace
                    tokenIndex++;
                }
                int currentToken = stream.get(tokenIndex).getType();

                if (state.getStateType() == RULE_START) {
                    int rule = state.ruleIndex;

                    if (specialRules.containsKey(rule)) {
                        if (!suppressed) {
                            record(tokenIndex, specialRules.get(rule));
                        }
                        suppressed = true;
                    }
                    else if (ignoredRules.contains(rule)) {
                        // TODO expand ignored rules like we expand special rules
                        continue;
                    }
                }

                if (state instanceof RuleStopState) {
                    endTokens.add(tokenIndex);
                    continue;
                }

                for (int i = 0; i < state.getNumberOfTransitions(); i++) {
                    Transition transition = state.transition(i);

                    if (transition instanceof RuleTransition ruleTransition) {
                        for (int endToken : process(new ParsingState(ruleTransition.target, tokenIndex, suppressed, parser), ruleTransition.precedence)) {
                            activeStates.push(new ParsingState(ruleTransition.followState, endToken, suppressed && endToken == currentToken, parser));
                        }
                    }
                    else if (transition instanceof PrecedencePredicateTransition) {
                        if (precedence < ((PrecedencePredicateTransition) transition).precedence) {
                            activeStates.push(new ParsingState(transition.target, tokenIndex, suppressed, parser));
                        }
                    }
                    else if (transition.isEpsilon()) {
                        activeStates.push(new ParsingState(transition.target, tokenIndex, suppressed, parser));
                    }
                    else if (transition instanceof WildcardTransition) {
                        throw new UnsupportedOperationException("not yet implemented: wildcard transition");
                    }
                    else {
                        IntervalSet labels = transition.label();

                        if (transition instanceof NotSetTransition) {
                            labels = labels.complement(IntervalSet.of(Token.MIN_USER_TOKEN_TYPE, atn.maxTokenType));
                        }

                        // Surprisingly, TokenStream (i.e. BufferedTokenStream) may not have loaded all the tokens from the
                        // underlying stream. TokenStream.get() does not force tokens to be buffered -- it just returns what's
                        // in the current buffer, or fail with an IndexOutOfBoundsError. Since Antlr decided the error occurred
                        // within the current set of buffered tokens, stop when we reach the end of the buffer.
                        if (labels.contains(currentToken) && tokenIndex < stream.size() - 1) {
                            activeStates.push(new ParsingState(transition.target, tokenIndex + 1, false, parser));
                        }
                        else {
                            if (!suppressed) {
                                record(tokenIndex, getTokenNames(labels));
                            }
                        }
                    }
                }
            }

            result = endTokens.build();
            memo.put(start, result);
            return result;
        }

        private void record(int tokenIndex, String label)
        {
            record(tokenIndex, ImmutableSet.of(label));
        }

        private void record(int tokenIndex, Set labels)
        {
            if (tokenIndex >= furthestTokenIndex) {
                if (tokenIndex > furthestTokenIndex) {
                    candidates.clear();
                    furthestTokenIndex = tokenIndex;
                }

                candidates.addAll(labels);
            }
        }

        private Set getTokenNames(IntervalSet tokens)
        {
            Set names = new HashSet<>();
            for (int i = 0; i < tokens.size(); i++) {
                int token = tokens.get(i);
                if (token == Token.EOF) {
                    names.add("");
                }
                else {
                    names.add(specialTokens.getOrDefault(token, vocabulary.getDisplayName(token)));
                }
            }

            return names;
        }
    }

    public static Builder builder()
    {
        return new Builder();
    }

    public static class Builder
    {
        private final Map specialRules = new HashMap<>();
        private final Map specialTokens = new HashMap<>();
        private final Set ignoredRules = new HashSet<>();

        public Builder specialRule(int ruleId, String name)
        {
            specialRules.put(ruleId, name);
            return this;
        }

        public Builder specialToken(int tokenId, String name)
        {
            specialTokens.put(tokenId, name);
            return this;
        }

        public Builder ignoredRule(int ruleId)
        {
            ignoredRules.add(ruleId);
            return this;
        }

        public ErrorHandler build()
        {
            return new ErrorHandler(specialRules, specialTokens, ignoredRules);
        }
    }

    private static class Result
    {
        private final int errorTokenIndex;
        private final Set expected;

        public Result(int errorTokenIndex, Set expected)
        {
            this.errorTokenIndex = errorTokenIndex;
            this.expected = expected;
        }

        public int getErrorTokenIndex()
        {
            return errorTokenIndex;
        }

        public Set getExpected()
        {
            return expected;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy