io.deephaven.server.util.GrpcServiceOverrideBuilder Maven / Gradle / Ivy
The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.util;
import com.google.rpc.Code;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.browserstreaming.BrowserStream;
import io.deephaven.server.browserstreaming.BrowserStreamInterceptor;
import io.deephaven.server.browserstreaming.StreamData;
import io.deephaven.server.session.SessionService;
import io.deephaven.server.session.SessionState;
import io.deephaven.io.logger.Logger;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.ServerCalls;
import io.grpc.stub.StreamObserver;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class GrpcServiceOverrideBuilder {
private static class GrpcOverride {
private final MethodDescriptor method;
private final ServerCallHandler handler;
private GrpcOverride(@NotNull MethodDescriptor method,
@NotNull ServerCallHandler handler) {
this.method = method;
this.handler = handler;
}
private void addMethod(ServerServiceDefinition.Builder builder) {
builder.addMethod(method, handler);
}
}
private final ServerServiceDefinition baseDefinition;
private final List> overrides = new ArrayList<>();
private final BrowserStreamInterceptor browserStreamInterceptor = new BrowserStreamInterceptor();
private boolean needsBrowserInterceptor = false;
private GrpcServiceOverrideBuilder(ServerServiceDefinition baseDefinition) {
this.baseDefinition = baseDefinition;
}
public static GrpcServiceOverrideBuilder newBuilder(ServerServiceDefinition baseDefinition) {
return new GrpcServiceOverrideBuilder(baseDefinition);
}
private GrpcServiceOverrideBuilder override(MethodDescriptor method,
ServerCalls.BidiStreamingMethod handler) {
validateMethodType(method.getType(), MethodDescriptor.MethodType.BIDI_STREAMING);
overrides.add(new GrpcOverride<>(method, ServerCalls.asyncBidiStreamingCall(handler)));
return this;
}
private GrpcServiceOverrideBuilder override(MethodDescriptor method,
ServerCalls.ServerStreamingMethod handler) {
validateMethodType(method.getType(), MethodDescriptor.MethodType.SERVER_STREAMING);
overrides.add(new GrpcOverride<>(method, ServerCalls.asyncServerStreamingCall(handler)));
return this;
}
private GrpcServiceOverrideBuilder override(MethodDescriptor method,
ServerCalls.UnaryMethod handler) {
validateMethodType(method.getType(), MethodDescriptor.MethodType.UNARY);
overrides.add(new GrpcOverride<>(method, ServerCalls.asyncUnaryCall(handler)));
return this;
}
public GrpcServiceOverrideBuilder onServerStreamingOverride(
final Delegate delegate,
final MethodDescriptor, ?> descriptor,
final MethodDescriptor.Marshaller requestMarshaller,
final MethodDescriptor.Marshaller responseMarshaller) {
return override(MethodDescriptor.newBuilder()
.setType(MethodDescriptor.MethodType.SERVER_STREAMING)
.setFullMethodName(descriptor.getFullMethodName())
.setSampledToLocalTracing(false)
.setRequestMarshaller(requestMarshaller)
.setResponseMarshaller(responseMarshaller)
.setSchemaDescriptor(descriptor.getSchemaDescriptor())
.build(), new OpenBrowserStreamMethod<>(delegate));
}
public GrpcServiceOverrideBuilder onBidiOverride(
final BidiDelegate delegate,
final MethodDescriptor, ?> descriptor,
final MethodDescriptor.Marshaller requestMarshaller,
final MethodDescriptor.Marshaller responseMarshaller) {
return override(MethodDescriptor.newBuilder()
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setFullMethodName(descriptor.getFullMethodName())
.setSampledToLocalTracing(false)
.setRequestMarshaller(requestMarshaller)
.setResponseMarshaller(responseMarshaller)
.setSchemaDescriptor(descriptor.getSchemaDescriptor())
.build(), new BidiStreamMethod<>(delegate));
}
public GrpcServiceOverrideBuilder onBidiOverrideWithBrowserSupport(
final BidiDelegate delegate,
final MethodDescriptor, ?> bidiDescriptor,
final MethodDescriptor, ?> openDescriptor,
final MethodDescriptor, ?> nextDescriptor,
final MethodDescriptor.Marshaller requestMarshaller,
final MethodDescriptor.Marshaller responseMarshaller,
final MethodDescriptor.Marshaller nextResponseMarshaller,
BrowserStream.Mode mode,
Logger log, SessionService sessionService) {
return this
.onBidiOverride(
delegate,
bidiDescriptor,
requestMarshaller,
responseMarshaller)
.onBidiBrowserSupport(delegate,
openDescriptor,
nextDescriptor,
requestMarshaller,
responseMarshaller,
nextResponseMarshaller,
mode,
log,
sessionService);
}
public GrpcServiceOverrideBuilder onBidiBrowserSupport(
final BidiDelegate delegate,
final MethodDescriptor, ?> openDescriptor,
final MethodDescriptor, ?> nextDescriptor,
final MethodDescriptor.Marshaller requestMarshaller,
final MethodDescriptor.Marshaller responseMarshaller,
final MethodDescriptor.Marshaller nextResponseMarshaller,
BrowserStream.Mode mode,
Logger log, SessionService sessionService) {
BrowserStreamMethod method =
new BrowserStreamMethod<>(log, mode, delegate, sessionService);
needsBrowserInterceptor = true;
return this
.override(MethodDescriptor.newBuilder()
.setType(MethodDescriptor.MethodType.SERVER_STREAMING)
.setFullMethodName(openDescriptor.getFullMethodName())
.setSampledToLocalTracing(false)
.setRequestMarshaller(requestMarshaller)
.setResponseMarshaller(responseMarshaller)
.setSchemaDescriptor(openDescriptor.getSchemaDescriptor())
.build(), method.open())
.override(MethodDescriptor.newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName(nextDescriptor.getFullMethodName())
.setSampledToLocalTracing(false)
.setRequestMarshaller(requestMarshaller)
.setResponseMarshaller(nextResponseMarshaller)
.setSchemaDescriptor(nextDescriptor.getSchemaDescriptor())
.build(), method.next());
}
public ServerServiceDefinition build() {
final String service = baseDefinition.getServiceDescriptor().getName();
final Set overrideMethodNames = overrides.stream()
.map(o -> o.method.getFullMethodName())
.collect(Collectors.toSet());
// Make sure we preserve SchemaDescriptor fields on methods so that gRPC reflection still works.
final ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(service)
.setSchemaDescriptor(baseDefinition.getServiceDescriptor().getSchemaDescriptor());
// define descriptor overrides
overrides.forEach(o -> serviceDescriptorBuilder.addMethod(o.method));
// keep non-overridden descriptors
baseDefinition.getServiceDescriptor().getMethods().stream()
.filter(d -> !overrideMethodNames.contains(d.getFullMethodName()))
.forEach(serviceDescriptorBuilder::addMethod);
final ServiceDescriptor serviceDescriptor = serviceDescriptorBuilder.build();
ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceDescriptor);
// add method overrides
overrides.forEach(dp -> dp.addMethod(serviceBuilder));
// add non-overridden methods
baseDefinition.getMethods().stream()
.filter(d -> !overrideMethodNames.contains(d.getMethodDescriptor().getFullMethodName()))
.forEach(serviceBuilder::addMethod);
ServerServiceDefinition serviceDef = serviceBuilder.build();
if (needsBrowserInterceptor) {
return ServerInterceptors.intercept(serviceDef, browserStreamInterceptor);
}
return serviceDef;
}
@FunctionalInterface
public interface Delegate {
void doInvoke(final ReqT request, final StreamObserver responseObserver);
}
public static final class BrowserStreamMethod {
private final BrowserStream.Factory factory;
private final SessionService sessionService;
private final Logger log;
public BrowserStreamMethod(Logger log, BrowserStream.Mode mode, BidiDelegate delegate,
SessionService sessionService) {
this.log = log;
this.factory = BrowserStream.factory(mode, delegate);
this.sessionService = sessionService;
}
public ServerCalls.ServerStreamingMethod open() {
return this::invokeOpen;
}
public ServerCalls.UnaryMethod next() {
return this::invokeNext;
}
public void invokeOpen(
@NotNull final ReqT request,
@NotNull final StreamObserver responseObserver) {
StreamData streamData = StreamData.STREAM_DATA_KEY.get();
SessionState session = sessionService.getCurrentSession();
if (streamData == null) {
throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
"no x-deephaven-stream headers, cannot handle open request");
}
BrowserStream browserStream = factory.create(session, responseObserver);
browserStream.onMessageReceived(request, streamData);
if (!streamData.isHalfClose()) {
// if this isn't a half-close, we should export it for later calls - if it is, the client won't send
// more messages
session.newExport(streamData.getRpcTicket(), "rpcTicket")
// not setting an onError here, failure can only happen if the session ends
.submit(() -> browserStream);
}
}
public void invokeNext(
@NotNull final ReqT request,
@NotNull final StreamObserver responseObserver) {
StreamData streamData = StreamData.STREAM_DATA_KEY.get();
if (streamData == null || streamData.getRpcTicket() == null) {
throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
"no x-deephaven-stream headers, cannot handle next request");
}
final SessionState session = sessionService.getCurrentSession();
final SessionState.ExportObject> browserStream =
session.getExport(streamData.getRpcTicket(), "rpcTicket");
session.nonExport()
.require(browserStream)
.onError(responseObserver)
.submit(() -> {
browserStream.get().onMessageReceived(request, streamData);
responseObserver.onNext(null);// TODO simple response payload
responseObserver.onCompleted();
});
}
}
public static class OpenBrowserStreamMethod implements ServerCalls.ServerStreamingMethod {
private final Delegate delegate;
public OpenBrowserStreamMethod(final Delegate delegate) {
this.delegate = delegate;
}
@Override
public void invoke(final ReqT request, final StreamObserver responseObserver) {
final ServerCallStreamObserver serverCall = (ServerCallStreamObserver) responseObserver;
serverCall.disableAutoInboundFlowControl();
serverCall.request(Integer.MAX_VALUE);
delegate.doInvoke(request, responseObserver);
}
}
@FunctionalInterface
public interface BidiDelegate {
StreamObserver doInvoke(final StreamObserver responseObserver);
}
public static class BidiStreamMethod implements ServerCalls.BidiStreamingMethod {
private final BidiDelegate delegate;
public BidiStreamMethod(final BidiDelegate delegate) {
this.delegate = delegate;
}
@Override
public StreamObserver invoke(final StreamObserver responseObserver) {
final ServerCallStreamObserver serverCall = (ServerCallStreamObserver) responseObserver;
serverCall.disableAutoInboundFlowControl();
serverCall.request(Integer.MAX_VALUE);
return delegate.doInvoke(responseObserver);
}
}
private static void validateMethodType(MethodDescriptor.MethodType methodType,
MethodDescriptor.MethodType handlerType) {
if (methodType != handlerType) {
throw new IllegalArgumentException("Provided method's type (" + methodType.name()
+ ") does not match handler's type of " + handlerType.name());
}
}
}