All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.stargate.grpc.service.interceptors.NewConnectionInterceptor Maven / Gradle / Ivy

There is a newer version: 2.1.0-BETA-14
Show newest version
package io.stargate.grpc.service.interceptors;

import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.stargate.auth.AuthenticationService;
import io.stargate.auth.AuthenticationSubject;
import io.stargate.auth.UnauthorizedException;
import io.stargate.db.AuthenticatedUser;
import io.stargate.db.ClientInfo;
import io.stargate.db.Persistence;
import io.stargate.db.Persistence.Connection;
import io.stargate.grpc.service.GrpcService;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletionException;
import javax.annotation.Nullable;
import org.apache.cassandra.stargate.exceptions.UnhandledClientException;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NewConnectionInterceptor implements ServerInterceptor {

  private static final Logger logger = LoggerFactory.getLogger(NewConnectionInterceptor.class);

  public static final Metadata.Key TOKEN_KEY =
      Metadata.Key.of("X-Cassandra-Token", Metadata.ASCII_STRING_MARSHALLER);

  private static final InetSocketAddress DUMMY_ADDRESS = new InetSocketAddress(9042);
  private static final int CACHE_TTL_SECS =
      Integer.getInteger("stargate.grpc.connection_cache_ttl_seconds", 60);
  private static final int CACHE_MAX_SIZE =
      Integer.getInteger("stargate.grpc.connection_cache_max_size", 10_000);

  private final Persistence persistence;
  private final AuthenticationService authenticationService;
  private final LoadingCache connectionCache =
      Caffeine.newBuilder()
          .expireAfterWrite(Duration.ofSeconds(CACHE_TTL_SECS))
          .maximumSize(CACHE_MAX_SIZE)
          .build(this::newConnection);

  @Value.Immutable
  interface RequestInfo {

    String token();

    Map headers();

    @Nullable
    SocketAddress remoteAddress();
  }

  public NewConnectionInterceptor(
      Persistence persistence, AuthenticationService authenticationService) {
    this.persistence = persistence;
    this.authenticationService = authenticationService;
  }

  @Override
  public  Listener interceptCall(
      ServerCall call, Metadata headers, ServerCallHandler next) {
    try {
      String token = headers.get(TOKEN_KEY);
      if (token == null) {
        call.close(Status.UNAUTHENTICATED.withDescription("No token provided"), new Metadata());
        return new NopListener<>();
      }

      Map stringHeaders = convertAndFilterHeaders(headers);

      // Some authentication service and persistence implementations depend on the "host" header
      // being set. HTTP/2 uses the ":authority" pseudo-header for this purpose and the
      // `grpc-netty-shaded` implementation will move the "host" header into the ":authority" value:
      // https://github.com/grpc/grpc-java/commit/122b3b2f7cf2b50fe0a0cebc55a84133441a4348
      String authority = call.getAuthority();
      if (authority != null && !authority.isEmpty()) {
        stringHeaders.put("host", authority);
      }

      RequestInfo info =
          ImmutableRequestInfo.builder()
              .token(token)
              .headers(stringHeaders)
              .remoteAddress(call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR))
              .build();

      Connection connection = connectionCache.get(info);

      Context context = Context.current();
      context =
          context.withValues(
              GrpcService.CONNECTION_KEY, connection, GrpcService.HEADERS_KEY, stringHeaders);
      return Contexts.interceptCall(context, call, headers, next);
    } catch (Exception e) {
      Throwable cause = e;
      if (cause instanceof CompletionException) {
        cause = e.getCause();
      }
      if (cause instanceof UnauthorizedException) {
        call.close(
            Status.UNAUTHENTICATED.withDescription("Invalid token").withCause(e), new Metadata());
      } else if (cause instanceof UnhandledClientException) {
        call.close(Status.UNAVAILABLE.withDescription(e.getMessage()).withCause(e), new Metadata());
      } else {
        final String message = "Error attempting to create connection to persistence";
        logger.error(message, cause);
        call.close(Status.INTERNAL.withDescription(message).withCause(e), new Metadata());
      }
    }
    return new NopListener<>();
  }

  private Connection newConnection(RequestInfo info) throws UnauthorizedException {
    AuthenticationSubject authenticationSubject =
        authenticationService.validateToken(info.token(), info.headers());

    AuthenticatedUser user = authenticationSubject.asUser();

    SocketAddress remoteAddress = info.remoteAddress();
    // This is best effort attempt to set the remote address, if the remote address is not the
    // correct type then use a dummy value. Note: `remoteAddress` is almost always a
    // `InetSocketAddress`.
    InetSocketAddress inetSocketAddress =
        remoteAddress instanceof InetSocketAddress
            ? (InetSocketAddress) remoteAddress
            : DUMMY_ADDRESS;
    ClientInfo clientInfo = new ClientInfo(inetSocketAddress, null);
    Connection connection = persistence.newConnection(clientInfo);
    connection.login(user);
    if (user.token() != null) {
      clientInfo.setAuthenticatedUser(user);
    }
    connection.setCustomProperties(info.headers());
    return connection;
  }

  private static Map convertAndFilterHeaders(Metadata headers) {
    Map stringHeaders = new HashMap<>();
    for (String key : headers.keys()) {
      // Ignore gRPC specific (in addition to binary) headers because they're not used by Connection
      // and some of the values contain unique values for each request which prevents effective
      // caching.
      if (key.endsWith("-bin") || key.startsWith("grpc-")) {
        continue;
      }
      String value = headers.get(Key.of(key, Metadata.ASCII_STRING_MARSHALLER));
      stringHeaders.put(key, value);
    }
    return stringHeaders;
  }

  private static class NopListener extends ServerCall.Listener {}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy