All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.server.arrow.FlightServiceGrpcImpl Maven / Gradle / Ivy

The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.arrow;

import com.google.protobuf.ByteString;
import com.google.protobuf.ByteStringAccess;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.rpc.Code;
import io.deephaven.auth.AuthenticationException;
import io.deephaven.auth.AuthenticationRequestHandler;
import io.deephaven.auth.BasicAuthMarshaller;
import io.deephaven.engine.table.impl.perf.QueryPerformanceNugget;
import io.deephaven.engine.table.impl.perf.QueryPerformanceRecorder;
import io.deephaven.engine.table.impl.util.EngineMetrics;
import io.deephaven.extensions.barrage.BarrageStreamGenerator;
import io.deephaven.extensions.barrage.util.GrpcUtil;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.deephaven.proto.backplane.grpc.ExportNotification;
import io.deephaven.proto.backplane.grpc.WrappedAuthenticationRequest;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.session.SessionService;
import io.deephaven.server.session.SessionState;
import io.deephaven.server.session.TicketRouter;
import io.deephaven.auth.AuthContext;
import io.deephaven.util.SafeCloseable;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.inject.Inject;
import javax.inject.Singleton;
import java.io.InputStream;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;

@Singleton
public class FlightServiceGrpcImpl extends FlightServiceGrpc.FlightServiceImplBase {
    private static final Logger log = LoggerFactory.getLogger(FlightServiceGrpcImpl.class);

    private final ScheduledExecutorService executorService;
    private final BarrageStreamGenerator.Factory streamGeneratorFactory;
    private final SessionService sessionService;
    private final SessionService.ErrorTransformer errorTransformer;
    private final TicketRouter ticketRouter;
    private final ArrowFlightUtil.DoExchangeMarshaller.Factory doExchangeFactory;

    private final Map authRequestHandlers;

    @Inject
    public FlightServiceGrpcImpl(
            @Nullable final ScheduledExecutorService executorService,
            final BarrageStreamGenerator.Factory streamGeneratorFactory,
            final SessionService sessionService,
            final SessionService.ErrorTransformer errorTransformer,
            final TicketRouter ticketRouter,
            final ArrowFlightUtil.DoExchangeMarshaller.Factory doExchangeFactory,
            Map authRequestHandlers) {
        this.executorService = executorService;
        this.streamGeneratorFactory = streamGeneratorFactory;
        this.sessionService = sessionService;
        this.errorTransformer = errorTransformer;
        this.ticketRouter = ticketRouter;
        this.doExchangeFactory = doExchangeFactory;
        this.authRequestHandlers = authRequestHandlers;
    }

    @Override
    public StreamObserver handshake(
            @NotNull final StreamObserver responseObserver) {
        return new HandshakeObserver(responseObserver);
    }

    private final class HandshakeObserver implements StreamObserver {

        private boolean isComplete = false;
        private final StreamObserver responseObserver;

        private HandshakeObserver(StreamObserver responseObserver) {
            this.responseObserver = responseObserver;
        }

        @Override
        public void onNext(final Flight.HandshakeRequest value) {
            // handle the scenario where authentication headers initialized a session
            SessionState session = sessionService.getOptionalSession();
            if (session != null) {
                respondWithAuthTokenBin(session);
                return;
            }

            final AuthenticationRequestHandler.HandshakeResponseListener handshakeResponseListener =
                    (protocol, response) -> {
                        GrpcUtil.safelyComplete(responseObserver, Flight.HandshakeResponse.newBuilder()
                                .setProtocolVersion(protocol)
                                .setPayload(ByteStringAccess.wrap(response))
                                .build());
                    };

            final ByteString payload = value.getPayload();
            final long protocolVersion = value.getProtocolVersion();
            Optional auth;
            try {
                auth = login(BasicAuthMarshaller.AUTH_TYPE, protocolVersion, payload, handshakeResponseListener);
                if (auth.isEmpty()) {
                    final WrappedAuthenticationRequest req = WrappedAuthenticationRequest.parseFrom(payload);
                    auth = login(req.getType(), protocolVersion, req.getPayload(), handshakeResponseListener);
                }
            } catch (final AuthenticationException | InvalidProtocolBufferException err) {
                log.error().append("Authentication failed: ").append(err).endl();
                auth = Optional.empty();
            }

            if (auth.isEmpty()) {
                responseObserver.onError(
                        Exceptions.statusRuntimeException(Code.UNAUTHENTICATED, "Authentication details invalid"));
                return;
            }

            session = sessionService.newSession(auth.get());
            respondWithAuthTokenBin(session);
        }

        private Optional login(String type, long version, ByteString payload,
                AuthenticationRequestHandler.HandshakeResponseListener listener) throws AuthenticationException {
            AuthenticationRequestHandler handler = authRequestHandlers.get(type);
            if (handler == null) {
                log.info().append("No AuthenticationRequestHandler registered for type ").append(type).endl();
                return Optional.empty();
            }
            return handler.login(version, payload.asReadOnlyByteBuffer(), listener);
        }

        /** send the bearer token as an AuthTokenBin, as headers might have already been sent */
        private void respondWithAuthTokenBin(SessionState session) {
            isComplete = true;
            responseObserver.onNext(Flight.HandshakeResponse.newBuilder()
                    .setPayload(session.getExpiration().getTokenAsByteString())
                    .build());
            responseObserver.onCompleted();
        }

        @Override
        public void onError(final Throwable t) {
            // ignore
        }

        @Override
        public void onCompleted() {
            if (isComplete) {
                return;
            }
            responseObserver.onError(
                    Exceptions.statusRuntimeException(Code.UNAUTHENTICATED, "no authentication details provided"));
        }
    }

    @Override
    public void listFlights(
            @NotNull final Flight.Criteria request,
            @NotNull final StreamObserver responseObserver) {
        ticketRouter.visitFlightInfo(sessionService.getOptionalSession(), responseObserver::onNext);
        responseObserver.onCompleted();
    }

    @Override
    public void getFlightInfo(
            @NotNull final Flight.FlightDescriptor request,
            @NotNull final StreamObserver responseObserver) {
        final SessionState session = sessionService.getOptionalSession();

        final String description = "FlightService#getFlightInfo(request=" + request + ")";
        final QueryPerformanceRecorder queryPerformanceRecorder = QueryPerformanceRecorder.newQuery(
                description, session == null ? null : session.getSessionId(), QueryPerformanceNugget.DEFAULT_FACTORY);

        try (final SafeCloseable ignored = queryPerformanceRecorder.startQuery()) {
            final SessionState.ExportObject export =
                    ticketRouter.flightInfoFor(session, request, "request");

            if (session != null) {
                session.nonExport()
                        .queryPerformanceRecorder(queryPerformanceRecorder)
                        .require(export)
                        .onError(responseObserver)
                        .submit(() -> {
                            responseObserver.onNext(export.get());
                            responseObserver.onCompleted();
                        });
                return;
            }

            StatusRuntimeException exception = null;
            if (export.tryRetainReference()) {
                try {
                    if (export.getState() == ExportNotification.State.EXPORTED) {
                        GrpcUtil.safelyOnNext(responseObserver, export.get());
                        GrpcUtil.safelyComplete(responseObserver);
                    }
                } finally {
                    export.dropReference();
                }
            } else {
                exception = Exceptions.statusRuntimeException(Code.FAILED_PRECONDITION, "Could not find flight info");
                GrpcUtil.safelyError(responseObserver, exception);
            }

            if (queryPerformanceRecorder.endQuery() || exception != null) {
                EngineMetrics.getInstance().logQueryProcessingResults(queryPerformanceRecorder, exception);
            }
        }
    }

    @Override
    public void getSchema(
            @NotNull final Flight.FlightDescriptor request,
            @NotNull final StreamObserver responseObserver) {
        final SessionState session = sessionService.getOptionalSession();

        final String description = "FlightService#getSchema(request=" + request + ")";
        final QueryPerformanceRecorder queryPerformanceRecorder = QueryPerformanceRecorder.newQuery(
                description, session == null ? null : session.getSessionId(), QueryPerformanceNugget.DEFAULT_FACTORY);

        try (final SafeCloseable ignored = queryPerformanceRecorder.startQuery()) {
            final SessionState.ExportObject export =
                    ticketRouter.flightInfoFor(session, request, "request");

            if (session != null) {
                session.nonExport()
                        .queryPerformanceRecorder(queryPerformanceRecorder)
                        .require(export)
                        .onError(responseObserver)
                        .submit(() -> {
                            responseObserver.onNext(Flight.SchemaResult.newBuilder()
                                    .setSchema(export.get().getSchema())
                                    .build());
                            responseObserver.onCompleted();
                        });
                return;
            }

            StatusRuntimeException exception = null;
            if (export.tryRetainReference()) {
                try {
                    if (export.getState() == ExportNotification.State.EXPORTED) {
                        GrpcUtil.safelyOnNext(responseObserver, Flight.SchemaResult.newBuilder()
                                .setSchema(export.get().getSchema())
                                .build());
                        GrpcUtil.safelyComplete(responseObserver);
                    }
                } finally {
                    export.dropReference();
                }
            } else {
                exception = Exceptions.statusRuntimeException(Code.FAILED_PRECONDITION, "Could not find flight info");
                responseObserver.onError(exception);
            }

            if (queryPerformanceRecorder.endQuery() || exception != null) {
                EngineMetrics.getInstance().logQueryProcessingResults(queryPerformanceRecorder, exception);
            }
        }
    }

    public void doGetCustom(
            final Flight.Ticket request,
            final StreamObserver responseObserver) {
        ArrowFlightUtil.DoGetCustom(
                streamGeneratorFactory, sessionService.getCurrentSession(), ticketRouter, request, responseObserver);
    }

    /**
     * Establish a new DoPut bi-directional stream.
     *
     * @param responseObserver the observer to reply to
     * @return the observer that grpc can delegate received messages to
     */
    public StreamObserver doPutCustom(final StreamObserver responseObserver) {
        return new ArrowFlightUtil.DoPutObserver(
                sessionService.getCurrentSession(), ticketRouter, errorTransformer, responseObserver);
    }

    /**
     * Establish a new DoExchange bi-directional stream.
     *
     * @param responseObserver the observer to reply to
     * @return the observer that grpc can delegate received messages to
     */
    public StreamObserver doExchangeCustom(final StreamObserver responseObserver) {
        return doExchangeFactory.openExchange(sessionService.getCurrentSession(), responseObserver);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy