com.linecorp.armeria.server.thrift.THttpService Maven / Gradle / Ivy
Show all versions of armeria-thrift0.9 Show documentation
/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.thrift;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.ImmutableSet;
import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.RpcResponse;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.thrift.ThriftSerializationFormats;
import com.linecorp.armeria.common.util.CompletionActions;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.internal.common.thrift.TByteBufTransport;
import com.linecorp.armeria.internal.common.thrift.ThriftFieldAccess;
import com.linecorp.armeria.internal.common.thrift.ThriftFunction;
import com.linecorp.armeria.server.DecoratingService;
import com.linecorp.armeria.server.HttpResponseException;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.HttpStatusException;
import com.linecorp.armeria.server.RpcService;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import io.netty.buffer.ByteBuf;
/**
* An {@link HttpService} that handles a Thrift call.
*
* @see ThriftProtocolFactories
*/
public final class THttpService extends DecoratingService
implements HttpService {
private static final Logger logger = LoggerFactory.getLogger(THttpService.class);
private static final String PROTOCOL_NOT_SUPPORTED = "Specified content-type not supported";
private static final String ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE =
"Thrift protocol specified in Accept header must match " +
"the one specified in the content-type header";
/**
* Creates a new instance of {@link THttpServiceBuilder} which can build an instance of {@link THttpService}
* fluently.
*
* The default SerializationFormat {@link ThriftSerializationFormats#BINARY} will be used when client
* does not specify one in the request, but also supports {@link ThriftSerializationFormats#values()}.
*
*
* Currently, the only way to specify a serialization format is by using the HTTP session
* protocol and setting the {@code "Content-Type"} header to the appropriate
* {@link SerializationFormat#mediaType()}.
*/
public static THttpServiceBuilder builder() {
return new THttpServiceBuilder();
}
/**
* Creates a new {@link THttpService} with the specified service implementation, supporting all thrift
* protocols and defaulting to {@link ThriftSerializationFormats#BINARY TBinary} protocol when the client
* doesn't specify one.
*
*
Currently, the only way to specify a serialization format is by using the HTTP session
* protocol and setting the {@code "Content-Type"} header to the appropriate
* {@link SerializationFormat#mediaType()}.
*
* @param implementation an implementation of {@code *.Iface} or {@code *.AsyncIface} service interface
* generated by the Apache Thrift compiler
*/
public static THttpService of(Object implementation) {
return of(implementation, ThriftSerializationFormats.BINARY);
}
/**
* Creates a new {@link THttpService} with the specified service implementation, supporting all thrift
* protocols and defaulting to the specified {@code defaultSerializationFormat} when the client doesn't
* specify one.
*
*
Currently, the only way to specify a serialization format is by using the HTTP session
* protocol and setting the {@code "Content-Type"} header to the appropriate
* {@link SerializationFormat#mediaType()}.
*
* @param implementation an implementation of {@code *.Iface} or {@code *.AsyncIface} service interface
* generated by the Apache Thrift compiler
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
*/
public static THttpService of(Object implementation,
SerializationFormat defaultSerializationFormat) {
return builder().addService(implementation)
.defaultSerializationFormat(defaultSerializationFormat)
.build();
}
/**
* Creates a new {@link THttpService} with the specified service implementation, supporting only the
* formats specified and defaulting to the specified {@code defaultSerializationFormat} when the client
* doesn't specify one.
*
*
Currently, the only way to specify a serialization format is by using the HTTP session protocol and
* setting the {@code "Content-Type"} header to the appropriate {@link SerializationFormat#mediaType()}.
*
* @param implementation an implementation of {@code *.Iface} or {@code *.AsyncIface} service interface
* generated by the Apache Thrift compiler
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
* @param otherSupportedSerializationFormats other serialization formats that should be supported by this
* service in addition to the default
*/
public static THttpService ofFormats(
Object implementation,
SerializationFormat defaultSerializationFormat,
SerializationFormat... otherSupportedSerializationFormats) {
requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
return ofFormats(implementation,
defaultSerializationFormat,
Arrays.asList(otherSupportedSerializationFormats));
}
/**
* Creates a new {@link THttpService} with the specified service implementation, supporting the protocols
* specified in {@code otherSupportedSerializationFormats} and defaulting to the specified
* {@code defaultSerializationFormat} when the client doesn't specify one.
*
*
Currently, the only way to specify a serialization format is by using the HTTP session protocol and
* setting the {@code "Content-Type"} header to the appropriate {@link SerializationFormat#mediaType()}.
*
* @param implementation an implementation of {@code *.Iface} or {@code *.AsyncIface} service interface
* generated by the Apache Thrift compiler
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
* @param otherSupportedSerializationFormats other serialization formats that should be supported by this
* service in addition to the default
*/
public static THttpService ofFormats(
Object implementation,
SerializationFormat defaultSerializationFormat,
Iterable otherSupportedSerializationFormats) {
return builder().addService(implementation)
.defaultSerializationFormat(defaultSerializationFormat)
.otherSerializationFormats(otherSupportedSerializationFormats)
.build();
}
/**
* Creates a new decorator that supports all thrift protocols and defaults to
* {@link ThriftSerializationFormats#BINARY TBinary} protocol when the client doesn't specify one.
*
* Currently, the only way to specify a serialization format is by using the HTTP session
* protocol and setting the {@code "Content-Type"} header to the appropriate
* {@link SerializationFormat#mediaType()}.
*/
public static Function super RpcService, THttpService> newDecorator() {
return newDecorator(ThriftSerializationFormats.BINARY);
}
/**
* Creates a new decorator that supports all thrift protocols and defaults to the specified
* {@code defaultSerializationFormat} when the client doesn't specify one.
* Currently, the only way to specify a serialization format is by using the HTTP session
* protocol and setting the {@code "Content-Type"} header to the appropriate
* {@link SerializationFormat#mediaType()}.
*
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
*/
public static Function super RpcService, THttpService> newDecorator(
SerializationFormat defaultSerializationFormat) {
return builder().defaultSerializationFormat(defaultSerializationFormat).newDecorator();
}
/**
* Creates a new decorator that supports only the formats specified and defaults to the specified
* {@code defaultSerializationFormat} when the client doesn't specify one.
* Currently, the only way to specify a serialization format is by using the HTTP session protocol and
* setting the {@code "Content-Type"} header to the appropriate {@link SerializationFormat#mediaType()}.
*
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
* @param otherSupportedSerializationFormats other serialization formats that should be supported by this
* service in addition to the default
*/
public static Function super RpcService, THttpService> newDecorator(
SerializationFormat defaultSerializationFormat,
SerializationFormat... otherSupportedSerializationFormats) {
requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
return newDecorator(defaultSerializationFormat,
ImmutableSet.copyOf(otherSupportedSerializationFormats));
}
/**
* Creates a new decorator that supports the protocols specified in
* {@code otherSupportedSerializationFormats} and defaults to the specified
* {@code defaultSerializationFormat} when the client doesn't specify one.
* Currently, the only way to specify a serialization format is by using the HTTP session protocol and
* setting the {@code "Content-Type"} header to the appropriate {@link SerializationFormat#mediaType()}.
*
* @param defaultSerializationFormat the default serialization format to use when not specified by the
* client
* @param otherSupportedSerializationFormats other serialization formats that should be supported by this
* service in addition to the default
*/
public static Function super RpcService, THttpService> newDecorator(
SerializationFormat defaultSerializationFormat,
Iterable otherSupportedSerializationFormats) {
return builder().defaultSerializationFormat(defaultSerializationFormat)
.otherSerializationFormats(otherSupportedSerializationFormats)
.newDecorator();
}
private final ThriftCallService thriftService;
private final SerializationFormat defaultSerializationFormat;
private final Set supportedSerializationFormats;
private final BiFunction super ServiceRequestContext, ? super Throwable, ? extends RpcResponse>
exceptionHandler;
THttpService(RpcService delegate, SerializationFormat defaultSerializationFormat,
Set supportedSerializationFormats,
BiFunction super ServiceRequestContext, ? super Throwable, ? extends RpcResponse>
exceptionHandler) {
super(delegate);
thriftService = findThriftService(delegate);
this.defaultSerializationFormat = defaultSerializationFormat;
this.supportedSerializationFormats = ImmutableSet.copyOf(supportedSerializationFormats);
this.exceptionHandler = exceptionHandler;
}
private static ThriftCallService findThriftService(Service, ?> delegate) {
final ThriftCallService thriftService = delegate.as(ThriftCallService.class);
checkState(thriftService != null,
"service being decorated is not a ThriftCallService: %s", delegate);
return thriftService;
}
/**
* Returns the information about the Thrift services being served.
*
* @return a {@link Map} whose key is a service name, which could be an empty string if this service
* is not multiplexed
*/
public Map entries() {
return thriftService.entries();
}
/**
* Returns the {@link SerializationFormat}s supported by this service.
*/
public Set supportedSerializationFormats() {
return supportedSerializationFormats;
}
/**
* Returns the default {@link SerializationFormat} of this service.
*/
public SerializationFormat defaultSerializationFormat() {
return defaultSerializationFormat;
}
@Override
public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception {
if (req.method() != HttpMethod.POST) {
return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
}
final SerializationFormat serializationFormat = determineSerializationFormat(req);
if (serializationFormat == null) {
return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE,
MediaType.PLAIN_TEXT_UTF_8, PROTOCOL_NOT_SUPPORTED);
}
if (!validateAcceptHeaders(req, serializationFormat)) {
return HttpResponse.of(HttpStatus.NOT_ACCEPTABLE,
MediaType.PLAIN_TEXT_UTF_8, ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
}
final CompletableFuture responseFuture = new CompletableFuture<>();
final HttpResponse res = HttpResponse.from(responseFuture);
ctx.logBuilder().serializationFormat(serializationFormat);
ctx.logBuilder().defer(RequestLogProperty.REQUEST_CONTENT);
req.aggregateWithPooledObjects(ctx.eventLoop(), ctx.alloc()).handle((aReq, cause) -> {
if (cause != null) {
final HttpResponse errorRes;
if (ctx.config().verboseResponses()) {
errorRes = HttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR,
MediaType.PLAIN_TEXT_UTF_8,
Exceptions.traceText(cause));
} else {
errorRes = HttpResponse.of(HttpStatus.INTERNAL_SERVER_ERROR);
}
responseFuture.complete(errorRes);
return null;
}
decodeAndInvoke(ctx, aReq, serializationFormat, responseFuture);
return null;
}).exceptionally(CompletionActions::log);
return res;
}
@Nullable
private SerializationFormat determineSerializationFormat(HttpRequest req) {
final HttpHeaders headers = req.headers();
final MediaType contentType = headers.contentType();
final SerializationFormat serializationFormat;
if (contentType != null) {
serializationFormat = findSerializationFormat(contentType);
if (serializationFormat == null) {
// Browser clients often send a non-Thrift content type.
// Choose the default serialization format for some vague media types.
if (!("text".equals(contentType.type()) &&
"plain".equals(contentType.subtype())) &&
!("application".equals(contentType.type()) &&
"octet-stream".equals(contentType.subtype()))) {
return null;
}
} else {
return serializationFormat;
}
}
return defaultSerializationFormat();
}
private static boolean validateAcceptHeaders(HttpRequest req, SerializationFormat serializationFormat) {
// If accept header is present, make sure it is sane. Currently, we do not support accept
// headers with a different format than the content type header.
final List acceptTypes = req.headers().accept();
return acceptTypes.isEmpty() || serializationFormat.mediaTypes().match(acceptTypes) != null;
}
@Nullable
private SerializationFormat findSerializationFormat(MediaType contentType) {
for (SerializationFormat format : supportedSerializationFormats) {
if (format.isAccepted(contentType)) {
return format;
}
}
return null;
}
private void decodeAndInvoke(
ServiceRequestContext ctx, AggregatedHttpRequest req,
SerializationFormat serializationFormat, CompletableFuture httpRes) {
final int seqId;
final ThriftFunction f;
final RpcRequest decodedReq;
try (HttpData content = req.content()) {
final TByteBufTransport inTransport = new TByteBufTransport(content.byteBuf());
final TProtocol inProto = ThriftSerializationFormats.protocolFactory(serializationFormat)
.getProtocol(inTransport);
final TMessage header;
final TBase, ?> args;
try {
header = inProto.readMessageBegin();
} catch (Exception e) {
logger.debug("{} Failed to decode a {} header:", ctx, serializationFormat, e);
final HttpResponse errorRes;
if (ctx.config().verboseResponses()) {
errorRes = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,
"Failed to decode a %s header: %s", serializationFormat,
Exceptions.traceText(e));
} else {
errorRes = HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,
"Failed to decode a %s header", serializationFormat);
}
httpRes.complete(errorRes);
return;
}
seqId = header.seqid;
final byte typeValue = header.type;
final int colonIdx = header.name.indexOf(':');
final String serviceName;
final String methodName;
if (colonIdx < 0) {
serviceName = "";
methodName = header.name;
} else {
serviceName = header.name.substring(0, colonIdx);
methodName = header.name.substring(colonIdx + 1);
}
// Basic sanity check. We usually should never fail here.
if (typeValue != TMessageType.CALL && typeValue != TMessageType.ONEWAY) {
final TApplicationException cause = new TApplicationException(
TApplicationException.INVALID_MESSAGE_TYPE,
"unexpected TMessageType: " + typeString(typeValue));
handlePreDecodeException(ctx, httpRes, cause, serializationFormat, seqId, methodName);
return;
}
// Ensure that such a method exists.
final ThriftServiceEntry entry = entries().get(serviceName);
f = entry != null ? entry.metadata.function(methodName) : null;
if (f == null) {
final TApplicationException cause = new TApplicationException(
TApplicationException.UNKNOWN_METHOD, "unknown method: " + header.name);
handlePreDecodeException(ctx, httpRes, cause, serializationFormat, seqId, methodName);
return;
}
// Decode the invocation parameters.
try {
args = f.newArgs();
args.read(inProto);
inProto.readMessageEnd();
decodedReq = toRpcRequest(f.serviceType(), header.name, args);
ctx.logBuilder().requestContent(decodedReq, new ThriftCall(header, args));
} catch (Exception e) {
// Failed to decode the invocation parameters.
logger.debug("{} Failed to decode Thrift arguments:", ctx, e);
final TApplicationException cause = new TApplicationException(
TApplicationException.PROTOCOL_ERROR, "failed to decode arguments: " + e);
handlePreDecodeException(ctx, httpRes, cause, serializationFormat, seqId, methodName);
return;
}
} finally {
ctx.logBuilder().requestContent(null, null);
}
invoke(ctx, serializationFormat, seqId, f, decodedReq, httpRes);
}
private static String typeString(byte typeValue) {
switch (typeValue) {
case TMessageType.CALL:
return "CALL";
case TMessageType.REPLY:
return "REPLY";
case TMessageType.EXCEPTION:
return "EXCEPTION";
case TMessageType.ONEWAY:
return "ONEWAY";
default:
return "UNKNOWN(" + (typeValue & 0xFF) + ')';
}
}
private void invoke(
ServiceRequestContext ctx, SerializationFormat serializationFormat, int seqId,
ThriftFunction func, RpcRequest call, CompletableFuture res) {
final RpcResponse reply;
try (SafeCloseable ignored = ctx.push()) {
reply = unwrap().serve(ctx, call);
} catch (Throwable cause) {
handleException(ctx, res, serializationFormat, seqId, func, cause);
return;
}
reply.handle((result, cause) -> {
if (func.isOneWay()) {
handleOneWaySuccess(ctx, reply, res, serializationFormat);
return null;
}
if (cause != null) {
handleException(ctx, res, serializationFormat, seqId, func, cause);
return null;
}
try {
handleSuccess(ctx, reply, res, serializationFormat, seqId, func, result);
} catch (Throwable t) {
handleException(ctx, res, serializationFormat, seqId, func, t);
return null;
}
return null;
}).exceptionally(CompletionActions::log);
}
private static RpcRequest toRpcRequest(Class> serviceType, String method, TBase, ?> thriftArgs) {
requireNonNull(thriftArgs, "thriftArgs");
// NB: The map returned by FieldMetaData.getStructMetaDataMap() is an EnumMap,
// so the parameter ordering is preserved correctly during iteration.
final Set extends TFieldIdEnum> fields =
FieldMetaData.getStructMetaDataMap(thriftArgs.getClass()).keySet();
// Handle the case where the number of arguments is 0 or 1.
final int numFields = fields.size();
switch (numFields) {
case 0:
return RpcRequest.of(serviceType, method);
case 1:
return RpcRequest.of(serviceType, method,
ThriftFieldAccess.get(thriftArgs, fields.iterator().next()));
}
// Handle the case where the number of arguments is greater than 1.
final List