All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.hmsonline.storm.cassandra.trident.CassandraMapState Maven / Gradle / Ivy

package com.hmsonline.storm.cassandra.trident;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import storm.trident.state.JSONNonTransactionalSerializer;
import storm.trident.state.JSONOpaqueSerializer;
import storm.trident.state.JSONTransactionalSerializer;
import storm.trident.state.OpaqueValue;
import storm.trident.state.Serializer;
import storm.trident.state.State;
import storm.trident.state.StateFactory;
import storm.trident.state.StateType;
import storm.trident.state.TransactionalValue;
import storm.trident.state.map.CachedMap;
import storm.trident.state.map.IBackingMap;
import storm.trident.state.map.MapState;
import storm.trident.state.map.NonTransactionalMap;
import storm.trident.state.map.OpaqueMap;
import storm.trident.state.map.SnapshottableMap;
import storm.trident.state.map.TransactionalMap;
import backtype.storm.task.IMetricsContext;
import backtype.storm.tuple.Values;

import com.google.common.base.Function;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.hmsonline.storm.cassandra.StormCassandraConstants;
import com.netflix.astyanax.AstyanaxConfiguration;
import com.netflix.astyanax.AstyanaxContext;
import com.netflix.astyanax.Keyspace;
import com.netflix.astyanax.MutationBatch;
import com.netflix.astyanax.connectionpool.ConnectionPoolConfiguration;
import com.netflix.astyanax.connectionpool.ConnectionPoolMonitor;
import com.netflix.astyanax.connectionpool.NodeDiscoveryType;
import com.netflix.astyanax.connectionpool.OperationResult;
import com.netflix.astyanax.connectionpool.exceptions.ConnectionException;
import com.netflix.astyanax.connectionpool.impl.ConnectionPoolConfigurationImpl;
import com.netflix.astyanax.connectionpool.impl.CountingConnectionPoolMonitor;
import com.netflix.astyanax.impl.AstyanaxConfigurationImpl;
import com.netflix.astyanax.model.ColumnFamily;
import com.netflix.astyanax.model.ColumnList;
import com.netflix.astyanax.model.Composite;
import com.netflix.astyanax.model.Rows;
import com.netflix.astyanax.model.Row;
import com.netflix.astyanax.query.RowQuery;
import com.netflix.astyanax.query.RowSliceQuery;
import com.netflix.astyanax.serializers.CompositeSerializer;
import com.netflix.astyanax.serializers.StringSerializer;
import com.netflix.astyanax.thrift.ThriftFamilyFactory;

public class CassandraMapState implements IBackingMap {
    @SuppressWarnings("unused")
    private static final Logger LOG = LoggerFactory.getLogger(CassandraMapState.class);

    @SuppressWarnings("rawtypes")
    private static final Map DEFAULT_SERIALZERS = Maps.newHashMap();

    public static final String CASSANDRA_CLUSTER_NAME = "cassandra.clusterName";
    public static final String ASTYANAX_CONFIGURATION = "astyanax.configuration";
    public static final String ASTYANAX_CONNECTION_POOL_CONFIGURATION = "astyanax.connectionPoolConfiguration";
    public static final String ASTYANAX_CONNECTION_POOL_MONITOR = "astyanax.connectioPoolMonitor";

    private final Map DEFAULTS = new ImmutableMap.Builder()
            .put(CASSANDRA_CLUSTER_NAME, "ClusterName")
            .put(ASTYANAX_CONFIGURATION, new AstyanaxConfigurationImpl().setDiscoveryType(NodeDiscoveryType.RING_DESCRIBE))
            .put(ASTYANAX_CONNECTION_POOL_CONFIGURATION,
                    new ConnectionPoolConfigurationImpl("MyConnectionPool").setMaxConnsPerHost(1))
            .put(ASTYANAX_CONNECTION_POOL_MONITOR, new CountingConnectionPoolMonitor()).build();

    private Options options;
    private Serializer serializer;
    protected Keyspace keyspace;

    static {
        DEFAULT_SERIALZERS.put(StateType.NON_TRANSACTIONAL, new JSONNonTransactionalSerializer());
        DEFAULT_SERIALZERS.put(StateType.TRANSACTIONAL, new JSONTransactionalSerializer());
        DEFAULT_SERIALZERS.put(StateType.OPAQUE, new JSONOpaqueSerializer());
    }

    protected AstyanaxContext createContext(Map config) {
        Map settings = Maps.newHashMap();
        for (Map.Entry defaultEntry : DEFAULTS.entrySet()) {
            if (config.containsKey(defaultEntry.getKey())) {
                settings.put(defaultEntry.getKey(), config.get(defaultEntry.getKey()));
            } else {
                settings.put(defaultEntry.getKey(), defaultEntry.getValue());
            }
        }
        // in the defaults case, we don't know the seed hosts until context
        // creation time
        if (settings.get(ASTYANAX_CONNECTION_POOL_CONFIGURATION) instanceof ConnectionPoolConfigurationImpl) {
            ConnectionPoolConfigurationImpl cpConfig = (ConnectionPoolConfigurationImpl) settings
                    .get(ASTYANAX_CONNECTION_POOL_CONFIGURATION);
            cpConfig.setSeeds((String) config.get(StormCassandraConstants.CASSANDRA_HOST));
        }

        return new AstyanaxContext.Builder()
                .forCluster((String) settings.get(CASSANDRA_CLUSTER_NAME))
                .forKeyspace((String) config.get(StormCassandraConstants.CASSANDRA_STATE_KEYSPACE))
                .withAstyanaxConfiguration((AstyanaxConfiguration) settings.get(ASTYANAX_CONFIGURATION))
                .withConnectionPoolConfiguration(
                        (ConnectionPoolConfiguration) settings.get(ASTYANAX_CONNECTION_POOL_CONFIGURATION))
                .withConnectionPoolMonitor((ConnectionPoolMonitor) settings.get(ASTYANAX_CONNECTION_POOL_MONITOR))
                .buildKeyspace(ThriftFamilyFactory.getInstance());
    }

    @SuppressWarnings("serial")
    public static class Options implements Serializable {

        public Serializer serializer = null;
        public int localCacheSize = 5000;
        public String globalKey = "globalkey";
        public String columnFamily = "cassandra_state";
        public String columnName = "default_cassandra_state";
        public String clientConfigKey = "cassandra.config";
        public Integer ttl = 86400; // 1 day

    }

    @SuppressWarnings("rawtypes")
    public static StateFactory opaque() {
        Options options = new Options();
        return opaque(options);
    }

    @SuppressWarnings("rawtypes")
    public static StateFactory opaque(Options opts) {
        return new Factory(StateType.OPAQUE, opts);
    }

    @SuppressWarnings("rawtypes")
    public static StateFactory transactional() {
        Options options = new Options();
        return transactional(options);
    }

    @SuppressWarnings("rawtypes")
    public static StateFactory transactional(Options opts) {
        return new Factory(StateType.TRANSACTIONAL, opts);
    }

    public static StateFactory nonTransactional() {
        Options options = new Options();
        return nonTransactional(options);
    }

    public static StateFactory nonTransactional(Options opts) {
        return new Factory(StateType.NON_TRANSACTIONAL, opts);
    }

    protected static class Factory implements StateFactory {
        private static final long serialVersionUID = -2644278289157792107L;
        private StateType stateType;
        private Options options;

        @SuppressWarnings({ "rawtypes", "unchecked" })
        public Factory(StateType stateType, Options options) {
            this.stateType = stateType;
            this.options = options;

            if (this.options.serializer == null) {
                this.options.serializer = DEFAULT_SERIALZERS.get(stateType);
            }

            if (this.options.serializer == null) {
                throw new RuntimeException("Serializer should be specified for type: " + stateType);
            }
        }

        @SuppressWarnings({ "rawtypes", "unchecked" })
        public State makeState(Map conf, IMetricsContext metrics, int partitionIndex, int numPartitions) {
            CassandraMapState state = new CassandraMapState(options, conf);

            CachedMap cachedMap = new CachedMap(state, options.localCacheSize);

            MapState mapState;
            if (stateType == StateType.NON_TRANSACTIONAL) {
                mapState = NonTransactionalMap.build(cachedMap);
            } else if (stateType == StateType.OPAQUE) {
                mapState = OpaqueMap.build(cachedMap);
            } else if (stateType == StateType.TRANSACTIONAL) {
                mapState = TransactionalMap.build(cachedMap);
            } else {
                throw new RuntimeException("Unknown state type: " + stateType);
            }

            return new SnapshottableMap(mapState, new Values(options.globalKey));
        }

    }

    @SuppressWarnings({ "rawtypes", "unchecked" })
    public CassandraMapState(Options options, Map conf) {
        this.options = options;
        this.serializer = options.serializer;
        AstyanaxContext context = createContext((Map) conf.get(options.clientConfigKey));
        context.start();
        this.keyspace = context.getEntity();
    }

    @Override
    public List multiGet(List> keys) {
        Collection keyNames = toKeyNames(keys);
        ColumnFamily cf = new ColumnFamily(this.options.columnFamily,
                CompositeSerializer.get(), StringSerializer.get());
        RowSliceQuery query = this.keyspace.prepareQuery(cf).getKeySlice(keyNames);

        Rows result = null;
        try {
            result = query.execute().getResult();
        } catch (ConnectionException e) {
            //TODO throw a specific error.
            throw new RuntimeException(e);
        }
        Map, byte[]> resultMap = new HashMap, byte[]>();
        if (result != null && result.size() > 0) {
            Collection rowKeys = result.getKeys();
            for (Composite rowKey : rowKeys) {
                List dimensions = new ArrayList();
                for (int i = 0; i < rowKey.size(); i++) {
                    dimensions.add(rowKey.get(i, StringSerializer.get()));
                }
                resultMap.put(dimensions, result.getRow(rowKey).getColumns().getByteArrayValue(this.options.columnName, null));
            }
        }

        List values = new ArrayList();
        for (List key : keys) {
            List stringKey = toKeyStrings(key);
            byte[] bytes = resultMap.get(stringKey);
            if (bytes != null) {
                values.add(serializer.deserialize(bytes));
            } else {
                values.add(null);
            }
        }

        return values;
    }

    @Override
    public void multiPut(List> keys, List values) {
        MutationBatch mutation = this.keyspace.prepareMutationBatch();
        ColumnFamily cf = new ColumnFamily(this.options.columnFamily,
                CompositeSerializer.get(), StringSerializer.get());

        for (int i = 0; i < keys.size(); i++) {
            Composite keyName = toKeyName(keys.get(i));
            byte[] bytes = serializer.serialize(values.get(i));
            if (options.ttl != null && options.ttl > 0) {
                mutation.withRow(cf, keyName).putColumn(this.options.columnName, bytes, options.ttl);
            } else {
                mutation.withRow(cf, keyName).putColumn(this.options.columnName, bytes);
            }
        }
        try {
            mutation.execute();
        } catch (ConnectionException e) {
            throw new RuntimeException("Batch mutation for state failed.", e);
        }
    }

    private Collection toKeyNames(List> keys) {
        return Collections2.transform(keys, new Function, Composite>() {
            @Override
            public Composite apply(List key) {
                return toKeyName(key);
            }
        });
    }

    private Composite toKeyName(List key) {
        Composite keyName = new Composite();
        List keyStrings = toKeyStrings(key);
        for (String componentString : keyStrings) {
            keyName.addComponent(componentString, StringSerializer.get());
        }
        return keyName;
    }

    private ArrayList toKeyStrings(List key) {
        ArrayList keyStrings = new ArrayList();
        for (int i = 0; i < key.size(); i++){
            Object component = key.get(i);
            if (component == null) {
                component = "[NULL]";
            }
            keyStrings.add(component.toString());
        }
        return keyStrings;
    }
}