All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apereo.cas.ticket.registry.CassandraTicketRegistry Maven / Gradle / Ivy

There is a newer version: 7.1.0
Show newest version
package org.apereo.cas.ticket.registry;

import org.apereo.cas.cassandra.CassandraSessionFactory;
import org.apereo.cas.configuration.model.support.cassandra.ticketregistry.CassandraTicketRegistryProperties;
import org.apereo.cas.configuration.support.Beans;
import org.apereo.cas.ticket.Ticket;
import org.apereo.cas.ticket.TicketCatalog;
import org.apereo.cas.ticket.TicketDefinition;
import org.apereo.cas.ticket.TicketGrantingTicket;
import org.apereo.cas.ticket.serialization.TicketSerializationManager;
import org.apereo.cas.util.crypto.CipherExecutor;
import org.apereo.cas.util.function.FunctionUtils;
import org.apereo.cas.util.serialization.JacksonObjectMapperFactory;

import com.datastax.oss.driver.api.core.DefaultConsistencyLevel;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.data.cassandra.core.cql.BeanPropertyRowMapper;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * This is {@link CassandraTicketRegistry}.
 *
 * @author Misagh Moayyed
 * @author doomviking
 * @since 6.1.0
 */
@Slf4j
public class CassandraTicketRegistry extends AbstractTicketRegistry implements DisposableBean, InitializingBean {
    private static final ObjectMapper MAPPER = JacksonObjectMapperFactory.builder()
        .defaultTypingEnabled(false).build().toObjectMapper();

    private final CassandraSessionFactory cassandraSessionFactory;

    private final CassandraTicketRegistryProperties properties;

    public CassandraTicketRegistry(final CipherExecutor cipherExecutor,
                                   final TicketSerializationManager ticketSerializationManager,
                                   final TicketCatalog ticketCatalog,
                                   final CassandraSessionFactory cassandraSessionFactory,
                                   final CassandraTicketRegistryProperties properties) {
        super(cipherExecutor, ticketSerializationManager, ticketCatalog);
        this.cassandraSessionFactory = cassandraSessionFactory;
        this.properties = properties;
    }


    private static int getTimeToLive(final Ticket ticket) {
        val timeToLive = ticket.getExpirationPolicy().getTimeToLive();
        val ttl = Long.MAX_VALUE == timeToLive ? Long.valueOf(Integer.MAX_VALUE) : timeToLive;
        if (ttl >= CassandraSessionFactory.MAX_TTL) {
            return CassandraSessionFactory.MAX_TTL;
        }
        return ttl.intValue();
    }

    @Override
    public Ticket getTicket(final String ticketId, final Predicate predicate) {
        LOGGER.trace("Locating ticket [{}]", ticketId);
        val encodedTicketId = digestIdentifier(ticketId);
        if (StringUtils.isBlank(encodedTicketId)) {
            LOGGER.debug("Ticket id [{}] could not be found", ticketId);
            return null;
        }

        val definition = ticketCatalog.find(ticketId);
        if (definition == null) {
            LOGGER.debug("Ticket definition [{}] could not be found in the ticket catalog", ticketId);
            return null;
        }

        val holder = findCassandraTicketBy(definition, encodedTicketId);
        if (holder.isEmpty()) {
            LOGGER.debug("Ticket [{}] could not be found in Cassandra", encodedTicketId);
            return null;
        }

        val object = deserialize(holder.iterator().next());
        val result = decodeTicket(object);
        return FunctionUtils.doAndReturn(result != null && predicate.test(result), () -> result, () -> {
            LOGGER.trace("The condition enforced by the predicate [{}] cannot successfully accept/test the ticket id [{}]", encodedTicketId,
                predicate.getClass().getSimpleName());
            return null;
        });
    }

    @Override
    public void addTicketInternal(final Ticket ticket) throws Exception {
        addTicketToCassandra(ticket, true);
    }

    @Override
    public Ticket updateTicket(final Ticket ticket) throws Exception {
        addTicketToCassandra(ticket, false);
        return ticket;
    }

    @Override
    public Collection getTickets() {
        return ticketCatalog.findAll()
            .stream()
            .map(definition -> {
                val results = findCassandraTicketBy(definition);
                return results
                    .stream()
                    .map(holder -> {
                        val result = deserialize(holder);
                        return decodeTicket(result);
                    })
                    .collect(Collectors.toSet());
            })
            .flatMap(Set::stream)
            .filter(Objects::nonNull)
            .filter(ticket -> !ticket.isExpired())
            .collect(Collectors.toSet());
    }

