All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.prestosql.server.HttpRequestSessionContext Maven / Gradle / Ivy

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.server;

import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.prestosql.Session.ResourceEstimateBuilder;
import io.prestosql.security.AccessControl;
import io.prestosql.spi.security.AccessDeniedException;
import io.prestosql.spi.security.GroupProvider;
import io.prestosql.spi.security.Identity;
import io.prestosql.spi.security.SelectedRole;
import io.prestosql.spi.session.ResourceEstimates;
import io.prestosql.sql.parser.ParsingException;
import io.prestosql.sql.parser.ParsingOptions;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.transaction.TransactionId;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;

import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static io.prestosql.client.PrestoHeaders.PRESTO_CATALOG;
import static io.prestosql.client.PrestoHeaders.PRESTO_CLIENT_CAPABILITIES;
import static io.prestosql.client.PrestoHeaders.PRESTO_CLIENT_INFO;
import static io.prestosql.client.PrestoHeaders.PRESTO_CLIENT_TAGS;
import static io.prestosql.client.PrestoHeaders.PRESTO_EXTRA_CREDENTIAL;
import static io.prestosql.client.PrestoHeaders.PRESTO_LANGUAGE;
import static io.prestosql.client.PrestoHeaders.PRESTO_PATH;
import static io.prestosql.client.PrestoHeaders.PRESTO_PREPARED_STATEMENT;
import static io.prestosql.client.PrestoHeaders.PRESTO_RESOURCE_ESTIMATE;
import static io.prestosql.client.PrestoHeaders.PRESTO_ROLE;
import static io.prestosql.client.PrestoHeaders.PRESTO_SCHEMA;
import static io.prestosql.client.PrestoHeaders.PRESTO_SESSION;
import static io.prestosql.client.PrestoHeaders.PRESTO_SOURCE;
import static io.prestosql.client.PrestoHeaders.PRESTO_TIME_ZONE;
import static io.prestosql.client.PrestoHeaders.PRESTO_TRACE_TOKEN;
import static io.prestosql.client.PrestoHeaders.PRESTO_TRANSACTION_ID;
import static io.prestosql.client.PrestoHeaders.PRESTO_USER;
import static io.prestosql.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public final class HttpRequestSessionContext
        implements SessionContext
{
    private static final Splitter DOT_SPLITTER = Splitter.on('.');
    public static final String AUTHENTICATED_IDENTITY = "presto.authenticated-identity";

    private final String catalog;
    private final String schema;
    private final String path;

    private final Optional authenticatedIdentity;
    private final Identity identity;

    private final String source;
    private final Optional traceToken;
    private final String userAgent;
    private final String remoteUserAddress;
    private final String timeZoneId;
    private final String language;
    private final Set clientTags;
    private final Set clientCapabilities;
    private final ResourceEstimates resourceEstimates;

    private final Map systemProperties;
    private final Map> catalogSessionProperties;

    private final Map preparedStatements;

    private final Optional transactionId;
    private final boolean clientTransactionSupport;
    private final String clientInfo;

    public HttpRequestSessionContext(MultivaluedMap headers, String remoteAddress, Optional authenticatedIdentity, GroupProvider groupProvider)
            throws WebApplicationException
    {
        catalog = trimEmptyToNull(headers.getFirst(PRESTO_CATALOG));
        schema = trimEmptyToNull(headers.getFirst(PRESTO_SCHEMA));
        path = trimEmptyToNull(headers.getFirst(PRESTO_PATH));
        assertRequest((catalog != null) || (schema == null), "Schema is set but catalog is not");

        this.authenticatedIdentity = requireNonNull(authenticatedIdentity, "authenticatedIdentity is null");
        identity = buildSessionIdentity(authenticatedIdentity, headers, groupProvider);

        source = headers.getFirst(PRESTO_SOURCE);
        traceToken = Optional.ofNullable(trimEmptyToNull(headers.getFirst(PRESTO_TRACE_TOKEN)));
        userAgent = headers.getFirst(USER_AGENT);
        remoteUserAddress = remoteAddress;
        timeZoneId = headers.getFirst(PRESTO_TIME_ZONE);
        language = headers.getFirst(PRESTO_LANGUAGE);
        clientInfo = headers.getFirst(PRESTO_CLIENT_INFO);
        clientTags = parseClientTags(headers);
        clientCapabilities = parseClientCapabilities(headers);
        resourceEstimates = parseResourceEstimate(headers);

        // parse session properties
        ImmutableMap.Builder systemProperties = ImmutableMap.builder();
        Map> catalogSessionProperties = new HashMap<>();
        for (Entry entry : parseSessionHeaders(headers).entrySet()) {
            String fullPropertyName = entry.getKey();
            String propertyValue = entry.getValue();
            List nameParts = DOT_SPLITTER.splitToList(fullPropertyName);
            if (nameParts.size() == 1) {
                String propertyName = nameParts.get(0);

                assertRequest(!propertyName.isEmpty(), "Invalid %s header", PRESTO_SESSION);

                // catalog session properties cannot be validated until the transaction has stated, so we delay system property validation also
                systemProperties.put(propertyName, propertyValue);
            }
            else if (nameParts.size() == 2) {
                String catalogName = nameParts.get(0);
                String propertyName = nameParts.get(1);

                assertRequest(!catalogName.isEmpty(), "Invalid %s header", PRESTO_SESSION);
                assertRequest(!propertyName.isEmpty(), "Invalid %s header", PRESTO_SESSION);

                // catalog session properties cannot be validated until the transaction has stated
                catalogSessionProperties.computeIfAbsent(catalogName, id -> new HashMap<>()).put(propertyName, propertyValue);
            }
            else {
                throw badRequest(format("Invalid %s header", PRESTO_SESSION));
            }
        }
        this.systemProperties = systemProperties.build();
        this.catalogSessionProperties = catalogSessionProperties.entrySet().stream()
                .collect(toImmutableMap(Entry::getKey, entry -> ImmutableMap.copyOf(entry.getValue())));

        preparedStatements = parsePreparedStatementsHeaders(headers);

        String transactionIdHeader = headers.getFirst(PRESTO_TRANSACTION_ID);
        clientTransactionSupport = transactionIdHeader != null;
        transactionId = parseTransactionId(transactionIdHeader);
    }

    public static Identity extractAuthorizedIdentity(HttpServletRequest servletRequest, HttpHeaders httpHeaders, AccessControl accessControl, GroupProvider groupProvider)
    {
        return extractAuthorizedIdentity(
                Optional.ofNullable((Identity) servletRequest.getAttribute(AUTHENTICATED_IDENTITY)),
                httpHeaders.getRequestHeaders(),
                accessControl,
                groupProvider);
    }

    public static Identity extractAuthorizedIdentity(Optional optionalAuthenticatedIdentity, MultivaluedMap headers, AccessControl accessControl, GroupProvider groupProvider)
            throws AccessDeniedException
    {
        Identity identity = buildSessionIdentity(optionalAuthenticatedIdentity, headers, groupProvider);

        accessControl.checkCanSetUser(identity.getPrincipal(), identity.getUser());

        // authenticated may not present for HTTP or if authentication is not setup
        optionalAuthenticatedIdentity.ifPresent(authenticatedIdentity -> {
            // only check impersonation if authenticated user is not the same as the explicitly set user
            if (!authenticatedIdentity.getUser().equals(identity.getUser())) {
                accessControl.checkCanImpersonateUser(authenticatedIdentity, identity.getUser());
            }
        });

        return identity;
    }

    private static Identity buildSessionIdentity(Optional authenticatedIdentity, MultivaluedMap headers, GroupProvider groupProvider)
    {
        String prestoUser = trimEmptyToNull(headers.getFirst(PRESTO_USER));
        String user = prestoUser != null ? prestoUser : authenticatedIdentity.map(Identity::getUser).orElse(null);
        assertRequest(user != null, "User must be set");
        return authenticatedIdentity
                .map(identity -> Identity.from(identity).withUser(user))
                .orElseGet(() -> Identity.forUser(user))
                .withAdditionalRoles(parseRoleHeaders(headers))
                .withAdditionalExtraCredentials(parseExtraCredentials(headers))
                .withAdditionalGroups(groupProvider.getGroups(user))
                .build();
    }

    @Override
    public Optional getAuthenticatedIdentity()
    {
        return authenticatedIdentity;
    }

    @Override
    public Identity getIdentity()
    {
        return identity;
    }

    @Override
    public String getCatalog()
    {
        return catalog;
    }

    @Override
    public String getSchema()
    {
        return schema;
    }

    @Override
    public String getPath()
    {
        return path;
    }

    @Override
    public String getSource()
    {
        return source;
    }

    @Override
    public String getRemoteUserAddress()
    {
        return remoteUserAddress;
    }

    @Override
    public String getUserAgent()
    {
        return userAgent;
    }

    @Override
    public String getClientInfo()
    {
        return clientInfo;
    }

    @Override
    public Set getClientTags()
    {
        return clientTags;
    }

    @Override
    public Set getClientCapabilities()
    {
        return clientCapabilities;
    }

    @Override
    public ResourceEstimates getResourceEstimates()
    {
        return resourceEstimates;
    }

    @Override
    public String getTimeZoneId()
    {
        return timeZoneId;
    }

    @Override
    public String getLanguage()
    {
        return language;
    }

    @Override
    public Map getSystemProperties()
    {
        return systemProperties;
    }

    @Override
    public Map> getCatalogSessionProperties()
    {
        return catalogSessionProperties;
    }

    @Override
    public Map getPreparedStatements()
    {
        return preparedStatements;
    }

    @Override
    public Optional getTransactionId()
    {
        return transactionId;
    }

    @Override
    public boolean supportClientTransaction()
    {
        return clientTransactionSupport;
    }

    @Override
    public Optional getTraceToken()
    {
        return traceToken;
    }

    private static List splitHttpHeader(MultivaluedMap headers, String name)
    {
        List values = firstNonNull(headers.get(name), ImmutableList.of());
        Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
        return values.stream()
                .map(splitter::splitToList)
                .flatMap(Collection::stream)
                .collect(toImmutableList());
    }

    private static Map parseSessionHeaders(MultivaluedMap headers)
    {
        return parseProperty(headers, PRESTO_SESSION);
    }

    private static Map parseRoleHeaders(MultivaluedMap headers)
    {
        ImmutableMap.Builder roles = ImmutableMap.builder();
        parseProperty(headers, PRESTO_ROLE).forEach((key, value) -> {
            SelectedRole role;
            try {
                role = SelectedRole.valueOf(value);
            }
            catch (IllegalArgumentException e) {
                throw badRequest(format("Invalid %s header", PRESTO_ROLE));
            }
            roles.put(key, role);
        });
        return roles.build();
    }

    private static Map parseExtraCredentials(MultivaluedMap headers)
    {
        return parseProperty(headers, PRESTO_EXTRA_CREDENTIAL);
    }

    private static Map parseProperty(MultivaluedMap headers, String headerName)
    {
        Map properties = new HashMap<>();
        for (String header : splitHttpHeader(headers, headerName)) {
            List nameValue = Splitter.on('=').trimResults().splitToList(header);
            assertRequest(nameValue.size() == 2, "Invalid %s header", headerName);
            try {
                properties.put(nameValue.get(0), urlDecode(nameValue.get(1)));
            }
            catch (IllegalArgumentException e) {
                throw badRequest(format("Invalid %s header: %s", headerName, e));
            }
        }
        return properties;
    }

    private static Set parseClientTags(MultivaluedMap headers)
    {
        Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
        return ImmutableSet.copyOf(splitter.split(nullToEmpty(headers.getFirst(PRESTO_CLIENT_TAGS))));
    }

    private static Set parseClientCapabilities(MultivaluedMap headers)
    {
        Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
        return ImmutableSet.copyOf(splitter.split(nullToEmpty(headers.getFirst(PRESTO_CLIENT_CAPABILITIES))));
    }

    private static ResourceEstimates parseResourceEstimate(MultivaluedMap headers)
    {
        ResourceEstimateBuilder builder = new ResourceEstimateBuilder();
        parseProperty(headers, PRESTO_RESOURCE_ESTIMATE).forEach((name, value) -> {
            try {
                switch (name.toUpperCase(ENGLISH)) {
                    case ResourceEstimates.EXECUTION_TIME:
                        builder.setExecutionTime(Duration.valueOf(value));
                        break;
                    case ResourceEstimates.CPU_TIME:
                        builder.setCpuTime(Duration.valueOf(value));
                        break;
                    case ResourceEstimates.PEAK_MEMORY:
                        builder.setPeakMemory(DataSize.valueOf(value));
                        break;
                    default:
                        throw badRequest(format("Unsupported resource name %s", name));
                }
            }
            catch (IllegalArgumentException e) {
                throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e));
            }
        });

        return builder.build();
    }

    private static void assertRequest(boolean expression, String format, Object... args)
    {
        if (!expression) {
            throw badRequest(format(format, args));
        }
    }

    private static Map parsePreparedStatementsHeaders(MultivaluedMap headers)
    {
        ImmutableMap.Builder preparedStatements = ImmutableMap.builder();
        parseProperty(headers, PRESTO_PREPARED_STATEMENT).forEach((key, sqlString) -> {
            String statementName;
            try {
                statementName = urlDecode(key);
            }
            catch (IllegalArgumentException e) {
                throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
            }

            // Validate statement
            SqlParser sqlParser = new SqlParser();
            try {
                sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */));
            }
            catch (ParsingException e) {
                throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
            }

            preparedStatements.put(statementName, sqlString);
        });

        return preparedStatements.build();
    }

    private static Optional parseTransactionId(String transactionId)
    {
        transactionId = trimEmptyToNull(transactionId);
        if (transactionId == null || transactionId.equalsIgnoreCase("none")) {
            return Optional.empty();
        }
        try {
            return Optional.of(TransactionId.valueOf(transactionId));
        }
        catch (Exception e) {
            throw badRequest(e.getMessage());
        }
    }

    private static WebApplicationException badRequest(String message)
    {
        throw new WebApplicationException(message, Response
                .status(Status.BAD_REQUEST)
                .type(MediaType.TEXT_PLAIN)
                .entity(message)
                .build());
    }

    private static String trimEmptyToNull(String value)
    {
        return emptyToNull(nullToEmpty(value).trim());
    }

    private static String urlDecode(String value)
    {
        try {
            return URLDecoder.decode(value, "UTF-8");
        }
        catch (UnsupportedEncodingException e) {
            throw new AssertionError(e);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy