org.lognet.springboot.grpc.security.SecurityInterceptor Maven / Gradle / Ivy
The newest version!
// Generated by delombok at Wed Sep 27 05:27:18 UTC 2023
package org.lognet.springboot.grpc.security;
import io.grpc.*;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import org.lognet.springboot.grpc.FailureHandlingSupport;
import org.lognet.springboot.grpc.GRpcServicesRegistry;
import org.lognet.springboot.grpc.MessageBlockingServerCallListener;
import org.lognet.springboot.grpc.autoconfigure.GRpcServerProperties;
import org.lognet.springboot.grpc.recovery.GRpcRuntimeExceptionWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.Ordered;
import org.springframework.security.access.SecurityMetadataSource;
import org.springframework.security.access.intercept.AbstractSecurityInterceptor;
import org.springframework.security.access.intercept.InterceptorStatusToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.util.SimpleMethodInvocation;
public class SecurityInterceptor extends AbstractSecurityInterceptor
implements ServerInterceptor, Ordered {
@java.lang.SuppressWarnings("all")
private static final org.slf4j.Logger log =
org.slf4j.LoggerFactory.getLogger(SecurityInterceptor.class);
private static final Context.Key INTERCEPTOR_STATUS_TOKEN =
Context.key("INTERCEPTOR_STATUS_TOKEN");
private static final Context.Key> METHOD_INVOCATION =
Context.key("METHOD_INVOCATION");
private final SecurityMetadataSource securityMetadataSource;
private final AuthenticationSchemeSelector schemeSelector;
private GRpcServerProperties.SecurityProperties.Auth authCfg;
private FailureHandlingSupport failureHandlingSupport;
private GRpcServicesRegistry registry;
static class GrpcMethodInvocation extends SimpleMethodInvocation {
private final ServerCall call;
private final Metadata headers;
private final ServerCallHandler next;
private Object[] arguments;
public GrpcMethodInvocation(
GRpcServicesRegistry.GrpcServiceMethod serviceMethod,
ServerCall call,
Metadata headers,
ServerCallHandler next) {
super(serviceMethod.getService(), serviceMethod.getMethod());
this.call = call;
this.headers = headers;
this.next = next;
}
@Override
public Object proceed() {
return next.startCall(call, headers);
}
ServerCall getCall() {
return call;
}
@java.lang.SuppressWarnings("all")
public Object[] getArguments() {
return this.arguments;
}
@java.lang.SuppressWarnings("all")
public void setArguments(final Object[] arguments) {
this.arguments = arguments;
}
}
public SecurityInterceptor(
SecurityMetadataSource securityMetadataSource, AuthenticationSchemeSelector schemeSelector) {
this.securityMetadataSource = securityMetadataSource;
this.schemeSelector = schemeSelector;
}
@Autowired
public void setGRpcServicesRegistry(GRpcServicesRegistry registry) {
this.registry = registry;
}
@Autowired
public void setFailureHandlingSupport(@Lazy FailureHandlingSupport failureHandlingSupport) {
this.failureHandlingSupport = failureHandlingSupport;
}
public void setConfig(GRpcServerProperties.SecurityProperties.Auth authCfg) {
this.authCfg =
Optional.ofNullable(authCfg).orElseGet(GRpcServerProperties.SecurityProperties.Auth::new);
}
@Override
public int getOrder() {
return Optional.ofNullable(authCfg.getInterceptorOrder())
.orElse(Ordered.HIGHEST_PRECEDENCE + 1);
}
@Override
public Class> getSecureObjectClass() {
return GrpcMethodInvocation.class;
}
@Override
public SecurityMetadataSource obtainSecurityMetadataSource() {
return securityMetadataSource;
}
/**
* Execute the same interceptor flow as original
* FilterSecurityInterceptor/MethodSecurityInterceptor { InterceptorStatusToken token =
* super.beforeInvocation(mi); Object result; try { result = mi.proceed(); } finally {
* super.finallyInvocation(token); } return super.afterInvocation(token, result); }
*/
@Override
public ServerCall.Listener interceptCall(
ServerCall call, Metadata headers, ServerCallHandler next) {
final CharSequence authorization =
Optional.ofNullable(
headers.get(
Metadata.Key.of(
"Authorization" + Metadata.BINARY_HEADER_SUFFIX,
Metadata.BINARY_BYTE_MARSHALLER)))
.map(auth -> (CharSequence) StandardCharsets.UTF_8.decode(ByteBuffer.wrap(auth)))
.orElse(
headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
try {
final Context grpcSecurityContext;
try {
grpcSecurityContext = setupGRpcSecurityContext(call, headers, next, authorization);
} catch (RuntimeException e) {
return fail(next, call, headers, e);
} catch (Exception e) {
return fail(next, call, headers, new GRpcRuntimeExceptionWrapper(e));
}
return Contexts.interceptCall(
grpcSecurityContext, call, headers, authenticationPropagatingHandler(next));
} finally {
SecurityContextHolder.getContext().setAuthentication(null);
}
}
private ServerCallHandler authenticationPropagatingHandler(
ServerCallHandler next) {
return (call, headers) ->
new ForwardingServerCallListener.SimpleForwardingServerCallListener(
next.startCall(afterInvocationPropagator(call), headers)) {
@Override
public void onMessage(ReqT message) {
propagateAuthentication(
() -> {
try {
switch (call.getMethodDescriptor().getType()) {
// server streaming and unary calls generated with 2 parameters,
// first one is the actual input
case SERVER_STREAMING:
case UNARY:
METHOD_INVOCATION.get().setArguments(new Object[] {message, null});
break;
// client streaming and bidi streaming calls generated with 1 parameter
case BIDI_STREAMING:
case CLIENT_STREAMING:
case UNKNOWN:
METHOD_INVOCATION.get().setArguments(new Object[] {message});
break;
default:
log.error("Unsupported call type " + call.getMethodDescriptor().getType());
throw new StatusRuntimeException(Status.UNAUTHENTICATED);
}
beforeInvocation(METHOD_INVOCATION.get());
super.onMessage(message);
} catch (RuntimeException e) {
failureHandlingSupport.closeCall(e, call, headers);
} catch (Exception e) {
failureHandlingSupport.closeCall(
new GRpcRuntimeExceptionWrapper(e), call, headers);
} finally {
METHOD_INVOCATION.get().setArguments(null);
}
});
}
@Override
public void onHalfClose() {
try {
propagateAuthentication(super::onHalfClose);
} finally {
finallyInvocation(INTERCEPTOR_STATUS_TOKEN.get());
}
}
@Override
public void onCancel() {
propagateAuthentication(super::onCancel);
}
@Override
public void onComplete() {
propagateAuthentication(super::onComplete);
}
@Override
public void onReady() {
propagateAuthentication(super::onReady);
}
private void propagateAuthentication(Runnable runnable) {
try {
SecurityContextHolder.getContext()
.setAuthentication(GrpcSecurity.AUTHENTICATION_CONTEXT_KEY.get());
runnable.run();
} finally {
SecurityContextHolder.clearContext();
}
}
};
}
private ServerCall afterInvocationPropagator(
ServerCall call) {
return new ForwardingServerCall.SimpleForwardingServerCall(call) {
@Override
public void sendMessage(ReqT message) {
super.sendMessage((ReqT) afterInvocation(INTERCEPTOR_STATUS_TOKEN.get(), message));
}
};
}
private Context setupGRpcSecurityContext(
ServerCall call,
Metadata headers,
ServerCallHandler next,
CharSequence authorization) {
final Authentication authentication =
null == authorization
? null
: schemeSelector
.getAuthScheme(authorization)
.orElseThrow(() -> new StatusRuntimeException(Status.UNAUTHENTICATED));
SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
final GRpcServicesRegistry.GrpcServiceMethod grpcServiceMethod =
registry.getGrpServiceMethod(call.getMethodDescriptor());
final GrpcMethodInvocation methodInvocation =
new GrpcMethodInvocation<>(grpcServiceMethod, call, headers, next);
final InterceptorStatusToken interceptorStatusToken = beforeInvocation(methodInvocation);
return Context.current()
.withValue(
GrpcSecurity.AUTHENTICATION_CONTEXT_KEY,
SecurityContextHolder.getContext().getAuthentication())
.withValue(INTERCEPTOR_STATUS_TOKEN, interceptorStatusToken)
.withValue(METHOD_INVOCATION, methodInvocation);
}
private ServerCall.Listener fail(
ServerCallHandler next,
ServerCall call,
Metadata headers,
RuntimeException exception)
throws RuntimeException {
if (authCfg.isFailFast()) {
failureHandlingSupport.closeCall(exception, call, headers);
return new ServerCall.Listener() {};
} else {
return new MessageBlockingServerCallListener(next.startCall(call, headers)) {
@Override
public void onMessage(ReqT message) {
blockMessage();
failureHandlingSupport.closeCall(exception, call, headers, b -> b.request(message));
}
};
}
}
}