com.amazonaws.athena.connector.lambda.security.CachableSecretsManager Maven / Gradle / Ivy
package com.amazonaws.athena.connector.lambda.security;
/*-
* #%L
* Amazon Athena Query Federation SDK
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest;
import com.amazonaws.services.secretsmanager.model.GetSecretValueResult;
import org.apache.arrow.util.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Since Athena may call your connector or UDF at a high TPS or concurrency you may want to have a short lived
* cache in front of SecretsManager to avoid bottlenecking on SecretsManager. This class offers such a cache. This class
* also has utilities for idetifying and replacing secrets in scripts. For example: MyString${WithSecret} would have
* ${WithSecret} replaced by the corresponding value of the secret in AWS Secrets Manager with that name.
*/
public class CachableSecretsManager
{
private static final Logger logger = LoggerFactory.getLogger(CachableSecretsManager.class);
private static final long MAX_CACHE_AGE_MS = 60_000;
protected static final int MAX_CACHE_SIZE = 10;
private static final String SECRET_PATTERN = "(\\$\\{[a-zA-Z0-9-_\\-]+\\})";
private static final String SECRET_NAME_PATTERN = "\\$\\{([a-zA-Z0-9-_\\-]+)\\}";
private static final Pattern PATTERN = Pattern.compile(SECRET_PATTERN);
private static final Pattern NAME_PATTERN = Pattern.compile(SECRET_NAME_PATTERN);
private final LinkedHashMap cache = new LinkedHashMap<>();
private final AWSSecretsManager secretsManager;
public CachableSecretsManager(AWSSecretsManager secretsManager)
{
this.secretsManager = secretsManager;
}
/**
* Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret}
* repalced by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found
* the function throws.
*
* @param rawString The string in which to find and replace/inject secrets.
* @return The processed rawString that has had all secrets replaced with their secret value from SecretsManager.
* Throws if any of the secrets can not be found.
*/
public String resolveSecrets(String rawString)
{
if (rawString == null) {
return rawString;
}
Matcher m = PATTERN.matcher(rawString);
String result = rawString;
while (m.find()) {
String nextSecret = m.group(1);
Matcher m1 = NAME_PATTERN.matcher(nextSecret);
m1.find();
result = result.replace(nextSecret, getSecret(m1.group(1)));
}
return result;
}
/**
* Retrieves a secret from SecretsManager, first checking the cache. Newly fetched secrets are added to the cache.
*
* @param secretName The name of the secret to retrieve.
* @return The value of the secret, throws if no such secret is found.
*/
public String getSecret(String secretName)
{
CacheEntry cacheEntry = cache.get(secretName);
if (cacheEntry == null || cacheEntry.getAge() > MAX_CACHE_AGE_MS) {
logger.info("getSecret: Resolving secret[{}].", secretName);
GetSecretValueResult secretValueResult = secretsManager.getSecretValue(new GetSecretValueRequest()
.withSecretId(secretName));
cacheEntry = new CacheEntry(secretName, secretValueResult.getSecretString());
evictCache(cache.size() >= MAX_CACHE_SIZE);
cache.put(secretName, cacheEntry);
}
return cacheEntry.getValue();
}
private void evictCache(boolean force)
{
Iterator> itr = cache.entrySet().iterator();
int removed = 0;
while (itr.hasNext()) {
CacheEntry entry = itr.next().getValue();
if (entry.getAge() > MAX_CACHE_AGE_MS) {
itr.remove();
removed++;
}
}
if (removed == 0 && force) {
//Remove the oldest since we found no expired entries
itr = cache.entrySet().iterator();
if (itr.hasNext()) {
itr.next();
itr.remove();
}
}
}
@VisibleForTesting
protected void addCacheEntry(String name, String value, long createTime)
{
cache.put(name, new CacheEntry(name, value, createTime));
}
private class CacheEntry
{
private final String name;
private final String value;
private final long createTime;
public CacheEntry(String name, String value)
{
this.value = value;
this.name = name;
this.createTime = System.currentTimeMillis();
}
public CacheEntry(String name, String value, long createTime)
{
this.value = value;
this.name = name;
this.createTime = createTime;
}
public String getValue()
{
return value;
}
public long getAge()
{
return System.currentTimeMillis() - createTime;
}
}
}