![JAR search and dependency download from the Maven repository](/logo.png)
io.trino.server.HttpRequestSessionContextFactory Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session.ResourceEstimateBuilder;
import io.trino.client.ProtocolDetectionException;
import io.trino.client.ProtocolHeaders;
import io.trino.metadata.Metadata;
import io.trino.security.AccessControl;
import io.trino.server.protocol.PreparedStatementEncoder;
import io.trino.spi.security.AccessDeniedException;
import io.trino.spi.security.GroupProvider;
import io.trino.spi.security.Identity;
import io.trino.spi.security.SelectedRole;
import io.trino.spi.security.SelectedRole.Type;
import io.trino.spi.session.ResourceEstimates;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.parser.SqlParser;
import io.trino.transaction.TransactionId;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.net.URLDecoder;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static io.trino.client.ProtocolHeaders.detectProtocol;
import static io.trino.spi.security.AccessDeniedException.denySetRole;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
public class HttpRequestSessionContextFactory
{
private static final Splitter DOT_SPLITTER = Splitter.on('.');
public static final String AUTHENTICATED_IDENTITY = "trino.authenticated-identity";
private final PreparedStatementEncoder preparedStatementEncoder;
private final Metadata metadata;
private final GroupProvider groupProvider;
private final AccessControl accessControl;
@Inject
public HttpRequestSessionContextFactory(PreparedStatementEncoder preparedStatementEncoder, Metadata metadata, GroupProvider groupProvider, AccessControl accessControl)
{
this.preparedStatementEncoder = requireNonNull(preparedStatementEncoder, "preparedStatementEncoder is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.groupProvider = requireNonNull(groupProvider, "groupProvider is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
}
public SessionContext createSessionContext(
MultivaluedMap headers,
Optional alternateHeaderName,
Optional remoteAddress,
Optional authenticatedIdentity)
throws WebApplicationException
{
ProtocolHeaders protocolHeaders;
try {
protocolHeaders = detectProtocol(alternateHeaderName, headers.keySet());
}
catch (ProtocolDetectionException e) {
throw badRequest(e.getMessage());
}
Optional catalog = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog())));
Optional schema = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema())));
Optional path = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestPath())));
assertRequest((catalog.isPresent()) || (schema.isEmpty()), "Schema is set but catalog is not");
requireNonNull(authenticatedIdentity, "authenticatedIdentity is null");
Identity identity = buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers);
Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers);
SelectedRole selectedRole = parseSystemRoleHeaders(protocolHeaders, headers);
Optional source = Optional.ofNullable(headers.getFirst(protocolHeaders.requestSource()));
Optional traceToken = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestTraceToken())));
Optional userAgent = Optional.ofNullable(headers.getFirst(USER_AGENT));
Optional remoteUserAddress = requireNonNull(remoteAddress, "remoteAddress is null");
Optional timeZoneId = Optional.ofNullable(headers.getFirst(protocolHeaders.requestTimeZone()));
Optional language = Optional.ofNullable(headers.getFirst(protocolHeaders.requestLanguage()));
Optional clientInfo = Optional.ofNullable(headers.getFirst(protocolHeaders.requestClientInfo()));
Set clientTags = parseClientTags(protocolHeaders, headers);
Set clientCapabilities = parseClientCapabilities(protocolHeaders, headers);
ResourceEstimates resourceEstimates = parseResourceEstimate(protocolHeaders, headers);
// parse session properties
ImmutableMap.Builder systemProperties = ImmutableMap.builder();
Map> catalogSessionProperties = new HashMap<>();
for (Entry entry : parseSessionHeaders(protocolHeaders, headers).entrySet()) {
String fullPropertyName = entry.getKey();
String propertyValue = entry.getValue();
List nameParts = DOT_SPLITTER.splitToList(fullPropertyName);
if (nameParts.size() == 1) {
String propertyName = nameParts.get(0);
assertRequest(!propertyName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
// catalog session properties cannot be validated until the transaction has stated, so we delay system property validation also
systemProperties.put(propertyName, propertyValue);
}
else if (nameParts.size() == 2) {
String catalogName = nameParts.get(0);
String propertyName = nameParts.get(1);
assertRequest(!catalogName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
assertRequest(!propertyName.isEmpty(), "Invalid %s header", protocolHeaders.requestSession());
// catalog session properties cannot be validated until the transaction has stated
catalogSessionProperties.computeIfAbsent(catalogName, id -> new HashMap<>()).put(propertyName, propertyValue);
}
else {
throw badRequest(format("Invalid %s header", protocolHeaders.requestSession()));
}
}
requireNonNull(catalogSessionProperties, "catalogSessionProperties is null");
catalogSessionProperties = catalogSessionProperties.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> ImmutableMap.copyOf(entry.getValue())));
Map preparedStatements = parsePreparedStatementsHeaders(protocolHeaders, headers);
String transactionIdHeader = headers.getFirst(protocolHeaders.requestTransactionId());
boolean clientTransactionSupport = transactionIdHeader != null;
Optional transactionId = parseTransactionId(transactionIdHeader);
return new SessionContext(
protocolHeaders,
catalog,
schema,
path,
authenticatedIdentity,
identity,
originalIdentity,
selectedRole,
source,
traceToken,
userAgent,
remoteUserAddress,
timeZoneId,
language,
clientTags,
clientCapabilities,
resourceEstimates,
systemProperties.buildOrThrow(),
catalogSessionProperties,
preparedStatements,
transactionId,
clientTransactionSupport,
clientInfo);
}
public Identity extractAuthorizedIdentity(
HttpServletRequest servletRequest,
HttpHeaders httpHeaders,
Optional alternateHeaderName)
{
return extractAuthorizedIdentity(
Optional.ofNullable((Identity) servletRequest.getAttribute(AUTHENTICATED_IDENTITY)),
httpHeaders.getRequestHeaders(),
alternateHeaderName);
}
public Identity extractAuthorizedIdentity(
Optional optionalAuthenticatedIdentity,
MultivaluedMap headers,
Optional alternateHeaderName)
throws AccessDeniedException
{
ProtocolHeaders protocolHeaders;
try {
protocolHeaders = detectProtocol(alternateHeaderName, headers.keySet());
}
catch (ProtocolDetectionException e) {
throw badRequest(e.getMessage());
}
Identity identity = buildSessionIdentity(optionalAuthenticatedIdentity, protocolHeaders, headers);
Identity originalIdentity = buildSessionOriginalIdentity(identity, protocolHeaders, headers);
accessControl.checkCanSetUser(originalIdentity.getPrincipal(), originalIdentity.getUser());
// authenticated may not present for HTTP or if authentication is not setup
optionalAuthenticatedIdentity.ifPresent(authenticatedIdentity -> {
// only check impersonation if authenticated user is not the same as the explicitly set user
if (!authenticatedIdentity.getUser().equals(originalIdentity.getUser())) {
// load enabled roles for authenticated identity, so impersonation permissions can be assigned to roles
authenticatedIdentity = Identity.from(authenticatedIdentity)
.withEnabledRoles(metadata.listEnabledRoles(authenticatedIdentity))
.build();
accessControl.checkCanImpersonateUser(authenticatedIdentity, originalIdentity.getUser());
}
});
if (!originalIdentity.getUser().equals(identity.getUser())) {
accessControl.checkCanSetUser(originalIdentity.getPrincipal(), identity.getUser());
accessControl.checkCanImpersonateUser(originalIdentity, identity.getUser());
}
return addEnabledRoles(identity, parseSystemRoleHeaders(protocolHeaders, headers), metadata);
}
public static Identity addEnabledRoles(Identity identity, SelectedRole selectedRole, Metadata metadata)
{
if (selectedRole.getType() == Type.NONE) {
return identity;
}
Set enabledRoles = metadata.listEnabledRoles(identity);
if (selectedRole.getType() == Type.ROLE) {
String role = selectedRole.getRole().orElseThrow();
if (!enabledRoles.contains(role)) {
denySetRole(role);
}
enabledRoles = ImmutableSet.of(role);
}
return Identity.from(identity)
.withEnabledRoles(enabledRoles)
.build();
}
private Identity buildSessionIdentity(Optional authenticatedIdentity, ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
String trinoUser = trimEmptyToNull(headers.getFirst(protocolHeaders.requestUser()));
String user = trinoUser != null ? trinoUser : authenticatedIdentity.map(Identity::getUser).orElse(null);
assertRequest(user != null, "User must be set");
SelectedRole systemRole = parseSystemRoleHeaders(protocolHeaders, headers);
ImmutableSet.Builder systemEnabledRoles = ImmutableSet.builder();
if (systemRole.getType() == Type.ROLE) {
systemEnabledRoles.add(systemRole.getRole().orElseThrow());
}
return authenticatedIdentity
.map(identity -> Identity.from(identity).withUser(user))
.orElseGet(() -> Identity.forUser(user))
.withEnabledRoles(systemEnabledRoles.build())
.withAdditionalConnectorRoles(parseConnectorRoleHeaders(protocolHeaders, headers))
.withAdditionalExtraCredentials(parseExtraCredentials(protocolHeaders, headers))
.withAdditionalGroups(groupProvider.getGroups(user))
.build();
}
private Identity buildSessionOriginalIdentity(Identity identity, ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
// We derive original identity using this header, but older clients will not send it, so fall back to identity
Optional optionalOriginalUser = Optional
.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestOriginalUser())));
Identity originalIdentity = optionalOriginalUser.map(originalUser -> Identity.from(identity)
.withUser(originalUser)
.withExtraCredentials(new HashMap<>())
.withGroups(groupProvider.getGroups(originalUser))
.build())
.orElse(identity);
return originalIdentity;
}
private static List splitHttpHeader(MultivaluedMap headers, String name)
{
List values = firstNonNull(headers.get(name), ImmutableList.of());
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
return values.stream()
.map(splitter::splitToList)
.flatMap(Collection::stream)
.collect(toImmutableList());
}
private static Map parseSessionHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
return parseProperty(headers, protocolHeaders.requestSession());
}
private static SelectedRole parseSystemRoleHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
return parseProperty(headers, protocolHeaders.requestRole()).entrySet().stream()
.filter(entry -> entry.getKey().equalsIgnoreCase("system"))
.map(Entry::getValue)
.map(role -> toSelectedRole(protocolHeaders, role))
.findFirst()
.orElse(new SelectedRole(Type.ALL, Optional.empty()));
}
private static Map parseConnectorRoleHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
ImmutableMap.Builder roles = ImmutableMap.builder();
parseProperty(headers, protocolHeaders.requestRole()).forEach((key, value) -> {
if (key.equalsIgnoreCase("system")) {
return;
}
roles.put(key, toSelectedRole(protocolHeaders, value));
});
return roles.buildOrThrow();
}
private static SelectedRole toSelectedRole(ProtocolHeaders protocolHeaders, String value)
{
SelectedRole role;
try {
role = SelectedRole.valueOf(value);
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header", protocolHeaders.requestRole()));
}
return role;
}
private static Map parseExtraCredentials(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
return parseProperty(headers, protocolHeaders.requestExtraCredential());
}
private static Map parseProperty(MultivaluedMap headers, String headerName)
{
Map properties = new HashMap<>();
for (String header : splitHttpHeader(headers, headerName)) {
List nameValue = Splitter.on('=').trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", headerName);
try {
properties.put(nameValue.get(0), urlDecode(nameValue.get(1)));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", headerName, e));
}
}
return properties;
}
private static Set parseClientTags(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
return ImmutableSet.copyOf(splitter.split(nullToEmpty(headers.getFirst(protocolHeaders.requestClientTags()))));
}
private static Set parseClientCapabilities(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
return ImmutableSet.copyOf(splitter.split(nullToEmpty(headers.getFirst(protocolHeaders.requestClientCapabilities()))));
}
private static ResourceEstimates parseResourceEstimate(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
ResourceEstimateBuilder builder = new ResourceEstimateBuilder();
parseProperty(headers, protocolHeaders.requestResourceEstimate()).forEach((name, value) -> {
try {
switch (name.toUpperCase(ENGLISH)) {
case ResourceEstimates.EXECUTION_TIME:
builder.setExecutionTime(Duration.valueOf(value));
return;
case ResourceEstimates.CPU_TIME:
builder.setCpuTime(Duration.valueOf(value));
return;
case ResourceEstimates.PEAK_MEMORY:
builder.setPeakMemory(DataSize.valueOf(value));
return;
}
throw badRequest(format("Unsupported resource name %s", name));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e));
}
});
return builder.build();
}
private static void assertRequest(boolean expression, String format, Object... args)
{
if (!expression) {
throw badRequest(format(format, args));
}
}
private Map parsePreparedStatementsHeaders(ProtocolHeaders protocolHeaders, MultivaluedMap headers)
{
ImmutableMap.Builder preparedStatements = ImmutableMap.builder();
parseProperty(headers, protocolHeaders.requestPreparedStatement()).forEach((key, value) -> {
String statementName;
try {
statementName = urlDecode(key);
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
}
String sqlString = preparedStatementEncoder.decodePreparedStatementFromHeader(value);
// Validate statement
SqlParser sqlParser = new SqlParser();
try {
sqlParser.createStatement(sqlString);
}
catch (ParsingException e) {
throw badRequest(format("Invalid %s header: %s", protocolHeaders.requestPreparedStatement(), e.getMessage()));
}
preparedStatements.put(statementName, sqlString);
});
return preparedStatements.buildOrThrow();
}
private static Optional parseTransactionId(String transactionId)
{
transactionId = trimEmptyToNull(transactionId);
if (transactionId == null || transactionId.equalsIgnoreCase("none")) {
return Optional.empty();
}
try {
return Optional.of(TransactionId.valueOf(transactionId));
}
catch (Exception e) {
throw badRequest(e.getMessage());
}
}
private static WebApplicationException badRequest(String message)
{
throw new WebApplicationException(message, Response
.status(Status.BAD_REQUEST)
.type(MediaType.TEXT_PLAIN)
.entity(message)
.build());
}
private static String trimEmptyToNull(String value)
{
return emptyToNull(nullToEmpty(value).trim());
}
private static String urlDecode(String value)
{
return URLDecoder.decode(value, UTF_8);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy