All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.server.util.GrpcServiceOverrideBuilder Maven / Gradle / Ivy

The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.util;

import com.google.rpc.Code;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.browserstreaming.BrowserStream;
import io.deephaven.server.browserstreaming.BrowserStreamInterceptor;
import io.deephaven.server.browserstreaming.StreamData;
import io.deephaven.server.session.SessionService;
import io.deephaven.server.session.SessionState;
import io.deephaven.io.logger.Logger;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.ServerCalls;
import io.grpc.stub.StreamObserver;
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class GrpcServiceOverrideBuilder {
    private static class GrpcOverride {
        private final MethodDescriptor method;
        private final ServerCallHandler handler;

        private GrpcOverride(@NotNull MethodDescriptor method,
                @NotNull ServerCallHandler handler) {
            this.method = method;
            this.handler = handler;
        }

        private void addMethod(ServerServiceDefinition.Builder builder) {
            builder.addMethod(method, handler);
        }
    }

    private final ServerServiceDefinition baseDefinition;
    private final List> overrides = new ArrayList<>();
    private final BrowserStreamInterceptor browserStreamInterceptor = new BrowserStreamInterceptor();
    private boolean needsBrowserInterceptor = false;

    private GrpcServiceOverrideBuilder(ServerServiceDefinition baseDefinition) {
        this.baseDefinition = baseDefinition;
    }

    public static GrpcServiceOverrideBuilder newBuilder(ServerServiceDefinition baseDefinition) {
        return new GrpcServiceOverrideBuilder(baseDefinition);
    }

    private  GrpcServiceOverrideBuilder override(MethodDescriptor method,
            ServerCalls.BidiStreamingMethod handler) {
        validateMethodType(method.getType(), MethodDescriptor.MethodType.BIDI_STREAMING);
        overrides.add(new GrpcOverride<>(method, ServerCalls.asyncBidiStreamingCall(handler)));
        return this;
    }

    private  GrpcServiceOverrideBuilder override(MethodDescriptor method,
            ServerCalls.ServerStreamingMethod handler) {
        validateMethodType(method.getType(), MethodDescriptor.MethodType.SERVER_STREAMING);
        overrides.add(new GrpcOverride<>(method, ServerCalls.asyncServerStreamingCall(handler)));
        return this;
    }

    private  GrpcServiceOverrideBuilder override(MethodDescriptor method,
            ServerCalls.UnaryMethod handler) {
        validateMethodType(method.getType(), MethodDescriptor.MethodType.UNARY);
        overrides.add(new GrpcOverride<>(method, ServerCalls.asyncUnaryCall(handler)));
        return this;
    }

    public  GrpcServiceOverrideBuilder onServerStreamingOverride(
            final Delegate delegate,
            final MethodDescriptor descriptor,
            final MethodDescriptor.Marshaller requestMarshaller,
            final MethodDescriptor.Marshaller responseMarshaller) {
        return override(MethodDescriptor.newBuilder()
                .setType(MethodDescriptor.MethodType.SERVER_STREAMING)
                .setFullMethodName(descriptor.getFullMethodName())
                .setSampledToLocalTracing(false)
                .setRequestMarshaller(requestMarshaller)
                .setResponseMarshaller(responseMarshaller)
                .setSchemaDescriptor(descriptor.getSchemaDescriptor())
                .build(), new OpenBrowserStreamMethod<>(delegate));
    }

    public  GrpcServiceOverrideBuilder onBidiOverride(
            final BidiDelegate delegate,
            final MethodDescriptor descriptor,
            final MethodDescriptor.Marshaller requestMarshaller,
            final MethodDescriptor.Marshaller responseMarshaller) {
        return override(MethodDescriptor.newBuilder()
                .setType(MethodDescriptor.MethodType.BIDI_STREAMING)
                .setFullMethodName(descriptor.getFullMethodName())
                .setSampledToLocalTracing(false)
                .setRequestMarshaller(requestMarshaller)
                .setResponseMarshaller(responseMarshaller)
                .setSchemaDescriptor(descriptor.getSchemaDescriptor())
                .build(), new BidiStreamMethod<>(delegate));
    }

    public  GrpcServiceOverrideBuilder onBidiOverrideWithBrowserSupport(
            final BidiDelegate delegate,
            final MethodDescriptor bidiDescriptor,
            final MethodDescriptor openDescriptor,
            final MethodDescriptor nextDescriptor,
            final MethodDescriptor.Marshaller requestMarshaller,
            final MethodDescriptor.Marshaller responseMarshaller,
            final MethodDescriptor.Marshaller nextResponseMarshaller,
            BrowserStream.Mode mode,
            Logger log, SessionService sessionService) {
        return this
                .onBidiOverride(
                        delegate,
                        bidiDescriptor,
                        requestMarshaller,
                        responseMarshaller)
                .onBidiBrowserSupport(delegate,
                        openDescriptor,
                        nextDescriptor,
                        requestMarshaller,
                        responseMarshaller,
                        nextResponseMarshaller,
                        mode,
                        log,
                        sessionService);
    }

    public  GrpcServiceOverrideBuilder onBidiBrowserSupport(
            final BidiDelegate delegate,
            final MethodDescriptor openDescriptor,
            final MethodDescriptor nextDescriptor,
            final MethodDescriptor.Marshaller requestMarshaller,
            final MethodDescriptor.Marshaller responseMarshaller,
            final MethodDescriptor.Marshaller nextResponseMarshaller,
            BrowserStream.Mode mode,
            Logger log, SessionService sessionService) {
        BrowserStreamMethod method =
                new BrowserStreamMethod<>(log, mode, delegate, sessionService);
        needsBrowserInterceptor = true;
        return this
                .override(MethodDescriptor.newBuilder()
                        .setType(MethodDescriptor.MethodType.SERVER_STREAMING)
                        .setFullMethodName(openDescriptor.getFullMethodName())
                        .setSampledToLocalTracing(false)
                        .setRequestMarshaller(requestMarshaller)
                        .setResponseMarshaller(responseMarshaller)
                        .setSchemaDescriptor(openDescriptor.getSchemaDescriptor())
                        .build(), method.open())
                .override(MethodDescriptor.newBuilder()
                        .setType(MethodDescriptor.MethodType.UNARY)
                        .setFullMethodName(nextDescriptor.getFullMethodName())
                        .setSampledToLocalTracing(false)
                        .setRequestMarshaller(requestMarshaller)
                        .setResponseMarshaller(nextResponseMarshaller)
                        .setSchemaDescriptor(nextDescriptor.getSchemaDescriptor())
                        .build(), method.next());
    }

    public ServerServiceDefinition build() {
        final String service = baseDefinition.getServiceDescriptor().getName();

        final Set overrideMethodNames = overrides.stream()
                .map(o -> o.method.getFullMethodName())
                .collect(Collectors.toSet());

        // Make sure we preserve SchemaDescriptor fields on methods so that gRPC reflection still works.
        final ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(service)
                .setSchemaDescriptor(baseDefinition.getServiceDescriptor().getSchemaDescriptor());

        // define descriptor overrides
        overrides.forEach(o -> serviceDescriptorBuilder.addMethod(o.method));

        // keep non-overridden descriptors
        baseDefinition.getServiceDescriptor().getMethods().stream()
                .filter(d -> !overrideMethodNames.contains(d.getFullMethodName()))
                .forEach(serviceDescriptorBuilder::addMethod);

        final ServiceDescriptor serviceDescriptor = serviceDescriptorBuilder.build();
        ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceDescriptor);

        // add method overrides
        overrides.forEach(dp -> dp.addMethod(serviceBuilder));

        // add non-overridden methods
        baseDefinition.getMethods().stream()
                .filter(d -> !overrideMethodNames.contains(d.getMethodDescriptor().getFullMethodName()))
                .forEach(serviceBuilder::addMethod);

        ServerServiceDefinition serviceDef = serviceBuilder.build();
        if (needsBrowserInterceptor) {
            return ServerInterceptors.intercept(serviceDef, browserStreamInterceptor);
        }
        return serviceDef;
    }

    @FunctionalInterface
    public interface Delegate {
        void doInvoke(final ReqT request, final StreamObserver responseObserver);
    }

    public static final class BrowserStreamMethod {
        private final BrowserStream.Factory factory;
        private final SessionService sessionService;
        private final Logger log;

        public BrowserStreamMethod(Logger log, BrowserStream.Mode mode, BidiDelegate delegate,
                SessionService sessionService) {
            this.log = log;
            this.factory = BrowserStream.factory(mode, delegate);
            this.sessionService = sessionService;
        }

        public ServerCalls.ServerStreamingMethod open() {
            return this::invokeOpen;
        }

        public ServerCalls.UnaryMethod next() {
            return this::invokeNext;
        }

        public void invokeOpen(
                @NotNull final ReqT request,
                @NotNull final StreamObserver responseObserver) {
            StreamData streamData = StreamData.STREAM_DATA_KEY.get();
            SessionState session = sessionService.getCurrentSession();
            if (streamData == null) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                        "no x-deephaven-stream headers, cannot handle open request");
            }

            BrowserStream browserStream = factory.create(session, responseObserver);
            browserStream.onMessageReceived(request, streamData);

            if (!streamData.isHalfClose()) {
                // if this isn't a half-close, we should export it for later calls - if it is, the client won't send
                // more messages
                session.newExport(streamData.getRpcTicket(), "rpcTicket")
                        // not setting an onError here, failure can only happen if the session ends
                        .submit(() -> browserStream);
            }
        }

        public void invokeNext(
                @NotNull final ReqT request,
                @NotNull final StreamObserver responseObserver) {
            StreamData streamData = StreamData.STREAM_DATA_KEY.get();
            if (streamData == null || streamData.getRpcTicket() == null) {
                throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT,
                        "no x-deephaven-stream headers, cannot handle next request");
            }
            final SessionState session = sessionService.getCurrentSession();

            final SessionState.ExportObject> browserStream =
                    session.getExport(streamData.getRpcTicket(), "rpcTicket");

            session.nonExport()
                    .require(browserStream)
                    .onError(responseObserver)
                    .submit(() -> {
                        browserStream.get().onMessageReceived(request, streamData);
                        responseObserver.onNext(null);// TODO simple response payload
                        responseObserver.onCompleted();
                    });
        }
    }

    public static class OpenBrowserStreamMethod implements ServerCalls.ServerStreamingMethod {

        private final Delegate delegate;

        public OpenBrowserStreamMethod(final Delegate delegate) {
            this.delegate = delegate;
        }

        @Override
        public void invoke(final ReqT request, final StreamObserver responseObserver) {
            final ServerCallStreamObserver serverCall = (ServerCallStreamObserver) responseObserver;
            serverCall.disableAutoInboundFlowControl();
            serverCall.request(Integer.MAX_VALUE);
            delegate.doInvoke(request, responseObserver);
        }
    }

    @FunctionalInterface
    public interface BidiDelegate {
        StreamObserver doInvoke(final StreamObserver responseObserver);
    }

    public static class BidiStreamMethod implements ServerCalls.BidiStreamingMethod {
        private final BidiDelegate delegate;

        public BidiStreamMethod(final BidiDelegate delegate) {
            this.delegate = delegate;
        }

        @Override
        public StreamObserver invoke(final StreamObserver responseObserver) {
            final ServerCallStreamObserver serverCall = (ServerCallStreamObserver) responseObserver;
            serverCall.disableAutoInboundFlowControl();
            serverCall.request(Integer.MAX_VALUE);
            return delegate.doInvoke(responseObserver);
        }
    }

    private static void validateMethodType(MethodDescriptor.MethodType methodType,
            MethodDescriptor.MethodType handlerType) {
        if (methodType != handlerType) {
            throw new IllegalArgumentException("Provided method's type (" + methodType.name()
                    + ") does not match handler's type of " + handlerType.name());
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy