All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache Maven / Gradle / Ivy

package com.amazonaws.encryptionsdk.caching;

import static java.util.Collections.max;

import javax.annotation.concurrent.GuardedBy;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.TreeSet;

import com.amazonaws.encryptionsdk.internal.Utils;
import com.amazonaws.encryptionsdk.model.DecryptionMaterials;
import com.amazonaws.encryptionsdk.model.EncryptionMaterials;

/**
 * A simple implementation of the {@link CryptoMaterialsCache} using a basic LRU cache.
 *
 * Example usage:
 * {@code
 *
 * LocalCryptoMaterialsCache cache = new LocalCryptoMaterialsCache(500);
 *
 * CachingCryptoMaterialsManager materialsManager = CachingCryptoMaterialsManager.builder()
 *      .setMaxAge(5, TimeUnit.MINUTES)
 *      .setCache(cache)
 *      .withMasterKeyProvider(myMasterKeyProvider)
 *      .build();
 *
 * byte[] data = new AwsCrypto().encryptData(materialsManager, plaintext).getResult();
 * }
 */
public class LocalCryptoMaterialsCache implements CryptoMaterialsCache {
    // The maximum number of entries to implicitly prune per access due to TTL expiration. We limit this to avoid
    // latency spikes when a large number of entries have expired since the last cache usage.
    private static final int MAX_TTL_PRUNE = 10;

    // Mockable time source, to allow us to test TTL pruning.
    // package access for tests
    // note: we're not using the java 8 time APIs in order to improve android compatibility
    MsClock clock = MsClock.WALLCLOCK;

    // The magic numbers here are the normal defaults for LinkedHashMap; we have to specify them explicitly if we are to
    // specify accessOrder=true, which enables LRU behavior
    private final LinkedHashMap cacheMap = new LinkedHashMap<>(
        /* capacity */  16, /* loadFactor */ 0.75f, /* accessOrder */ true
    );

    // This is a treeset sorted by TTL to allow us to quickly find expired entries
    private final TreeSet expirationQueue = new TreeSet<>(LocalCryptoMaterialsCache::compareEntries);

    private final int capacity;

    public LocalCryptoMaterialsCache(int capacity) {
        this.capacity = capacity;
    }

    private static int compareEntries(BaseEntry a, BaseEntry b) {
        int result;

        if (a == b) {
            return 0;
        }

        result = Long.compare(a.expirationTimestamp_, b.expirationTimestamp_);
        if (result != 0) {
            return result;
        }

        return Utils.compareObjectIdentity(a, b);
    }

    /**
     * A common base for both encrypt and decrypt entries
     */
    private class BaseEntry {
        final CacheIdentifier identifier_;
        final long expirationTimestamp_;
        final long creationTime = clock.timestamp();

        private BaseEntry(CacheIdentifier identifier, long expiration) {
            this.identifier_ = identifier;
            this.expirationTimestamp_ = expiration;
        }
    }

    /**
     * This wrapper just gives us a usable hashcode over our cache identifiers.
     */
    private static final class CacheIdentifier {
        private final byte[] identifier;
        private final int hashCode;

        private CacheIdentifier(byte[] passed_id) {
            this.identifier = passed_id.clone();
            this.hashCode = Arrays.hashCode(passed_id);
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            return Arrays.equals(identifier, ((CacheIdentifier)o).identifier);
        }

        @Override
        public int hashCode() {
            return hashCode;
        }
    }

    // Note: We take locks on both cache entries as well as the overall cache.
    // The lock order is overall cache -> cache entry; this means that the entry cannot call back into the parent cache
    // while holding its own lock.
    private final class EncryptCacheEntryInternal extends BaseEntry {
        private final EncryptionMaterials result;

        @GuardedBy("this")
        private UsageStats usageStats = UsageStats.ZERO;

        private EncryptCacheEntryInternal(
                CacheIdentifier identifier,
                long expiration,
                EncryptionMaterials result
        ) {
            super(identifier, expiration);

            this.result = result;
        }

        synchronized UsageStats addAndGetUsageStats(UsageStats delta) {
            this.usageStats = this.usageStats.add(delta);

            return this.usageStats;
        }
    }

    // When returning cache entries, we create a new object to represent the snapshot of usage stats at time of get.
    // This helps avoid races where two gets together push an entry over usage limits, and then both miss when they
    // see the entry over the limit.
    //
    // Not static as invalidate calls back into the cache.
    private final class EncryptCacheEntryExposed implements EncryptCacheEntry {
        private final UsageStats usageStats_;
        private final EncryptCacheEntryInternal internal_;

        private EncryptCacheEntryExposed(
                final UsageStats usageStats,
                final EncryptCacheEntryInternal internal
        ) {
            usageStats_ = usageStats;
            internal_ = internal;
        }

        @Override public UsageStats getUsageStats() {
            return usageStats_;
        }

        @Override public long getEntryCreationTime() {
            return internal_.creationTime;
        }

        @Override public EncryptionMaterials getResult() {
            return internal_.result;
        }

        @Override public void invalidate() {
            removeEntry(internal_);
        }
    }

    private final class DecryptCacheEntryInternal extends BaseEntry implements DecryptCacheEntry {
        final DecryptionMaterials result;

        private DecryptCacheEntryInternal(
                CacheIdentifier identifier,
                long expiration,
                DecryptionMaterials result
        ) {
            super(identifier, expiration);

            this.result = result;
        }

        @Override public DecryptionMaterials getResult() {
            return result;
        }

        @Override public void invalidate() {
            removeEntry(this);
        }

        @Override public long getEntryCreationTime() {
            return creationTime;
        }
    }

    /**
     * Removes an entry from the cache.
     * @param e the entry to remove
     */
    private synchronized void removeEntry(BaseEntry e) {
        expirationQueue.remove(e);
        // This does not update the LRU if the value does not match
        cacheMap.remove(e.identifier_, e);
    }

    /**
     * Prunes all TTL-expired entries, plus LRU entries until we are under capacity limits.
     */
    private synchronized void prune() {
        // Purge maxage-expired entries first, to avoid pruning entries by LRU unnecessarily when we're about to free
        // up space anyway.
        ttlPrune();

        while (cacheMap.size() > capacity) {
            removeEntry(cacheMap.values().iterator().next());
        }
    }

    /**
     * Prunes all TTL-expired entries. Does not check capacity.
     */
    private void ttlPrune() {
        int pruneCount = 0;
        long now = clock.timestamp();

        while (!expirationQueue.isEmpty() && expirationQueue.first().expirationTimestamp_ < now && pruneCount < MAX_TTL_PRUNE) {
            removeEntry(expirationQueue.first());
            pruneCount++;
        }
    }

    private synchronized  T getEntry(Class klass, byte[] identifier) {
        // Perform cache maintenance first
        ttlPrune();

        BaseEntry e = cacheMap.get(new CacheIdentifier(identifier));

        if (e == null) {
            return null;
        } else {
            if (e.expirationTimestamp_ < clock.timestamp()) {
                removeEntry(e);
                return null;
            }

            return klass.cast(e);
        }
    }

    private synchronized void putEntry(final BaseEntry entry) {
        BaseEntry oldEntry = cacheMap.put(entry.identifier_, entry);

        if (oldEntry != null) {
            expirationQueue.remove(oldEntry);
        }
        expirationQueue.add(entry);

        prune();
    }

    @Override public EncryptCacheEntry getEntryForEncrypt(
            byte[] cacheId, final UsageStats usageIncrement
    ) {
        EncryptCacheEntryInternal entry = getEntry(EncryptCacheEntryInternal.class, cacheId);

        if (entry != null) {
            UsageStats stats = entry.addAndGetUsageStats(usageIncrement);
            return new EncryptCacheEntryExposed(stats, entry);
        }

        return null;
    }

    @Override public EncryptCacheEntry putEntryForEncrypt(
            byte[] cacheId,
            EncryptionMaterials encryptionMaterials,
            CacheHint hint,
            UsageStats initialUsage
    ) {
        EncryptCacheEntryInternal entry = new EncryptCacheEntryInternal(
                new CacheIdentifier(cacheId),
                Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()),
                encryptionMaterials
        );

        entry.addAndGetUsageStats(initialUsage);

        putEntry(entry);

        return new EncryptCacheEntryExposed(initialUsage, entry);
    }

    @Override public DecryptCacheEntry getEntryForDecrypt(byte[] cacheId) {
        return getEntry(DecryptCacheEntryInternal.class, cacheId);
    }

    @Override public void putEntryForDecrypt(
            byte[] cacheId, DecryptionMaterials decryptionMaterials, CacheHint hint
    ) {
        DecryptCacheEntryInternal entry = new DecryptCacheEntryInternal(
                new CacheIdentifier(cacheId),
                Utils.saturatingAdd(clock.timestamp(), hint.getMaxAgeMillis()),
                decryptionMaterials
        );

        putEntry(entry);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy