io.deephaven.server.arrow.FlightServiceGrpcImpl Maven / Gradle / Ivy
The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.arrow;
import com.google.protobuf.ByteString;
import com.google.protobuf.ByteStringAccess;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.rpc.Code;
import io.deephaven.auth.AuthenticationException;
import io.deephaven.auth.AuthenticationRequestHandler;
import io.deephaven.auth.BasicAuthMarshaller;
import io.deephaven.engine.table.impl.perf.QueryPerformanceNugget;
import io.deephaven.engine.table.impl.perf.QueryPerformanceRecorder;
import io.deephaven.engine.table.impl.util.EngineMetrics;
import io.deephaven.extensions.barrage.BarrageStreamGenerator;
import io.deephaven.extensions.barrage.util.GrpcUtil;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.deephaven.proto.backplane.grpc.ExportNotification;
import io.deephaven.proto.backplane.grpc.WrappedAuthenticationRequest;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.session.SessionService;
import io.deephaven.server.session.SessionState;
import io.deephaven.server.session.TicketRouter;
import io.deephaven.auth.AuthContext;
import io.deephaven.util.SafeCloseable;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import java.io.InputStream;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
@Singleton
public class FlightServiceGrpcImpl extends FlightServiceGrpc.FlightServiceImplBase {
private static final Logger log = LoggerFactory.getLogger(FlightServiceGrpcImpl.class);
private final ScheduledExecutorService executorService;
private final BarrageStreamGenerator.Factory streamGeneratorFactory;
private final SessionService sessionService;
private final SessionService.ErrorTransformer errorTransformer;
private final TicketRouter ticketRouter;
private final ArrowFlightUtil.DoExchangeMarshaller.Factory doExchangeFactory;
private final Map authRequestHandlers;
@Inject
public FlightServiceGrpcImpl(
@Nullable final ScheduledExecutorService executorService,
final BarrageStreamGenerator.Factory streamGeneratorFactory,
final SessionService sessionService,
final SessionService.ErrorTransformer errorTransformer,
final TicketRouter ticketRouter,
final ArrowFlightUtil.DoExchangeMarshaller.Factory doExchangeFactory,
Map authRequestHandlers) {
this.executorService = executorService;
this.streamGeneratorFactory = streamGeneratorFactory;
this.sessionService = sessionService;
this.errorTransformer = errorTransformer;
this.ticketRouter = ticketRouter;
this.doExchangeFactory = doExchangeFactory;
this.authRequestHandlers = authRequestHandlers;
}
@Override
public StreamObserver handshake(
@NotNull final StreamObserver responseObserver) {
return new HandshakeObserver(responseObserver);
}
private final class HandshakeObserver implements StreamObserver {
private boolean isComplete = false;
private final StreamObserver responseObserver;
private HandshakeObserver(StreamObserver responseObserver) {
this.responseObserver = responseObserver;
}
@Override
public void onNext(final Flight.HandshakeRequest value) {
// handle the scenario where authentication headers initialized a session
SessionState session = sessionService.getOptionalSession();
if (session != null) {
respondWithAuthTokenBin(session);
return;
}
final AuthenticationRequestHandler.HandshakeResponseListener handshakeResponseListener =
(protocol, response) -> {
GrpcUtil.safelyComplete(responseObserver, Flight.HandshakeResponse.newBuilder()
.setProtocolVersion(protocol)
.setPayload(ByteStringAccess.wrap(response))
.build());
};
final ByteString payload = value.getPayload();
final long protocolVersion = value.getProtocolVersion();
Optional auth;
try {
auth = login(BasicAuthMarshaller.AUTH_TYPE, protocolVersion, payload, handshakeResponseListener);
if (auth.isEmpty()) {
final WrappedAuthenticationRequest req = WrappedAuthenticationRequest.parseFrom(payload);
auth = login(req.getType(), protocolVersion, req.getPayload(), handshakeResponseListener);
}
} catch (final AuthenticationException | InvalidProtocolBufferException err) {
log.error().append("Authentication failed: ").append(err).endl();
auth = Optional.empty();
}
if (auth.isEmpty()) {
responseObserver.onError(
Exceptions.statusRuntimeException(Code.UNAUTHENTICATED, "Authentication details invalid"));
return;
}
session = sessionService.newSession(auth.get());
respondWithAuthTokenBin(session);
}
private Optional login(String type, long version, ByteString payload,
AuthenticationRequestHandler.HandshakeResponseListener listener) throws AuthenticationException {
AuthenticationRequestHandler handler = authRequestHandlers.get(type);
if (handler == null) {
log.info().append("No AuthenticationRequestHandler registered for type ").append(type).endl();
return Optional.empty();
}
return handler.login(version, payload.asReadOnlyByteBuffer(), listener);
}
/** send the bearer token as an AuthTokenBin, as headers might have already been sent */
private void respondWithAuthTokenBin(SessionState session) {
isComplete = true;
responseObserver.onNext(Flight.HandshakeResponse.newBuilder()
.setPayload(session.getExpiration().getTokenAsByteString())
.build());
responseObserver.onCompleted();
}
@Override
public void onError(final Throwable t) {
// ignore
}
@Override
public void onCompleted() {
if (isComplete) {
return;
}
responseObserver.onError(
Exceptions.statusRuntimeException(Code.UNAUTHENTICATED, "no authentication details provided"));
}
}
@Override
public void listFlights(
@NotNull final Flight.Criteria request,
@NotNull final StreamObserver responseObserver) {
ticketRouter.visitFlightInfo(sessionService.getOptionalSession(), responseObserver::onNext);
responseObserver.onCompleted();
}
@Override
public void getFlightInfo(
@NotNull final Flight.FlightDescriptor request,
@NotNull final StreamObserver responseObserver) {
final SessionState session = sessionService.getOptionalSession();
final String description = "FlightService#getFlightInfo(request=" + request + ")";
final QueryPerformanceRecorder queryPerformanceRecorder = QueryPerformanceRecorder.newQuery(
description, session == null ? null : session.getSessionId(), QueryPerformanceNugget.DEFAULT_FACTORY);
try (final SafeCloseable ignored = queryPerformanceRecorder.startQuery()) {
final SessionState.ExportObject export =
ticketRouter.flightInfoFor(session, request, "request");
if (session != null) {
session.nonExport()
.queryPerformanceRecorder(queryPerformanceRecorder)
.require(export)
.onError(responseObserver)
.submit(() -> {
responseObserver.onNext(export.get());
responseObserver.onCompleted();
});
return;
}
StatusRuntimeException exception = null;
if (export.tryRetainReference()) {
try {
if (export.getState() == ExportNotification.State.EXPORTED) {
GrpcUtil.safelyOnNext(responseObserver, export.get());
GrpcUtil.safelyComplete(responseObserver);
}
} finally {
export.dropReference();
}
} else {
exception = Exceptions.statusRuntimeException(Code.FAILED_PRECONDITION, "Could not find flight info");
GrpcUtil.safelyError(responseObserver, exception);
}
if (queryPerformanceRecorder.endQuery() || exception != null) {
EngineMetrics.getInstance().logQueryProcessingResults(queryPerformanceRecorder, exception);
}
}
}
@Override
public void getSchema(
@NotNull final Flight.FlightDescriptor request,
@NotNull final StreamObserver responseObserver) {
final SessionState session = sessionService.getOptionalSession();
final String description = "FlightService#getSchema(request=" + request + ")";
final QueryPerformanceRecorder queryPerformanceRecorder = QueryPerformanceRecorder.newQuery(
description, session == null ? null : session.getSessionId(), QueryPerformanceNugget.DEFAULT_FACTORY);
try (final SafeCloseable ignored = queryPerformanceRecorder.startQuery()) {
final SessionState.ExportObject export =
ticketRouter.flightInfoFor(session, request, "request");
if (session != null) {
session.nonExport()
.queryPerformanceRecorder(queryPerformanceRecorder)
.require(export)
.onError(responseObserver)
.submit(() -> {
responseObserver.onNext(Flight.SchemaResult.newBuilder()
.setSchema(export.get().getSchema())
.build());
responseObserver.onCompleted();
});
return;
}
StatusRuntimeException exception = null;
if (export.tryRetainReference()) {
try {
if (export.getState() == ExportNotification.State.EXPORTED) {
GrpcUtil.safelyOnNext(responseObserver, Flight.SchemaResult.newBuilder()
.setSchema(export.get().getSchema())
.build());
GrpcUtil.safelyComplete(responseObserver);
}
} finally {
export.dropReference();
}
} else {
exception = Exceptions.statusRuntimeException(Code.FAILED_PRECONDITION, "Could not find flight info");
responseObserver.onError(exception);
}
if (queryPerformanceRecorder.endQuery() || exception != null) {
EngineMetrics.getInstance().logQueryProcessingResults(queryPerformanceRecorder, exception);
}
}
}
public void doGetCustom(
final Flight.Ticket request,
final StreamObserver responseObserver) {
ArrowFlightUtil.DoGetCustom(
streamGeneratorFactory, sessionService.getCurrentSession(), ticketRouter, request, responseObserver);
}
/**
* Establish a new DoPut bi-directional stream.
*
* @param responseObserver the observer to reply to
* @return the observer that grpc can delegate received messages to
*/
public StreamObserver doPutCustom(final StreamObserver responseObserver) {
return new ArrowFlightUtil.DoPutObserver(
sessionService.getCurrentSession(), ticketRouter, errorTransformer, responseObserver);
}
/**
* Establish a new DoExchange bi-directional stream.
*
* @param responseObserver the observer to reply to
* @return the observer that grpc can delegate received messages to
*/
public StreamObserver doExchangeCustom(final StreamObserver responseObserver) {
return doExchangeFactory.openExchange(sessionService.getCurrentSession(), responseObserver);
}
}