All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.gateway.ha.handler.ProxyUtils Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.gateway.ha.handler;

import com.google.common.base.Splitter;
import com.google.common.io.CharStreams;
import io.airlift.log.Logger;
import jakarta.servlet.http.HttpServletRequest;

import java.io.InputStreamReader;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;

public final class ProxyUtils
{
    public static final String SOURCE_HEADER = "X-Trino-Source";
    public static final String AUTHORIZATION = "Authorization";

    private static final Logger log = Logger.get(ProxyUtils.class);
    public static final int QUERY_TEXT_LENGTH_FOR_HISTORY = 200;
    /**
     * This regular expression matches query ids as they appear in the path of a URL. The query id must be preceded
     * by a "/". A query id is defined as three groups of digits separated by underscores, with a final group
     * consisting of any alphanumeric characters.
     */
    private static final Pattern QUERY_ID_PATH_PATTERN = Pattern.compile(".*/(\\d+_\\d+_\\d+_\\w+).*");
    /**
     * This regular expression matches query ids as they appear in the query parameters of a URL. The query id is
     * defined as in QUERY_TEXT_LENGTH_FOR_HISTORY. The query id must either be at the beginning of the query parameter
     * string, or be preceded by %2F (a URL-encoded "/"), or  "query_id=", with or without the underscore and any
     * capitalization.
     */
    private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*");
    private static final Pattern EXTRACT_BETWEEN_SINGLE_QUOTES = Pattern.compile("'([^\\s']+)'");

    private ProxyUtils() {}

    public static String getQueryUser(String userHeader, String authorization)
    {
        if (!isNullOrEmpty(userHeader)) {
            log.debug("User from header %s", USER_HEADER);
            return userHeader;
        }

        log.debug("User from basic authentication");
        String user = "";
        if (authorization == null) {
            log.debug("No basic auth header found.");
            return user;
        }

        int space = authorization.indexOf(' ');
        if ((space < 0) || !authorization.substring(0, space).equalsIgnoreCase("basic")) {
            log.error("Basic auth format is invalid");
            return user;
        }

        String headerInfo = authorization.substring(space + 1).trim();
        if (isNullOrEmpty(headerInfo)) {
            log.error("Encoded value of basic auth doesn't exist");
            return user;
        }

        String info = new String(Base64.getDecoder().decode(headerInfo));
        List parts = Splitter.on(':').limit(2).splitToList(info);
        if (parts.size() < 1) {
            log.error("No user inside the basic auth text");
            return user;
        }
        return parts.get(0);
    }

    public static String extractQueryIdIfPresent(HttpServletRequest request, List statementPaths)
    {
        String path = request.getRequestURI();
        String queryParams = request.getQueryString();
        try {
            String queryText = CharStreams.toString(new InputStreamReader(request.getInputStream()));
            if (!isNullOrEmpty(queryText)
                    && queryText.toLowerCase().contains("system.runtime.kill_query")) {
                // extract and return the queryId
                String[] parts = queryText.split(",");
                for (String part : parts) {
                    if (part.contains("query_id")) {
                        Matcher matcher = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
                        if (matcher.find()) {
                            String queryQuoted = matcher.group();
                            if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
                                return queryQuoted.substring(1, queryQuoted.length() - 1);
                            }
                        }
                    }
                }
            }
        }
        catch (Exception e) {
            log.error(e, "Error extracting query payload from request");
        }

        return extractQueryIdIfPresent(path, queryParams, statementPaths);
    }

    public static String extractQueryIdIfPresent(String path, String queryParams, List statementPaths)
    {
        if (path == null) {
            return null;
        }
        String queryId = null;
        log.debug("Trying to extract query id from path [%s] or queryString [%s]", path, queryParams);
        // matchingStatementPath should match paths such as /v1/statement/executing/query_id/nonce/sequence_number,
        // and if custom paths are supplied using the statementPaths configuration, paths such as
        // /custom/statement/path/executing/query_id/nonce/sequence_number
        Optional matchingStatementPath = statementPaths.stream().filter(path::startsWith).findAny();
        if (matchingStatementPath.isPresent() || path.startsWith(V1_QUERY_PATH)) {
            path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), "");
            String[] tokens = path.split("/");
            if (tokens.length >= 2) {
                if (tokens[1].equals("queued")
                        || tokens[1].equals("scheduled")
                        || tokens[1].equals("executing")
                        || tokens[1].equals("partialCancel")) {
                    queryId = tokens[2];
                }
                else {
                    queryId = tokens[1];
                }
            }
        }
        else if (path.startsWith(TRINO_UI_PATH)) {
            Matcher matcher = QUERY_ID_PATH_PATTERN.matcher(path);
            if (matcher.matches()) {
                queryId = matcher.group(1);
            }
        }
        if (!isNullOrEmpty(queryParams)) {
            Matcher matcher = QUERY_ID_PARAM_PATTERN.matcher(queryParams);
            if (matcher.matches()) {
                queryId = matcher.group(1);
            }
        }
        log.debug("Query id in URL [%s]", queryId);
        return queryId;
    }

    public static String buildUriWithNewBackend(String backendHost, HttpServletRequest request)
    {
        return backendHost + request.getRequestURI() + (request.getQueryString() != null ? "?" + request.getQueryString() : "");
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy