org.finos.tracdap.gateway.proxy.rest.RestApiProxy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tracdap-gateway Show documentation
Show all versions of tracdap-gateway Show documentation
TRAC D.A.P. gateway component, provides authentication, routing, load balancing and API translation
/*
* Copyright 2022 Accenture Global Solutions Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.finos.tracdap.gateway.proxy.rest;
import org.finos.tracdap.api.DownloadResponse;
import org.finos.tracdap.common.exception.*;
import org.finos.tracdap.gateway.proxy.grpc.GrpcUtils;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http2.*;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class RestApiProxy extends Http2ChannelDuplexHandler {
private static final Set FILTER_REQUEST_HEADERS = Set.of(
Http2Headers.PseudoHeaderName.METHOD.value().toString(),
Http2Headers.PseudoHeaderName.PATH.value().toString(),
HttpHeaderNames.CONTENT_TYPE.toString(),
HttpHeaderNames.CONTENT_LENGTH.toString(),
HttpHeaderNames.CONTENT_ENCODING.toString(),
HttpHeaderNames.ACCEPT.toString());
private static final Set FILTER_RESPONSE_HEADERS = Set.of(
Http2Headers.PseudoHeaderName.STATUS.value().toString(),
HttpHeaderNames.CONTENT_TYPE.toString(),
HttpHeaderNames.CONTENT_LENGTH.toString(),
HttpHeaderNames.CONTENT_ENCODING.toString());
private final Logger log = LoggerFactory.getLogger(getClass());
private final List> methods;
private final Map callStateMap;
public RestApiProxy(List> methods) {
this.methods = methods;
this.callStateMap = new HashMap<>();
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
try {
// REST proxy layer expects all messages to be HTTP/2 frames
if (!(msg instanceof Http2Frame))
throw new EUnexpected();
// Allow control messages to pass through the REST proxy
if (!(msg instanceof Http2StreamFrame)) {
ReferenceCountUtil.retain(msg);
ctx.write(msg, promise);
return;
}
var frame = (Http2StreamFrame) msg;
var stream = frame.stream();
if (!callStateMap.containsKey(stream)) {
var newState = new RestApiCallState(ctx, stream);
callStateMap.put(stream, newState);
}
var state = callStateMap.get(stream);
if (frame instanceof Http2HeadersFrame) {
var headersFrame = (Http2HeadersFrame) frame;
state.requestHeaders.add(headersFrame.headers());
if (headersFrame.isEndStream())
dispatchRequest(state, ctx, promise);
else
promise.setSuccess();
}
else if (frame instanceof Http2DataFrame) {
var dataFrame = (Http2DataFrame) frame;
if (dataFrame.content() != null && dataFrame.content().readableBytes() > 0) {
ReferenceCountUtil.retain(dataFrame.content());
state.requestContent.addComponent(true, dataFrame.content());
}
if (dataFrame.isEndStream())
dispatchRequest(state, ctx, promise);
else
promise.setSuccess();
}
else {
log.warn("Unexpected request frame type {} will be dropped", frame.name());
promise.setSuccess();
}
}
finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void channelRead(@Nonnull ChannelHandlerContext ctx, @Nonnull Object msg) throws Exception {
try {
// REST proxy layer expects all messages to be HTTP/2 frames
if (!(msg instanceof Http2Frame))
throw new EUnexpected();
// Allow control messages to pass through the REST proxy
if (!(msg instanceof Http2StreamFrame)) {
ReferenceCountUtil.retain(msg);
ctx.fireChannelRead(msg);
return;
}
var frame = (Http2StreamFrame) msg;
var stream = frame.stream();
var state = callStateMap.get(stream);
var unaryResponse = ! state.method.grpcMethod.isServerStreaming();
if (frame instanceof Http2HeadersFrame) {
var grpcFrame = (Http2HeadersFrame) frame;
var grpcHeaders = grpcFrame.headers();
state.responseHeaders.add(grpcHeaders);
if (unaryResponse) {
if (grpcFrame.isEndStream())
dispatchUnaryResponse(state, ctx);
}
else {
// Download endpoints need to wait for the first gRPC message before sending headers
if (!state.method.isDownload)
dispatchStreamHeaders(state, ctx, grpcFrame.isEndStream());
if (grpcFrame.isEndStream())
dispatchStreamComplete(state, ctx);
}
}
else if (frame instanceof Http2DataFrame) {
var grpcFrame = (Http2DataFrame) frame;
if (grpcFrame.content() != null && grpcFrame.content().readableBytes() > 0) {
ReferenceCountUtil.retain(grpcFrame.content());
state.responseContent.addComponent(true, grpcFrame.content());
}
if (unaryResponse) {
if (grpcFrame.isEndStream())
dispatchUnaryResponse(state, ctx);
}
else {
dispatchStreamContent(state, ctx);
if (grpcFrame.isEndStream())
dispatchStreamComplete(state, ctx);
}
}
else {
log.warn("Unexpected response frame type {} will be dropped", frame.name());
}
}
finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.error("There was an error", cause);
super.exceptionCaught(ctx, cause);
}
private void dispatchRequest(RestApiCallState state, ChannelHandlerContext ctx, ChannelPromise promise) {
try {
var url = state.requestHeaders.path().toString();
state.method = lookupMethod(state.requestHeaders);
state.translator = state.method != null ? state.method.translator : null;
if (state.method == null) {
sendErrorResponse(state.stream, ctx, HttpResponseStatus.NOT_FOUND, "REST API METHOD NOT FOUND");
promise.setFailure(new ENetworkHttp(HttpResponseStatus.NOT_FOUND.code(), "REST API METHOD NOT FOUND"));
return;
}
// TODO: This content type checking should happen in the router for grpc and rest routes
try {
checkRequestHeaders(state);
}
catch (EInputValidation e) {
sendErrorResponse(state.stream, ctx, HttpResponseStatus.NOT_ACCEPTABLE, e.getMessage());
return;
}
var restHeaders = state.requestHeaders;
var grpcHeaders = translateRequestHeaders(restHeaders, state);
var grpcMessage = state.method.hasBody
? state.translator.decodeRestRequest(url, state.requestContent)
: state.translator.decodeRestRequest(url);
var lpm = state.translator.encodeGrpcRequest(grpcMessage, ctx.alloc());
var headersFrame = new DefaultHttp2HeadersFrame(grpcHeaders).stream(state.stream);
var dataFrame = new DefaultHttp2DataFrame(lpm, true).stream(state.stream);
ctx.write(headersFrame);
ctx.write(dataFrame, promise);
}
catch (EInputValidation error) {
log.warn("Bad request in REST API: " + error.getMessage(), error);
// Validation errors can occur in the request builder
// These are from extracting fields from the URL, or translating JSON -> protobuf
// In this case, send some helpful information back about what cause the failure
var errorCode = HttpResponseStatus.BAD_REQUEST;
var errorMessage = error.getLocalizedMessage();
sendErrorResponse(state.stream, ctx, errorCode, errorMessage);
promise.setFailure(new ENetworkHttp(errorCode.code(), errorMessage));
}
finally {
if (state.requestContent != null) {
state.requestContent.release();
state.requestContent = null;
}
}
}
private void dispatchUnaryResponse(RestApiCallState state, ChannelHandlerContext ctx) {
try {
var grpcHeaders = state.responseHeaders;
var grpcContentLength = grpcHeaders.getInt(HttpHeaderNames.CONTENT_LENGTH);
if (grpcContentLength != null && state.responseContent.readableBytes() != grpcContentLength)
throw new ENetworkHttp(HttpResponseStatus.BAD_GATEWAY.code(), "Garbled response form server");
ByteBuf restResponse;
if (state.responseContent.readableBytes() == 0) {
restResponse = Unpooled.EMPTY_BUFFER;
}
else {
var grpcMessage = state.translator.decodeGrpcResponse(state.responseContent);
restResponse = state.translator.encodeRestResponse(grpcMessage);
}
var restHeaders = translateResponseHeaders(grpcHeaders, state, /* streaming = */ false);
restHeaders.add(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8");
restHeaders.add(HttpHeaderNames.CONTENT_LENGTH, Integer.toString(restResponse.readableBytes()));
var headersFrame = new DefaultHttp2HeadersFrame(restHeaders).stream(state.stream);
var dataFrame = new DefaultHttp2DataFrame(restResponse, true).stream(state.stream);
ctx.fireChannelRead(headersFrame);
ctx.fireChannelRead(dataFrame);
}
catch (Exception error) {
log.warn("Bad request in REST API: " + error.getMessage(), error);
// Validation errors can occur in the request builder
// These are from extracting fields from the URL, or translating JSON -> protobuf
// In this case, send some helpful information back about what cause the failure
var errorCode = HttpResponseStatus.INTERNAL_SERVER_ERROR;
var errorMessage = error.getLocalizedMessage();
sendErrorResponse(state.stream, ctx, errorCode, errorMessage);
}
finally {
if (state.responseContent != null) {
state.responseContent.release();
state.responseContent = null;
}
}
}
private void dispatchStreamHeaders(RestApiCallState state, ChannelHandlerContext ctx, boolean eos) {
if (!state.responseHeadersSent) {
var restHeaders = translateResponseHeaders(state.responseHeaders, state, true);
var restFrame = new DefaultHttp2HeadersFrame(restHeaders, eos);
ctx.fireChannelRead(restFrame);
state.responseHeadersSent = true;
state.responseHttpStatus = HttpResponseStatus.parseLine(restHeaders.status());
}
}
private void dispatchDownloadStreamHeaders(RestApiCallState state, ChannelHandlerContext ctx, DownloadResponse downloadResponse) {
if (!state.responseHeadersSent) {
var restHeaders = translateResponseHeaders(state.responseHeaders, state, true);
if (downloadResponse.hasContentType())
restHeaders.add(HttpHeaderNames.CONTENT_TYPE, downloadResponse.getContentType());
if (downloadResponse.hasContentLength())
restHeaders.addLong(HttpHeaderNames.CONTENT_LENGTH, downloadResponse.getContentLength());
var restFrame = new DefaultHttp2HeadersFrame(restHeaders);
ctx.fireChannelRead(restFrame);
state.responseHeadersSent = true;
state.responseHttpStatus = HttpResponseStatus.parseLine(restHeaders.status());
}
}
private void dispatchStreamContent(RestApiCallState state, ChannelHandlerContext ctx) {
try {
while (GrpcUtils.canDecodeLpm(state.responseContent)) {
var msg = state.translator.decodeGrpcResponse(state.responseContent);
// This is some TRAC magic to set the correct content headers for REST-ful data download streams
// It is not possible to define this behavior using the HTTP options in the proto file
if (!state.responseHeadersSent && msg instanceof DownloadResponse)
dispatchDownloadStreamHeaders(state, ctx, (DownloadResponse) msg);
var httpContent = state.translator.encodeRestResponse(msg);
var dataFrame = new DefaultHttp2DataFrame(httpContent).stream(state.stream);
ctx.fireChannelRead(dataFrame);
}
}
finally {
state.responseContent.discardReadComponents();
}
}
private void dispatchStreamComplete(RestApiCallState state, ChannelHandlerContext ctx) {
if (!state.responseHeadersSent)
dispatchStreamHeaders(state, ctx, true);
// It is possible the stream failed after sending an initial 200 OK HTTP HEADERS frame
// In this case, fire an exception to try and cause an unclean shutdown of the channel
// Otherwise, send an empty EOS data frame
// Sending trailers is not compatible with HTTP/1 clients
var finalGrpcStatus = state.responseHeaders.getInt("grpc-status", Status.Code.UNKNOWN.value());
var finalGrpcCode = Status.fromCodeValue(finalGrpcStatus).getCode();
var finalHttpStatus = state.translator.translateGrpcErrorCode(finalGrpcCode);
if (state.responseHttpStatus.equals(HttpResponseStatus.OK) && !finalHttpStatus.equals(HttpResponseStatus.OK)) {
var error = new ETracInternal("Download stream failed with error code " + finalGrpcCode.name());
ctx.fireExceptionCaught(error);
}
else {
ctx.fireChannelRead(new DefaultHttp2DataFrame(true));
}
}
private RestApiMethod, ?> lookupMethod(Http2Headers headers) {
for (var method: this.methods) {
var httpMethod = HttpMethod.valueOf(headers.method().toString());
var uri = URI.create(headers.path().toString());
if (method.matcher.matches(httpMethod, uri))
return method;
}
return null;
}
private Http2Headers translateRequestHeaders(Http2Headers restHeaders, RestApiCallState state) {
var grpcHeaders = new DefaultHttp2Headers();
// Bring across all response headers that do not have special handling
for (var header : restHeaders) {
if (!FILTER_REQUEST_HEADERS.contains(header.getKey().toString()))
grpcHeaders.add(header.getKey(), header.getValue());
}
// gRPC method
var grpcMethod = state.method.grpcMethod;
var httpPath = String.format("/%s/%s", grpcMethod.getService().getFullName(), grpcMethod.getName());
grpcHeaders.method(HttpMethod.POST.asciiName());
grpcHeaders.path(httpPath);
grpcHeaders.add(HttpHeaderNames.TE, "trailers");
// Content headers
grpcHeaders.add(HttpHeaderNames.CONTENT_TYPE, "application/grpc+proto");
grpcHeaders.add(HttpHeaderNames.ACCEPT, "application/grpc+proto");
return grpcHeaders;
}
private Http2Headers translateResponseHeaders(Http2Headers grpcHeaders, RestApiCallState state, boolean streaming) {
var restHeaders = new DefaultHttp2Headers();
// Bring across all response headers that do not have special handling
for (var header : grpcHeaders) {
if (!FILTER_RESPONSE_HEADERS.contains(header.getKey().toString()))
restHeaders.add(header.getKey(), header.getValue());
}
// Figure out the right HTTP response code
var httpStatus = HttpResponseStatus.parseLine(grpcHeaders.status());
var grpcStatus = grpcHeaders.getInt("grpc-status");
var grpcMessage = grpcHeaders.get("grpc-message");
// If the request fails at the HTTP level, the HTTP error is the response code
if (httpStatus.code() != HttpResponseStatus.OK.code()) {
restHeaders.status(httpStatus.toString());
}
// If gRPC status code is available, translate that
else if (grpcStatus != null) {
var grpcStatusCode = Status.fromCodeValue(grpcStatus).getCode();
var restStatusCode = state.method.translator.translateGrpcErrorCode(grpcStatusCode);
var restMessage = grpcMessage != null ? grpcMessage.toString() : restStatusCode.reasonPhrase();
var restStatus = new HttpResponseStatus(restStatusCode.code(), restMessage);
restHeaders.status(restStatus.toString());
}
// For streaming responses gRPC status is not known until the stream completes
// Unless there is an early error we have to send OK to start sending content
else if (streaming) {
restHeaders.status(HttpResponseStatus.OK.toString());
}
// Otherwise if the status is not known that is an error
else {
var restStatusCode = HttpResponseStatus.BAD_REQUEST;
var restStatus = new HttpResponseStatus(restStatusCode.code(), "RESPONSE STATUS UNKNOWN");
restHeaders.status(restStatus.toString());
}
return restHeaders;
}
private void checkRequestHeaders(RestApiCallState state) {
var restHeaders = state.requestHeaders;
// POST methods have a JSON payload for the encoded gRPC request type
// GET methods do not have a body, the request is entirely built from the URL
if (state.method.hasBody) {
if (!restHeaders.contains(HttpHeaderNames.CONTENT_TYPE))
throw new EInputValidation("Missing required HTTP header [" + HttpHeaderNames.CONTENT_TYPE + "]");
var contentTypeHeader = restHeaders.get(HttpHeaderNames.CONTENT_TYPE).toString().split(";")[0];
if (!contentTypeHeader.equals("application/json"))
throw new EInputValidation("Invalid [content-type] header (expected application/json for REST calls)");
}
else {
if (restHeaders.contains(HttpHeaderNames.CONTENT_TYPE))
throw new EInputValidation("Unexpected HTTP header [" + HttpHeaderNames.CONTENT_TYPE + "]");
}
// Unary methods return a JSON encoding of the gRPC response type
// Server streaming methods are for downloads and must accept whatever type the server sends
// E.g. downloading a file, the response content type will depend on the type of the file
if (!state.method.grpcMethod.isServerStreaming()) {
if (!restHeaders.contains(HttpHeaderNames.ACCEPT))
throw new EInputValidation("Missing required HTTP header [" + HttpHeaderNames.ACCEPT + "]");
var acceptHeader = restHeaders.get(HttpHeaderNames.ACCEPT).toString().split(";")[0];
if (!acceptHeader.equals("application/json"))
throw new EInputValidation("Invalid [accept] header (expected application/json for REST calls)");
}
else {
if (restHeaders.contains(HttpHeaderNames.ACCEPT))
throw new EInputValidation("Unexpected HTTP header [" + HttpHeaderNames.ACCEPT + "]");
}
}
private void sendErrorResponse(
Http2FrameStream stream, ChannelHandlerContext ctx,
HttpResponseStatus errorStatus, String errorMessage) {
var headers = new DefaultHttp2Headers();
headers.status(errorStatus.toString());
if (errorMessage == null || errorMessage.isEmpty()) {
var headersFrame = new DefaultHttp2HeadersFrame(headers, true).stream(stream);
ctx.fireChannelRead(headersFrame);
ctx.fireChannelReadComplete();
}
else {
// Putting error message in response body for now
// It may be more appropriate in a header
var content = ctx.alloc().buffer();
content.writeCharSequence(errorMessage, StandardCharsets.UTF_8);
headers.set(HttpHeaderNames.CONTENT_TYPE, "text/plain; charset=UTF-8");
headers.setInt(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes());
var headersFrame = new DefaultHttp2HeadersFrame(headers, false).stream(stream);
var dataFrame = new DefaultHttp2DataFrame(content, true).stream(stream);
ctx.fireChannelRead(headersFrame);
ctx.fireChannelRead(dataFrame);
ctx.fireChannelReadComplete();
}
}
private static class RestApiCallState {
Http2FrameStream stream;
RestApiMethod, ?> method;
RestApiTranslator, ?> translator;
Http2Headers requestHeaders;
CompositeByteBuf requestContent;
Http2Headers responseHeaders;
CompositeByteBuf responseContent;
boolean responseHeadersSent;
HttpResponseStatus responseHttpStatus;
RestApiCallState(ChannelHandlerContext ctx, Http2FrameStream stream) {
this.stream = stream;
this.requestHeaders = new DefaultHttp2Headers();
this.requestContent = ctx.alloc().compositeBuffer();
this.responseHeaders = new DefaultHttp2Headers();
this.responseContent = ctx.alloc().compositeBuffer();
}
}
}