de.arbeitsagentur.opdt.keycloak.cassandra.client.CassandraClientProvider Maven / Gradle / Ivy
/*
* Copyright 2022 IT-Systemhaus der Bundesagentur fuer Arbeit
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.arbeitsagentur.opdt.keycloak.cassandra.client;
import static de.arbeitsagentur.opdt.keycloak.common.MapProviderObjectType.CLIENT_AFTER_REMOVE;
import static de.arbeitsagentur.opdt.keycloak.common.MapProviderObjectType.CLIENT_BEFORE_REMOVE;
import static org.keycloak.common.util.StackUtil.getShortStackTrace;
import de.arbeitsagentur.opdt.keycloak.cassandra.CompositeRepository;
import de.arbeitsagentur.opdt.keycloak.cassandra.client.persistence.ClientRepository;
import de.arbeitsagentur.opdt.keycloak.cassandra.client.persistence.entities.Client;
import de.arbeitsagentur.opdt.keycloak.cassandra.transaction.TransactionalProvider;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.jbosslog.JBossLog;
import org.keycloak.models.*;
import org.keycloak.models.utils.KeycloakModelUtils;
@JBossLog
public class CassandraClientProvider extends TransactionalProvider
implements ClientProvider {
private final ClientRepository clientRepository;
public CassandraClientProvider(KeycloakSession session, CompositeRepository cassandraRepository) {
super(session);
this.clientRepository = cassandraRepository;
}
@Override
protected CassandraClientAdapter createNewModel(RealmModel realm, Client entity) {
return createNewModel(realm, entity, () -> {});
}
private CassandraClientAdapter createNewModelWithRollback(RealmModel realm, Client entity) {
return createNewModel(
realm,
entity,
() -> {
clientRepository.delete(entity);
models.remove(entity.getId());
});
}
private CassandraClientAdapter createNewModel(
RealmModel realm, Client entity, Runnable rollbackTask) {
CassandraClientAdapter adapter =
new CassandraClientAdapter(entity, session, realm, clientRepository) {
@Override
public void rollback() {
rollbackTask.run();
}
};
return adapter;
}
@Override
public Stream getClientsStream(
RealmModel realm, Integer firstResult, Integer maxResults) {
return Stream.concat(
models.values().stream().filter(m -> m.getRealm().equals(realm)),
clientRepository.findAllClientsWithRealmId(realm.getId()).stream()
.filter(Objects::nonNull)
.map(entityToAdapterFunc(realm)))
.distinct()
.map(ClientModel.class::cast)
.sorted(Comparator.comparing(ClientModel::getClientId))
.skip(firstResult == null || firstResult < 0 ? 0 : firstResult)
.limit(maxResults == null || maxResults < 0 ? Long.MAX_VALUE : maxResults);
}
@Override
public ClientModel addClient(RealmModel realm, String id, String clientId) {
log.tracef("addClient(%s, %s, %s)%s", realm, id, clientId, getShortStackTrace());
if (id != null && getClientById(realm, id) != null) {
throw new ModelDuplicateException("Client with same id exists: " + id);
}
if (clientId != null && getClientByClientId(realm, clientId) != null) {
throw new ModelDuplicateException(
"Client with same clientId in realm " + realm.getName() + " exists: " + clientId);
}
String newId = id == null ? KeycloakModelUtils.generateId() : id;
Client client = new Client(realm.getId(), newId, null, new HashMap<>());
clientRepository.insertOrUpdate(client);
ClientModel adapter =
entityToAdapterFunc(realm, this::createNewModelWithRollback).apply(client);
adapter.setClientId(clientId != null ? clientId : client.getId());
adapter.setEnabled(true);
adapter.setStandardFlowEnabled(true);
// TODO: Sending an event should be extracted to store layer
session.getKeycloakSessionFactory().publish((ClientModel.ClientCreationEvent) () -> adapter);
adapter
.updateClient(); // This is actualy strange contract - it should be the store code to call
// updateClient
return adapter;
}
@Override
public long getClientsCount(RealmModel realm) {
return clientRepository.countClientsByRealm(realm.getId());
}
@Override
public Stream getAlwaysDisplayInConsoleClientsStream(RealmModel realm) {
return getClientsStream(realm).filter(ClientModel::isAlwaysDisplayInConsole);
}
@Override
public boolean removeClient(RealmModel realm, String id) {
Client client = clientRepository.getClientById(realm.getId(), id);
if (client == null) {
return false;
}
ClientModel clientModel = getClientById(realm, id);
session.invalidate(CLIENT_BEFORE_REMOVE, realm, clientModel);
clientRepository.delete(client);
((CassandraClientAdapter) clientModel).markDeleted();
models.remove(client.getId());
session.invalidate(CLIENT_AFTER_REMOVE, clientModel);
return true;
}
@Override
public void removeClients(RealmModel realm) {
log.tracef("removeClients(%s)%s", realm, getShortStackTrace());
getClientsStream(realm).map(ClientModel::getId).forEach(cid -> removeClient(realm, cid));
}
@Override
public void addClientScopes(
RealmModel realm,
ClientModel client,
Set clientScopes,
boolean defaultScope) {
// Defaults to openid-connect
String clientProtocol = client.getProtocol() == null ? "openid-connect" : client.getProtocol();
log.tracef(
"addClientScopes(%s, %s, %s, %b)%s",
realm, client, clientScopes, defaultScope, getShortStackTrace());
Map existingClientScopes = getClientScopes(realm, client, true);
existingClientScopes.putAll(getClientScopes(realm, client, false));
clientScopes.stream()
.filter(clientScope -> !existingClientScopes.containsKey(clientScope.getName()))
.filter(clientScope -> Objects.equals(clientScope.getProtocol(), clientProtocol))
.forEach(clientScope -> client.addClientScope(clientScope, defaultScope));
}
@Override
public void removeClientScope(
RealmModel realm, ClientModel client, ClientScopeModel clientScope) {
log.tracef("removeClientScope(%s, %s, %s)%s", realm, client, clientScope, getShortStackTrace());
client.removeClientScope(clientScope);
}
@Override
public void addClientScopeToAllClients(
RealmModel realmModel, ClientScopeModel clientScopeModel, boolean defaultClientScope) {
log.tracef(
"addClientScopeToAllClients(%s, %s, %b)%s",
realmModel, clientScopeModel, defaultClientScope, getShortStackTrace());
clientRepository
.findAllClientsWithRealmId(realmModel.getId())
.forEach(
client -> {
ClientModel clientModel = entityToAdapterFunc(realmModel).apply(client);
clientModel.addClientScope(clientScopeModel, defaultClientScope);
});
}
@Override
public Map> getAllRedirectUrisOfEnabledClients(RealmModel realm) {
return Stream.concat(
models.values().stream().filter(m -> m.getRealm().equals(realm)),
clientRepository.findAllClientsWithRealmId(realm.getId()).stream()
.map(entityToAdapterFunc(realm)))
.distinct()
.filter(ClientModel::isEnabled)
.filter(c -> !c.getRedirectUris().isEmpty())
.collect(Collectors.toMap(Function.identity(), ClientModel::getRedirectUris));
}
@Override
public ClientModel getClientById(RealmModel realm, String id) {
log.tracef("getClientById(%s, %s)%s", realm, id, getShortStackTrace());
Client client = clientRepository.getClientById(realm.getId(), id);
return entityToAdapterFunc(realm).apply(client);
}
@Override
public ClientModel getClientByClientId(RealmModel realm, String clientId) {
Client byClientId = clientRepository.findByClientId(realm.getId(), clientId);
if (byClientId != null) {
return entityToAdapterFunc(realm).apply(byClientId);
}
return Stream.concat(
models.values().stream().filter(m -> m.getRealm().equals(realm)),
clientRepository.findAllClientsWithRealmId(realm.getId()).stream()
.map(entityToAdapterFunc(realm)))
.distinct()
.filter(e -> Objects.equals(e.getClientId(), clientId))
.findFirst()
.orElse(null);
}
@Override
public Stream searchClientsByClientIdStream(
RealmModel realm, String clientId, Integer firstResult, Integer maxResults) {
if (clientId == null) {
return Stream.empty();
}
return Stream.concat(
models.values().stream().filter(m -> m.getRealm().equals(realm)),
clientRepository.findAllClientsWithRealmId(realm.getId()).stream()
.map(entityToAdapterFunc(realm)))
.distinct()
.map(ClientModel.class::cast)
.filter(
e ->
"%".equals(clientId)
|| e.getAttribute(CassandraClientAdapter.CLIENT_ID)
.toLowerCase()
.contains(clientId.toLowerCase()))
.skip(firstResult == null || firstResult < 0 ? 0 : firstResult)
.limit(maxResults == null || maxResults < 0 ? Long.MAX_VALUE : maxResults);
}
@Override
public Stream searchClientsByAttributes(
RealmModel realm, Map attributes, Integer firstResult, Integer maxResults) {
return Stream.concat(
models.values().stream().filter(m -> m.getRealm().equals(realm)),
clientRepository.findAllClientsWithRealmId(realm.getId()).stream()
.map(entityToAdapterFunc(realm)))
.distinct()
.map(ClientModel.class::cast)
.filter(
c ->
attributes.isEmpty()
|| c.getAttributes().entrySet().containsAll(attributes.entrySet()))
.skip(firstResult == null || firstResult < 0 ? 0 : firstResult)
.limit(maxResults == null || maxResults < 0 ? Long.MAX_VALUE : maxResults);
}
@Override
public Map getClientScopes(
RealmModel realm, ClientModel client, boolean defaultScopes) {
if (client == null) return null;
// Defaults to openid-connect
String clientProtocol = client.getProtocol() == null ? "openid-connect" : client.getProtocol();
log.tracef("getClientScopes(%s, %s, %b)%s", realm, client, defaultScopes, getShortStackTrace());
return client.getClientScopes(defaultScopes).values().stream()
.filter(Objects::nonNull)
.filter(clientScope -> Objects.equals(clientScope.getProtocol(), clientProtocol))
.collect(Collectors.toMap(ClientScopeModel::getName, Function.identity()));
}
public void preRemove(RealmModel realm, RoleModel role) {
realm.getClientsStream().forEach(c -> c.deleteScopeMapping(role));
}
public void preRemove(RealmModel realm) {
this.removeClients(realm);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy