java-micronaut.client.auth.AuthorizationFilter.mustache Maven / Gradle / Ivy
{{>common/licenseInfo}}
package {{invokerPackage}}.auth;
import io.micronaut.context.BeanContext;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.core.util.StringUtils;
import io.micronaut.core.util.Toggleable;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.MutableHttpRequest;
import io.micronaut.http.annotation.Filter;
import io.micronaut.http.filter.ClientFilterChain;
import io.micronaut.http.filter.HttpClientFilter;
import io.micronaut.inject.qualifiers.Qualifiers;
import io.micronaut.security.oauth2.client.clientcredentials.ClientCredentialsClient;
import io.micronaut.security.oauth2.client.clientcredentials.ClientCredentialsConfiguration;
import io.micronaut.security.oauth2.client.clientcredentials.propagation.ClientCredentialsHttpClientFilter;
import io.micronaut.security.oauth2.client.clientcredentials.propagation.ClientCredentialsTokenPropagator;
import io.micronaut.security.oauth2.configuration.OauthClientConfiguration;
import io.micronaut.security.oauth2.endpoint.token.response.TokenResponse;
import {{invokerPackage}}.auth.configuration.ConfigurableAuthorization;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import {{javaxPackage}}.annotation.Generated;
{{>common/generatedAnnotation}}
@Filter({{#configureAuthFilterPattern}}"{{authorizationFilterPattern}}"{{/configureAuthFilterPattern}}{{^configureAuthFilterPattern}}Filter.MATCH_ALL_PATTERN{{/configureAuthFilterPattern}})
public class AuthorizationFilter implements HttpClientFilter {
private static final Logger LOG = LoggerFactory.getLogger(ClientCredentialsHttpClientFilter.class);
private final BeanContext beanContext;
private final Map clientConfigurationByName;
ClientCredentialsTokenPropagator defaultTokenPropagator;
private final Map tokenPropagatorByName;
private final Map clientCredentialsClientByName;
public final Map authorizationsByName;
public AuthorizationFilter(
BeanContext beanContext,
Stream clientConfigurations,
ClientCredentialsTokenPropagator defaultTokenPropagator,
Stream configurableAuthorizations
) {
this.beanContext = beanContext;
this.clientConfigurationByName = clientConfigurations
.filter(Toggleable::isEnabled)
.collect(Collectors.toMap(OauthClientConfiguration::getName, v -> v));
this.defaultTokenPropagator = defaultTokenPropagator;
this.tokenPropagatorByName = new HashMap<>();
this.clientCredentialsClientByName = new HashMap<>();
this.authorizationsByName = configurableAuthorizations
.collect(Collectors.toMap(ConfigurableAuthorization::getName, v -> v));
}
@Override
public Publisher> doFilter(
@NonNull MutableHttpRequest request,
@NonNull ClientFilterChain chain
) {
List names = request.getAttribute(AuthorizationBinder.AUTHORIZATION_NAMES, List.class).orElse(null);
if (CollectionUtils.isNotEmpty(names)) {
List> authorizers = new ArrayList<>(names.size());
for (Object nameObject: names) {
if (!(nameObject instanceof String)) {
continue;
}
String name = (String) nameObject;
// Check if other authorizations have the key
if (authorizationsByName.containsKey(name)) {
ConfigurableAuthorization authorizer = authorizationsByName.get(name);
authorizers.add(Mono.fromCallable(() -> {
authorizer.applyAuthorization(request);
return request;
}));
continue;
}
// Perform OAuth authorization
OauthClientConfiguration clientConfiguration = clientConfigurationByName.get(name);
if (clientConfiguration == null) {
continue;
}
ClientCredentialsClient clientCredentialsClient = getClientCredentialsClient(name);
if (clientCredentialsClient == null) {
if (LOG.isTraceEnabled()) {
LOG.trace("Could not retrieve client credentials client for OAuth 2.0 client {}", name);
}
continue;
}
ClientCredentialsTokenPropagator tokenHandler = getTokenPropagator(name);
Flux authorizer = Flux.from(clientCredentialsClient
.requestToken(getScope(clientConfiguration)))
.map(TokenResponse::getAccessToken)
.map(accessToken -> {
if (StringUtils.isNotEmpty(accessToken)) {
tokenHandler.writeToken(request, accessToken);
}
return request;
});
authorizers.add(authorizer);
}
return Flux.concat(authorizers)
.switchMap(v -> chain.proceed(request));
}
return chain.proceed(request);
}
protected ClientCredentialsTokenPropagator getTokenPropagator(String name) {
ClientCredentialsTokenPropagator tokenPropagator = tokenPropagatorByName.get(name);
if (tokenPropagator == null) {
tokenPropagator = beanContext.findBean(ClientCredentialsTokenPropagator.class, Qualifiers.byName(name))
.orElse(defaultTokenPropagator);
if (tokenPropagator != null) {
tokenPropagatorByName.put(name, tokenPropagator);
}
}
return tokenPropagator;
}
protected ClientCredentialsClient getClientCredentialsClient(String name) {
ClientCredentialsClient client = clientCredentialsClientByName.get(name);
if (client == null) {
client = beanContext.findBean(ClientCredentialsClient.class, Qualifiers.byName(name)).orElse(null);
if (client != null) {
clientCredentialsClientByName.put(name, client);
}
}
return client;
}
@Nullable
protected String getScope(@NonNull OauthClientConfiguration oauthClient) {
return oauthClient.getClientCredentials().flatMap(ClientCredentialsConfiguration::getScope).orElse(null);
}
}