    @Override
    public long deleteSingleTicket(final Ticket ticketToDelete) {
        val ticketId = digestIdentifier(ticketToDelete.getId());
        LOGGER.debug("Deleting ticket [{}]", ticketId);
        val definition = ticketCatalog.find(ticketToDelete);
        val delete = QueryBuilder
            .deleteFrom(properties.getKeyspace(), definition.getProperties().getStorageName())
            .whereColumn("id").isEqualTo(QueryBuilder.literal(ticketId))
            .build()
            .setConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getConsistencyLevel()))
            .setSerialConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getSerialConsistencyLevel()))
            .setTimeout(Beans.newDuration(properties.getTimeout()));
        cassandraSessionFactory.getCqlTemplate().execute(delete);
        return 1;
    }

    @Override
    public long deleteAll() {
        ticketCatalog.findAll()
            .stream()
            .map(definition -> QueryBuilder
                .truncate(properties.getKeyspace(), definition.getProperties().getStorageName())
                .build()
                .setConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getConsistencyLevel()))
                .setSerialConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getSerialConsistencyLevel()))
                .setTimeout(Beans.newDuration(properties.getTimeout())))
            .forEach(delete -> {
                LOGGER.trace("Attempting to delete all via query [{}]", delete);
                cassandraSessionFactory.getCqlTemplate().execute(delete);
            });
        return -1;
    }

    @Override
    public Stream stream() {
        return ticketCatalog.findAll()
            .stream()
            .flatMap(this::streamCassandraTicketBy)
            .map(holder -> {
                val result = deserialize(holder);
                return decodeTicket(result);
            });
    }

    @Override
    public Stream getSessionsWithAttributes(final Map> queryAttributes) {
        val metadata = ticketCatalog.findTicketDefinition(TicketGrantingTicket.class).orElseThrow();
        val queryList = new ArrayList();
        queryAttributes.forEach((key, values) ->
            values.forEach(queryValue -> {
                var cql = "SELECT * FROM %s.%s WHERE prefix='%s' AND ".formatted(properties.getKeyspace(), metadata.getProperties().getStorageName(), metadata.getPrefix());
                cql += "attributes CONTAINS KEY '%s' AND attributes CONTAINS '%s' ALLOW FILTERING;".formatted(digestIdentifier(key), digestIdentifier(queryValue.toString()));
                queryList.add(cql);
            }));
        val rowMapper = new BeanPropertyRowMapper<>(CassandraTicketHolder.class, true);
        return queryList
            .stream()
            .flatMap(query -> cassandraSessionFactory.getCqlTemplate().queryForStream(query, rowMapper))
            .distinct()
            .map(holder -> {
                val result = deserialize(holder);
                return decodeTicket(result);
            })
            .filter(ticket -> !ticket.isExpired());
    }

    @Override
    public void destroy() throws Exception {
        cassandraSessionFactory.close();
    }

    @Override
    public void afterPropertiesSet() {
        createTablesIfNecessary();
    }

    private Ticket deserialize(final CassandraTicketHolder holder) {
        return ticketSerializationManager.deserializeTicket(holder.getData(), holder.getType());
    }

    private Collection findCassandraTicketBy(final TicketDefinition definition) {
        return findCassandraTicketBy(definition, null);
    }

    private Collection findCassandraTicketBy(final TicketDefinition definition, final String ticketId) {
        val builder = QueryBuilder.selectFrom(properties.getKeyspace(), definition.getProperties().getStorageName()).all();
        if (StringUtils.isNotBlank(ticketId)) {
            builder.whereColumn("id").isEqualTo(QueryBuilder.literal(ticketId)).limit(1);
        }
        val select = builder.build()
            .setConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getConsistencyLevel()))
            .setSerialConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getSerialConsistencyLevel()))
            .setTimeout(Beans.newDuration(properties.getTimeout()));
        LOGGER.trace("Attempting to locate ticket via query [{}]", select);
        val rowMapper = new BeanPropertyRowMapper<>(CassandraTicketHolder.class, true);
        return cassandraSessionFactory.getCqlTemplate().query(select, rowMapper);
    }

    private Stream streamCassandraTicketBy(final TicketDefinition definition) {
        val builder = QueryBuilder.selectFrom(properties.getKeyspace(), definition.getProperties().getStorageName()).all();
        val select = builder.build()
            .setConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getConsistencyLevel()))
            .setSerialConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getSerialConsistencyLevel()))
            .setTimeout(Beans.newDuration(properties.getTimeout()));
        LOGGER.trace("Attempting to locate ticket via query [{}]", select);
        val rowMapper = new BeanPropertyRowMapper<>(CassandraTicketHolder.class, true);
        return cassandraSessionFactory.getCqlTemplate().queryForStream(select, rowMapper);
    }

    private void createTablesIfNecessary() {
        val createNs = "CREATE KEYSPACE IF NOT EXISTS %s WITH replication = { 'class':'SimpleStrategy','replication_factor':1 };"
            .formatted(properties.getKeyspace()).stripIndent().strip();
        LOGGER.trace("Creating Cassandra keyspace with query [{}]", createNs);
        cassandraSessionFactory.getCqlTemplate().execute(createNs);

        ticketCatalog.findAll()
            .stream()
            .filter(metadata -> StringUtils.isNotBlank(metadata.getProperties().getStorageName()))
            .forEach(metadata -> {
                if (properties.isDropTablesOnStartup()) {
                    val drop = "DROP TABLE IF EXISTS %s.%s;".formatted(properties.getKeyspace(), metadata.getProperties().getStorageName());
                    LOGGER.trace("Dropping Cassandra table with query [{}]", drop);
                    cassandraSessionFactory.getCqlTemplate().execute(drop);
                }
                val createTable = "CREATE TABLE IF NOT EXISTS %s.%s(id text,type text,prefix text,attributes map,data text, PRIMARY KEY(id,type));"
                    .formatted(properties.getKeyspace(), metadata.getProperties().getStorageName());
                LOGGER.trace("Creating Cassandra table with query [{}]", createTable);
                cassandraSessionFactory.getCqlTemplate().execute(createTable);

                cassandraSessionFactory.getCqlTemplate().execute("DROP INDEX IF EXISTS " + metadata.getProperties().getStorageName() + "_entries_index");
                val createIndexAttributeNames = "CREATE INDEX " + metadata.getProperties().getStorageName() + "_entries_index ON "
                                                + properties.getKeyspace() + '.' + metadata.getProperties().getStorageName() + " (ENTRIES(attributes));";
                LOGGER.trace("Creating Cassandra index with query [{}]", createIndexAttributeNames);
                cassandraSessionFactory.getCqlTemplate().execute(createIndexAttributeNames);

                cassandraSessionFactory.getCqlTemplate().execute("DROP INDEX IF EXISTS " + metadata.getProperties().getStorageName() + "_values_index");
                val createIndexAttributeValues = "CREATE INDEX " + metadata.getProperties().getStorageName() + "_values_index ON "
                                                 + properties.getKeyspace() + '.' + metadata.getProperties().getStorageName() + " (VALUES(attributes));";
                LOGGER.trace("Creating Cassandra index with query [{}]", createIndexAttributeValues);
                cassandraSessionFactory.getCqlTemplate().execute(createIndexAttributeValues);

                cassandraSessionFactory.getCqlTemplate().execute("DROP INDEX IF EXISTS " + metadata.getProperties().getStorageName() + "_keys_index");
                val createIndexAttributeNames3 = "CREATE INDEX " + metadata.getProperties().getStorageName() + "_keys_index ON "
                                                 + properties.getKeyspace() + '.' + metadata.getProperties().getStorageName() + " (KEYS(attributes));";
                LOGGER.trace("Creating Cassandra index with query [{}]", createIndexAttributeNames3);
                cassandraSessionFactory.getCqlTemplate().execute(createIndexAttributeNames3);
            });
    }


    private void addTicketToCassandra(final Ticket ticket, final boolean inserting) throws Exception {
        LOGGER.debug("Adding ticket [{}]", ticket.getId());
        val metadata = ticketCatalog.find(ticket);
        LOGGER.trace("Located ticket definition [{}] in the ticket catalog", metadata);
        val encTicket = encodeTicket(ticket);
        val data = ticketSerializationManager.serializeTicket(encTicket);
        val ttl = getTimeToLive(ticket);
        var statement = (SimpleStatement) null;

        val attributeMap = (Map) collectAndDigestTicketAttributes(ticket);
        val attributesEncoded = attributeMap
            .entrySet()
            .stream()
            .map(entry -> {
                val entryValues = (List) entry.getValue();
                val valueList = entryValues.stream().map(Object::toString).collect(Collectors.joining(","));
                return Pair.of(entry.getKey(), valueList);
            })
            .collect(Collectors.toMap(Pair::getKey, v -> v.getValue().toString()));

        if (inserting) {
            val document = CassandraTicketHolder.builder()
                .id(encTicket.getId())
                .data(data)
                .prefix(ticket.getPrefix())
                .type(encTicket.getClass().getName())
                .attributes(attributesEncoded)
                .build();
            val json = MAPPER.writeValueAsString(document);
            statement = QueryBuilder.insertInto(properties.getKeyspace(), metadata.getProperties().getStorageName())
                .json(json)
                .usingTtl(ttl)
                .build();
        } else {
            statement = QueryBuilder.update(properties.getKeyspace(), metadata.getProperties().getStorageName())
                .usingTtl(ttl)
                .setColumn("data", QueryBuilder.literal(data))
                .setColumn("attributes", QueryBuilder.literal(attributesEncoded))
                .whereColumn("id").isEqualTo(QueryBuilder.literal(encTicket.getId()))
                .whereColumn("type").isEqualTo(QueryBuilder.literal(encTicket.getClass().getName()))
                .build();
        }
        statement = statement.setConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getConsistencyLevel()))
            .setSerialConsistencyLevel(DefaultConsistencyLevel.valueOf(properties.getSerialConsistencyLevel()))
            .setTimeout(Beans.newDuration(properties.getTimeout()));

        LOGGER.trace("Attempting to locate ticket via query [{}]", statement.getQuery());
        cassandraSessionFactory.getCqlTemplate().execute(statement);
        LOGGER.debug("Added ticket [{}]", encTicket.getId());
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy