All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.server.session.SessionService Maven / Gradle / Ivy

The newest version!
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.server.session;

import com.github.f4b6a3.uuid.UuidCreator;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.protobuf.ByteString;
import com.google.rpc.Code;
import io.deephaven.auth.AuthenticationException;
import io.deephaven.auth.AuthenticationRequestHandler;
import io.deephaven.configuration.Configuration;
import io.deephaven.extensions.barrage.util.GrpcUtil;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.deephaven.proto.backplane.grpc.TerminationNotificationResponse;
import io.deephaven.proto.util.Exceptions;
import io.deephaven.server.util.Scheduler;
import io.deephaven.auth.AuthContext;
import io.deephaven.util.process.ProcessEnvironment;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.apache.arrow.flight.auth2.Auth2Constants;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

import javax.inject.Inject;
import javax.inject.Named;
import javax.inject.Singleton;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

@Singleton
public class SessionService {
    private static final Logger log = LoggerFactory.getLogger(SessionService.class);

    /**
     * Implementations of error transformer give the server one last chance to convert errors to useful messages before
     * responding to gRPC users.
     */
    @FunctionalInterface
    public interface ErrorTransformer {
        StatusRuntimeException transform(Throwable t);
    }

    @Singleton
    public static class ObfuscatingErrorTransformer implements ErrorTransformer {
        @VisibleForTesting
        static final int MAX_STACK_TRACE_CAUSAL_DEPTH = 25;
        private static final int MAX_CACHE_BUILDER_SIZE = 1009;
        private static final int MAX_CACHE_DURATION_MIN = 1;

        private final Cache idCache;

        @Inject
        public ObfuscatingErrorTransformer() {
            idCache = CacheBuilder.newBuilder()
                    .expireAfterAccess(MAX_CACHE_DURATION_MIN, TimeUnit.MINUTES)
                    .maximumSize(MAX_CACHE_BUILDER_SIZE)
                    .weakKeys()
                    .build();
        }

        @Override
        public StatusRuntimeException transform(final Throwable err) {
            if (err instanceof StatusRuntimeException) {
                final StatusRuntimeException sre = (StatusRuntimeException) err;
                if (sre.getStatus().getCode().equals(Status.UNAUTHENTICATED.getCode())) {
                    log.debug().append("ignoring unauthenticated request").endl();
                } else if (sre.getStatus().getCode().equals(Status.CANCELLED.getCode())) {
                    log.debug().append("ignoring cancelled request").endl();
                } else {
                    log.error().append(sre).endl();
                }
                return sre;
            } else if (err instanceof InterruptedException) {
                return securelyWrapError(err, Code.UNAVAILABLE);
            } else {
                return securelyWrapError(err, Code.INVALID_ARGUMENT);
            }
        }

        private StatusRuntimeException securelyWrapError(@NotNull final Throwable err, final Code statusCode) {
            UUID errorId;
            final boolean shouldLog;

            synchronized (idCache) {
                errorId = idCache.getIfPresent(err);
                shouldLog = errorId == null;

                int currDepth = 0;
                // @formatter:off
                for (Throwable causeToCheck = err.getCause();
                     errorId == null && ++currDepth < MAX_STACK_TRACE_CAUSAL_DEPTH && causeToCheck != null;
                     causeToCheck = causeToCheck.getCause()) {
                    // @formatter:on
                    errorId = idCache.getIfPresent(causeToCheck);
                }

                if (errorId == null) {
                    errorId = UuidCreator.getRandomBased();
                }

                // @formatter:off
                for (Throwable throwableToAdd = err;
                     currDepth > 0;
                     throwableToAdd = throwableToAdd.getCause(), --currDepth) {
                    // @formatter:on
                    if (throwableToAdd.getStackTrace().length > 0) {
                        // Note that stackless exceptions are singletons, so it would be a bad idea to cache them
                        idCache.put(throwableToAdd, errorId);
                    }
                }
            }

            if (shouldLog) {
                // this is a new top-level error; log it, possibly using an existing errorId
                log.error().append("Internal Error '").append(errorId.toString()).append("' ").append(err).endl();
            }

            return Exceptions.statusRuntimeException(statusCode, "Details Logged w/ID '" + errorId + "'");
        }
    }

    static final long MIN_COOKIE_EXPIRE_MS = 10_000; // 10 seconds
    private static final int MAX_STACK_TRACE_CAUSAL_DEPTH =
            Configuration.getInstance().getIntegerForClassWithDefault(SessionService.class,
                    "maxStackTraceCausedByDepth", 20);
    private static final int MAX_STACK_TRACE_DEPTH =
            Configuration.getInstance().getIntegerForClassWithDefault(SessionService.class,
                    "maxStackTraceDepth", 50);

    private final Scheduler scheduler;
    private final SessionState.Factory sessionFactory;

    private final long tokenExpireMs;
    private final long tokenRotateMs;

    private final Map tokenToSession = new ConcurrentHashMap<>();
    private final Deque outstandingCookies = new ConcurrentLinkedDeque<>();
    private boolean cleanupJobInstalled = false;
    private final SessionCleanupJob sessionCleanupJob = new SessionCleanupJob();

    private final List terminationListeners = new CopyOnWriteArrayList<>();

    private final Map authRequestHandlers;

    private final SessionListener sessionListener;

    @Inject
    public SessionService(final Scheduler scheduler, final SessionState.Factory sessionFactory,
            @Named("session.tokenExpireMs") final long tokenExpireMs,
            Map authRequestHandlers,
            Set sessionListeners) {
        this.scheduler = scheduler;
        this.sessionFactory = sessionFactory;
        this.tokenExpireMs = tokenExpireMs;
        this.authRequestHandlers = authRequestHandlers;

        if (tokenExpireMs < MIN_COOKIE_EXPIRE_MS) {
            throw new IllegalArgumentException("session.tokenExpireMs is set too low. It is configured to "
                    + tokenExpireMs + "ms (minimum is " + MIN_COOKIE_EXPIRE_MS + "ms). At low levels it is difficult "
                    + "to guarantee smooth operability given a distributed system and potential clock drift");
        }
        final long tokenExpireNanos = TimeUnit.MILLISECONDS.toNanos(tokenExpireMs);
        if (tokenExpireNanos == Long.MIN_VALUE || tokenExpireNanos == Long.MAX_VALUE) {
            throw new IllegalArgumentException("session.tokenExpireMs is set too high.");
        }

        // Protect ourselves from rotation spam, but be loose enough that any reasonable run strategy works.
        this.tokenRotateMs = tokenExpireMs / 5;

        if (ProcessEnvironment.tryGet() != null) {
            ProcessEnvironment.getGlobalFatalErrorReporter().addInterceptor(this::onFatalError);
        }

        this.sessionListener = new DelegatingSessionListener(sessionListeners);
    }

    private synchronized void onFatalError(
            @NotNull String message,
            @NotNull Throwable throwable,
            boolean isFromUncaught) {
        final TerminationNotificationResponse.Builder builder =
                TerminationNotificationResponse.newBuilder()
                        .setAbnormalTermination(true)
                        .setIsFromUncaughtException(isFromUncaught)
                        .setReason(message);

        // TODO (core#801): revisit this error communication to properly match the API Error mode
        for (int depth = 0; throwable != null && depth < MAX_STACK_TRACE_CAUSAL_DEPTH; ++depth) {
            builder.addStackTraces(transformToProtoBuf(throwable));
            throwable = throwable.getCause();
        }

        final TerminationNotificationResponse notification = builder.build();
        terminationListeners.forEach(listener -> listener.sendMessage(notification));
        terminationListeners.clear();
    }

    private static TerminationNotificationResponse.StackTrace transformToProtoBuf(@NotNull final Throwable throwable) {
        return TerminationNotificationResponse.StackTrace.newBuilder()
                .setType(throwable.getClass().getName())
                .setMessage(Objects.toString(throwable.getMessage()))
                .addAllElements(Arrays.stream(throwable.getStackTrace())
                        .limit(MAX_STACK_TRACE_DEPTH)
                        .map(StackTraceElement::toString)
                        .collect(Collectors.toList()))
                .build();
    }

    public synchronized void onShutdown() {
        final TerminationNotificationResponse notification = TerminationNotificationResponse.newBuilder()
                .setAbnormalTermination(false)
                .build();
        terminationListeners.forEach(listener -> listener.sendMessage(notification));
        terminationListeners.clear();

        closeAllSessions();
    }

    /**
     * Add a listener who receives a single notification when this process is exiting and yet able to communicate with
     * the observer.
     *
     * @param session the session the observer belongs to
     * @param responseObserver the observer to notify
     */
    public void addTerminationListener(
            final SessionState session,
            final StreamObserver responseObserver) {
        terminationListeners.add(new TerminationNotificationListener(session, responseObserver));
    }

    /**
     * Create a new session object for the provided auth context.
     *
     * @param authContext the auth context of the session
     * @return a new session independent of all other existing sessions
     */
    public SessionState newSession(final AuthContext authContext) {
        final SessionState session = sessionFactory.create(authContext);
        checkTokenAndRotate(session, true);
        sessionListener.onSessionCreate(session);
        return session;
    }

    /**
     * If enough time has passed since the last token run, rotate to a new token and reset the expiration deadline.
     *
     * @param session the session to run
     * @return the most recent token expiration
     */
    public TokenExpiration refreshToken(final SessionState session) {
        return checkTokenAndRotate(session, false);
    }

    @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
    private TokenExpiration checkTokenAndRotate(final SessionState session, boolean initialToken) {
        UUID newUUID;
        TokenExpiration expiration;
        final long nowMillis = scheduler.currentTimeMillis();

        synchronized (session) {
            expiration = session.getExpiration();
            if (!initialToken) {
                if (expiration == null) {
                    // current token is expired; we should not rotate
                    return null;
                }

                if (expiration.deadlineMillis - tokenExpireMs + tokenRotateMs > nowMillis) {
                    // current token is not old enough to rotate
                    return expiration;
                }
            }

            do {
                newUUID = UuidCreator.getRandomBased();
                expiration = new TokenExpiration(newUUID, nowMillis + tokenExpireMs, session);
            } while (tokenToSession.putIfAbsent(newUUID, expiration) != null);

            if (initialToken) {
                session.initializeExpiration(expiration);
            } else {
                session.updateExpiration(expiration);
            }
        }
        outstandingCookies.addLast(expiration);

        synchronized (this) {
            if (!cleanupJobInstalled) {
                cleanupJobInstalled = true;
                scheduler.runAtTime(expiration.deadlineMillis, sessionCleanupJob);
            }
        }

        return expiration;
    }

    /**
     * @return the configured token duration in milliseconds
     */
    public long getExpirationDelayMs() {
        return tokenExpireMs;
    }

    /**
     * Lookup a session by token. Creates a new session if it's a basic auth and it passes.
     *
     * @param token the Authentication header to service
     * @return the session or null if the session is invalid
     */
    public SessionState getSessionForAuthToken(final String token) throws AuthenticationException {
        if (token.startsWith(Auth2Constants.BEARER_PREFIX)) {
            final String authToken = token.substring(Auth2Constants.BEARER_PREFIX.length());
            try {
                UUID uuid = UuidCreator.fromString(authToken);
                SessionState session = getSessionForToken(uuid);
                if (session != null) {
                    return session;
                }
            } catch (IllegalArgumentException ignored) {
            }
        }

        int offset = token.indexOf(' ');
        final String key = token.substring(0, offset < 0 ? token.length() : offset);
        final String payload = offset < 0 ? "" : token.substring(offset + 1);
        AuthenticationRequestHandler handler = authRequestHandlers.get(key);
        if (handler == null) {
            log.info().append("No AuthenticationRequestHandler registered for type ").append(key).endl();
            throw new AuthenticationException();
        }
        return handler.login(payload, SessionServiceGrpcImpl::insertCallHeader)
                .map(this::newSession)
                .orElseThrow(AuthenticationException::new);
    }

    /**
     * Lookup a session by token.
     *
     * @param token the session secret to look for
     * @return the session or null if the session is invalid
     */
    public SessionState getSessionForToken(final UUID token) {
        final TokenExpiration expiration = tokenToSession.get(token);
        if (expiration == null || expiration.session.isExpired()
                || expiration.deadlineMillis <= scheduler.currentTimeMillis()) {
            return null;
        }
        return expiration.session;
    }

    /**
     * Lookup a session via the SessionServiceGrpcImpl.SESSION_CONTEXT_KEY. This method is only valid in the context of
     * the original calling gRPC thread.
     *
     * @return the session attached to this gRPC request
     * @throws StatusRuntimeException if thread is not attached to a session or if the session is expired/closed
     */
    @NotNull
    public SessionState getCurrentSession() {
        final SessionState session = getOptionalSession();
        if (session == null) {
            throw Status.UNAUTHENTICATED.asRuntimeException();
        }
        return session;
    }

    /**
     * Lookup a session via the SessionServiceGrpcImpl.SESSION_CONTEXT_KEY. This method is only valid in the context of
     * the original calling gRPC thread.
     *
     * @return the session attached to this gRPC request; null if no session is established
     */
    @Nullable
    public SessionState getOptionalSession() {
        final SessionState session = SessionServiceGrpcImpl.SESSION_CONTEXT_KEY.get();
        if (session == null || session.isExpired()) {
            return null;
        }
        return session;
    }

    /**
     * Reduces the liveness of the session.
     *
     * @param session the session to close
     */
    public void closeSession(final SessionState session) {
        if (session.isExpired()) {
            return;
        }
        session.onExpired();
    }

    public void closeAllSessions() {
        for (final TokenExpiration token : outstandingCookies) {
            // close all exports/resources acquired by the session
            token.session.onExpired();
        }
    }

    public static final class TokenExpiration {
        public final UUID token;
        public final long deadlineMillis;
        public final SessionState session;

        public TokenExpiration(final UUID cookie, final long deadlineMillis, final SessionState session) {
            this.token = cookie;
            this.deadlineMillis = deadlineMillis;
            this.session = session;
        }

        /**
         * Returns the bearer token in byte[] friendly format.
         */
        public ByteString getBearerTokenAsByteString() {
            return ByteString.copyFromUtf8(Auth2Constants.BEARER_PREFIX + UuidCreator.toString(token));
        }

        public ByteString getTokenAsByteString() {
            return ByteString.copyFromUtf8(UuidCreator.toString(token));
        }
    }

    private final class SessionCleanupJob implements Runnable {
        @Override
        public void run() {
            final long nowMillis = scheduler.currentTimeMillis();

            do {
                final TokenExpiration next = outstandingCookies.peek();
                if (next == null || next.deadlineMillis > nowMillis) {
                    break;
                }

                // Permanently remove the first token as it is officially expired, note that other tokens may exist for
                // this session, so the session itself does not expire. We allow multiple tokens to co-exist to best
                // support out of order requests and thus allow any reasonable client behavior that respects a given
                // token expiration time.
                outstandingCookies.poll();

                if (next.session.isExpired()) {
                    next.session.onExpired();
                }
            } while (true);

            synchronized (SessionService.this) {
                final TokenExpiration next = outstandingCookies.peek();
                if (next == null) {
                    cleanupJobInstalled = false;
                } else {
                    scheduler.runAtTime(next.deadlineMillis, this);
                }
            }
        }
    }

    private final class TerminationNotificationListener
            extends SessionCloseableObserver {
        public TerminationNotificationListener(
                final SessionState session,
                final StreamObserver responseObserver) {
            super(session, responseObserver);
        }

        @Override
        protected void onClose() {
            GrpcUtil.safelyError(responseObserver, Code.UNAUTHENTICATED, "Session has ended");
            terminationListeners.remove(this);
        }

        void sendMessage(final TerminationNotificationResponse response) {
            GrpcUtil.safelyComplete(responseObserver, response);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy