io.trino.proxy.ProxyResource Maven / Gradle / Ivy
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.proxy;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.google.common.hash.HashCode;
import com.google.common.hash.HashFunction;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.proxy.ProxyResponseHandler.ProxyResponse;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.AsyncResponse;
import jakarta.ws.rs.container.Suspended;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.ResponseBuilder;
import jakarta.ws.rs.core.Response.Status;
import jakarta.ws.rs.core.UriInfo;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import static com.fasterxml.jackson.core.JsonFactory.Feature.CANONICALIZE_FIELD_NAMES;
import static com.fasterxml.jackson.core.JsonToken.END_OBJECT;
import static com.fasterxml.jackson.core.JsonToken.FIELD_NAME;
import static com.fasterxml.jackson.core.JsonToken.START_OBJECT;
import static com.fasterxml.jackson.core.JsonToken.VALUE_STRING;
import static com.google.common.hash.Hashing.hmacSha256;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static com.google.common.net.HttpHeaders.COOKIE;
import static com.google.common.net.HttpHeaders.SET_COOKIE;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.airlift.http.client.Request.Builder.prepareDelete;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.plugin.base.util.JsonUtils.jsonFactoryBuilder;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
import static jakarta.ws.rs.core.Response.Status.FORBIDDEN;
import static jakarta.ws.rs.core.Response.noContent;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.Files.readAllBytes;
import static java.util.Collections.list;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.TimeUnit.MINUTES;
@Path("/")
public class ProxyResource
{
private static final Logger log = Logger.get(ProxyResource.class);
private static final String X509_ATTRIBUTE = "jakarta.servlet.request.X509Certificate";
private static final Duration ASYNC_TIMEOUT = new Duration(2, MINUTES);
private static final JsonFactory JSON_FACTORY = jsonFactoryBuilder().disable(CANONICALIZE_FIELD_NAMES).build();
private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("proxy-%s"));
private final HttpClient httpClient;
private final JsonWebTokenHandler jwtHandler;
private final URI remoteUri;
private final HashFunction hmac;
@Inject
public ProxyResource(@ForProxy HttpClient httpClient, JsonWebTokenHandler jwtHandler, ProxyConfig config)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.jwtHandler = requireNonNull(jwtHandler, "jwtHandler is null");
this.remoteUri = requireNonNull(config.getUri(), "uri is null");
this.hmac = hmacSha256(loadSharedSecret(config.getSharedSecretFile()));
}
@PreDestroy
public void shutdown()
{
executor.shutdownNow();
}
@GET
@Path("/v1/info")
@Produces(APPLICATION_JSON)
public void getInfo(
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
Request.Builder request = prepareGet()
.setUri(uriBuilderFrom(remoteUri).replacePath("/v1/info").build());
performRequest(servletRequest, asyncResponse, request, response ->
responseWithHeaders(Response.ok(response.getBody()), response));
}
@POST
@Path("/v1/statement")
@Produces(APPLICATION_JSON)
public void postStatement(
String statement,
@Context HttpServletRequest servletRequest,
@Context UriInfo uriInfo,
@Suspended AsyncResponse asyncResponse)
{
Request.Builder request = preparePost()
.setUri(uriBuilderFrom(remoteUri).replacePath("/v1/statement").build())
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
}
@GET
@Path("/v1/proxy")
@Produces(APPLICATION_JSON)
public void getNext(
@QueryParam("uri") String uri,
@QueryParam("hmac") String hash,
@Context HttpServletRequest servletRequest,
@Context UriInfo uriInfo,
@Suspended AsyncResponse asyncResponse)
{
if (!hmac.hashString(uri, UTF_8).equals(HashCode.fromString(hash))) {
throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
}
Request.Builder request = prepareGet().setUri(URI.create(uri));
performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
}
@DELETE
@Path("/v1/proxy")
@Produces(APPLICATION_JSON)
public void cancelQuery(
@QueryParam("uri") String uri,
@QueryParam("hmac") String hash,
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
if (!hmac.hashString(uri, UTF_8).equals(HashCode.fromString(hash))) {
throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
}
Request.Builder request = prepareDelete().setUri(URI.create(uri));
performRequest(servletRequest, asyncResponse, request, response -> responseWithHeaders(noContent(), response));
}
private void performRequest(
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
Request.Builder requestBuilder,
Function responseBuilder)
{
setupBearerToken(servletRequest, requestBuilder);
for (String name : list(servletRequest.getHeaderNames())) {
if (isTrinoHeader(name) || name.equalsIgnoreCase(COOKIE)) {
for (String value : list(servletRequest.getHeaders(name))) {
requestBuilder.addHeader(name, value);
}
}
else if (name.equalsIgnoreCase(USER_AGENT)) {
for (String value : list(servletRequest.getHeaders(name))) {
requestBuilder.addHeader(name, "[Trino Proxy] " + value);
}
}
}
Request request = requestBuilder
.setPreserveAuthorizationOnRedirect(true)
.build();
ListenableFuture future = executeHttp(request)
.transform(responseBuilder::apply, executor)
.catching(ProxyException.class, e -> handleProxyException(request, e), directExecutor());
setupAsyncResponse(asyncResponse, future);
}
private Response buildResponse(UriInfo uriInfo, ProxyResponse response)
{
byte[] body = rewriteResponse(response.getBody(), uri -> rewriteUri(uriInfo, uri));
return responseWithHeaders(Response.ok(body), response);
}
private String rewriteUri(UriInfo uriInfo, String uri)
{
return uriInfo.getAbsolutePathBuilder()
.replacePath("/v1/proxy")
.queryParam("uri", uri)
.queryParam("hmac", hmac.hashString(uri, UTF_8))
.build()
.toString();
}
private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture future)
{
bindAsyncResponse(asyncResponse, future, executor)
.withTimeout(ASYNC_TIMEOUT, () -> Response
.status(BAD_GATEWAY)
.type(TEXT_PLAIN_TYPE)
.entity("Request to remote Trino server timed out after" + ASYNC_TIMEOUT)
.build());
}
private FluentFuture executeHttp(Request request)
{
return FluentFuture.from(httpClient.executeAsync(request, new ProxyResponseHandler()));
}
private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder requestBuilder)
{
if (!jwtHandler.isConfigured()) {
return;
}
X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE);
if ((certs == null) || (certs.length == 0)) {
throw badRequest(FORBIDDEN, "No TLS certificate present for request");
}
String principal = certs[0].getSubjectX500Principal().getName();
String accessToken = jwtHandler.getBearerToken(principal);
requestBuilder.addHeader(AUTHORIZATION, "Bearer " + accessToken);
}
private static T handleProxyException(Request request, ProxyException e)
{
log.warn(e, "Proxy request failed: %s %s", request.getMethod(), request.getUri());
throw badRequest(BAD_GATEWAY, e.getMessage());
}
private static WebApplicationException badRequest(Status status, String message)
{
throw new WebApplicationException(
Response.status(status)
.type(TEXT_PLAIN_TYPE)
.entity(message)
.build());
}
private static boolean isTrinoHeader(String name)
{
return name.toLowerCase(ENGLISH).startsWith("x-trino-");
}
private static Response responseWithHeaders(ResponseBuilder builder, ProxyResponse response)
{
response.getHeaders().forEach((headerName, value) -> {
String name = headerName.toString();
if (isTrinoHeader(name) || name.equalsIgnoreCase(SET_COOKIE)) {
builder.header(name, value);
}
});
return builder.build();
}
private static byte[] rewriteResponse(byte[] input, Function uriRewriter)
{
try {
JsonParser parser = JSON_FACTORY.createParser(input);
ByteArrayOutputStream out = new ByteArrayOutputStream(input.length * 2);
JsonGenerator generator = JSON_FACTORY.createGenerator(out);
JsonToken token = parser.nextToken();
if (token != START_OBJECT) {
throw invalidJson("bad start token: " + token);
}
generator.copyCurrentEvent(parser);
while (true) {
token = parser.nextToken();
if (token == null) {
throw invalidJson("unexpected end of stream");
}
if (token == END_OBJECT) {
generator.copyCurrentEvent(parser);
break;
}
if (token == FIELD_NAME) {
String name = parser.getValueAsString();
if (!"nextUri".equals(name) && !"partialCancelUri".equals(name)) {
generator.copyCurrentStructure(parser);
continue;
}
token = parser.nextToken();
if (token != VALUE_STRING) {
throw invalidJson(format("bad %s token: %s", name, token));
}
String value = parser.getValueAsString();
value = uriRewriter.apply(value);
generator.writeStringField(name, value);
continue;
}
throw invalidJson("unexpected token: " + token);
}
token = parser.nextToken();
if (token != null) {
throw invalidJson("unexpected token after object close: " + token);
}
generator.close();
return out.toByteArray();
}
catch (IOException e) {
throw new ProxyException(e);
}
}
private static IOException invalidJson(String message)
{
return new IOException("Invalid JSON response from remote Trino server: " + message);
}
private static byte[] loadSharedSecret(File file)
{
try {
return Base64.getMimeDecoder().decode(readAllBytes(file.toPath()));
}
catch (IOException | IllegalArgumentException e) {
throw new RuntimeException("Failed to load shared secret file: " + file, e);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy