All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazon.redshift.plugin.JwtCredentialsProvider Maven / Gradle / Ivy

package com.amazon.redshift.plugin;

import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithWebIdentityRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithWebIdentityResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import com.amazonaws.util.json.Jackson;
import com.fasterxml.jackson.databind.JsonNode;
import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.RedshiftProperty;
import com.amazon.redshift.CredentialsHolder.IamMetadata;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.logger.RedshiftLogger;
import com.amazon.redshift.plugin.utils.RequestUtils;

import java.io.IOException;
import java.net.URL;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;


public abstract class JwtCredentialsProvider implements IPlugin
{
  	private static final String KEY_ROLE_ARN = "roleArn";
  	private static final String KEY_WEB_IDENTITY_TOKEN = "webIdentityToken";
    private static final String KEY_DURATION = "duration";
  	private static final String KEY_ROLE_SESSION_NAME = "roleSessionName";

    private static final String DEFAULT_ROLE_SESSION_NAME = "jwt_redshift_session";
    
    // Mandatory parameters
    protected String m_roleArn;
    protected String m_jwt;
    
    // Optional parameters
    protected String m_roleSessionName = DEFAULT_ROLE_SESSION_NAME;
    protected int m_duration;
    
    protected String m_dbUser;
/*    protected String m_dbGroups;
    protected String m_dbGroupsFilter;
    protected Boolean m_forceLowercase;
    protected Boolean m_autoCreate;
*/    
    protected String m_stsEndpoint;
    protected String m_region;
    protected RedshiftLogger m_log;
    protected Boolean m_disableCache = false;

    private static Map m_cache = new HashMap();
    private CredentialsHolder m_lastRefreshCredentials; // Used when cache is disable.

    /**
     * The custom log factory class.
     */
    private static final Class CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;

    /**
     * Log properties file name.
     */
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";

    /**
     * Log properties file path.
     */
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";

    /**
     * A custom context class loader which allows us to control which LogFactory is loaded.
     * Our CUSTOM_LOG_FACTORY_CLASS will divert any wire logging to NoOpLogger to suppress wire
     * messages being logged.
     */
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(
    		JwtCredentialsProvider.class.getClassLoader())
    {
        @Override
        public Class loadClass(String name) throws ClassNotFoundException
        {
            Class clazz = getParent().loadClass(name);
            if (org.apache.commons.logging.LogFactory.class.isAssignableFrom(clazz))
            {
                return CUSTOM_LOG_FACTORY_CLASS;
            }
            return clazz;
        }

        @Override
        public Enumeration getResources(String name) throws IOException
        {
            if (LogFactory.FACTORY_PROPERTIES.equals(name))
            {
                // make sure not load any other commons-logging.properties files
                return Collections.enumeration(Collections.emptyList());
            }
            return super.getResources(name);
        }

        @Override
        public URL getResource(String name)
        {
            if (LOG_PROPERTIES_FILE_PATH.equals(name))
            {
                return JwtCredentialsProvider.class.getResource(LOG_PROPERTIES_FILE_NAME);
            }
            return super.getResource(name);
        }
    };

    // If IDP required to look into JWT then decode it and
    // get any custom claim/tag in it.
    protected abstract String processJwt(String jwt) throws IOException;

    @Override
    public void addParameter(String key, String value)
    {
	      if (RedshiftLogger.isEnable())
	    		m_log.logDebug("key: {0}", key);
    	
        if (KEY_ROLE_ARN.equalsIgnoreCase(key))
        {
            m_roleArn = value;
        }
        else if (KEY_WEB_IDENTITY_TOKEN.equalsIgnoreCase(key))
        {
            m_jwt = value;
        }
        else if (KEY_ROLE_SESSION_NAME.equalsIgnoreCase(key))
        {
            m_roleSessionName = value;
        }
        else if (KEY_DURATION.equalsIgnoreCase(key))
        {
            m_duration = Integer.parseInt(value);
        }
        else if (RedshiftProperty.DB_USER.getName().equalsIgnoreCase(key))
        {
        	// Do not read dbUser from connection, as it derives from token.
          // m_dbUser = value;
        }
/*        else if (RedshiftProperty.DB_GROUPS.getName().equalsIgnoreCase(key))
        {
            m_dbGroups = value;
        }
        else if (RedshiftProperty.DB_GROUPS_FILTER.getName().equalsIgnoreCase(key))
        {
            m_dbGroupsFilter = value;
        }
        else if (RedshiftProperty.FORCE_LOWERCASE.getName().equalsIgnoreCase(key))
        {
            m_forceLowercase = Boolean.valueOf(value);
        }
        else if (RedshiftProperty.USER_AUTOCREATE.getName().equalsIgnoreCase(key))
        {
            m_autoCreate = Boolean.valueOf(value);
        }
*/        
        else if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(key))
        {
            m_region = value;
        }
        else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(key))
        {
            m_stsEndpoint = value;
        }
        else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(key))
        {
            m_disableCache = Boolean.valueOf(value);
        }
    }

    @Override
    public void setLogger(RedshiftLogger log)
    {
        m_log = log;
    }

    @Override
    public CredentialsHolder getCredentials()
    {
    		CredentialsHolder credentials = null;
    		
  			if(!m_disableCache) {
	        String key = getCacheKey();
	        credentials = m_cache.get(key);
  			}
  			
        if (credentials == null || credentials.isExpired())
        {
          if(RedshiftLogger.isEnable()) 
            m_log.logInfo("JWT getCredentials NOT from cache");
        	
          synchronized(this) {
          	
          	refresh();
          	
          	if(m_disableCache) {
          		credentials = m_lastRefreshCredentials;
          		m_lastRefreshCredentials = null;
          	}
          }
        }
        else {
          credentials.setRefresh(false);
          if(RedshiftLogger.isEnable()) 
            m_log.logInfo("SAML getCredentials from cache");
        }

  			if(!m_disableCache) {
  				credentials = m_cache.get(getCacheKey());
  			}
        
        // if dbUser argument has been passed in the connection string, add it to metadata.
/*        if (!StringUtils.isNullOrEmpty(m_dbUser))
        {
            credentials.getThisMetadata().setDbUser(this.m_dbUser);
        } */

        if (credentials == null)
        {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
        return credentials;
    }

    @Override
    public void refresh()
    {
        // Get the current thread and set the context loader with our custom load class method.
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();

        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);

        try
        {
            String jwt = processJwt(m_jwt);
            String[] decodedjwt = decodeJwt(m_jwt);

            if (RedshiftLogger.isEnable())
          		m_log.logDebug(
                  String.format("JWT : %s", jwt));
            
            m_dbUser = deriveDatabaseUser(decodedjwt);

            AssumeRoleWithWebIdentityRequest jwtRequest = new AssumeRoleWithWebIdentityRequest();
            jwtRequest.setWebIdentityToken(jwt);
            jwtRequest.setRoleArn(m_roleArn);
            jwtRequest.setRoleSessionName(m_roleSessionName);
            if (m_duration > 0)
            {
            	jwtRequest.setDurationSeconds(m_duration);
            }

            AWSCredentialsProvider p = new AWSStaticCredentialsProvider(new AnonymousAWSCredentials());
            AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard();
            
            AWSSecurityTokenService stsSvc =
            		RequestUtils.buildSts(m_stsEndpoint, m_region, builder, p, m_log);
            AssumeRoleWithWebIdentityResult result = stsSvc.assumeRoleWithWebIdentity(jwtRequest);
            Credentials cred = result.getCredentials();
            Date expiration = cred.getExpiration();
            AWSCredentials c = new BasicSessionCredentials(cred.getAccessKeyId(),
                    cred.getSecretAccessKey(), cred.getSessionToken());
            CredentialsHolder credentials = CredentialsHolder.newInstance(c, expiration);
            credentials.setMetadata(readMetadata());
            credentials.setRefresh(true);
            
            if(!m_disableCache)
            	m_cache.put(getCacheKey(), credentials);
            else
              m_lastRefreshCredentials = credentials;
        } 
        catch (Exception e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
        	
          throw new SdkClientException("JWT error: " + e.getMessage(), e);
        }
        finally
        {
          currentThread.setContextClassLoader(cl);
        }
    }

    @Override
    public String getPluginSpecificCacheKey() {
    	// Override this in each derived plugin.
    	return "";
    }
    
    private String getCacheKey()
    {
  		String pluginSpecificKey = getPluginSpecificCacheKey();
    	
      return  m_roleArn + m_jwt +  m_roleSessionName + m_duration + pluginSpecificKey;
    }

    protected void checkRequiredParameters() throws IOException
    {
        if (StringUtils.isNullOrEmpty(m_roleArn))
        {
            throw new IOException("Missing required property: " + KEY_ROLE_ARN);
        }
        if (StringUtils.isNullOrEmpty(m_jwt))
        {
            throw new IOException("Missing required property: " + KEY_WEB_IDENTITY_TOKEN);
        }
    }
    
    protected String[] decodeJwt(String jwt) {
    	if (jwt == null)
    		return null;
    	
    	// Base64(JOSE header).Base64(Payload).Base64(Signature)
    	String[] headerPayloadSig = jwt.split("\\.");
    	
    	if (headerPayloadSig.length == 3) {
	    	String header = new String(Base64.decodeBase64(headerPayloadSig[0]));
	    	String payload = new String(Base64.decodeBase64(headerPayloadSig[1]));
	    	String signature = headerPayloadSig[2];
	    	
        if (RedshiftLogger.isEnable())
      		m_log.logDebug(
              String.format("Decoded JWT : Header: %s payload: %s signature:%s", header, payload, signature));
	    	
	    	return new String[] {header, payload, signature};
    	}
    	else
    		return null;
    	
    }
    
    protected String deriveDatabaseUser(String[] decodedJwt) {
      String databaseUser = null;
      
    	if (decodedJwt != null && decodedJwt.length == 3) {
    		// Use payload
    		String payload = decodedJwt[1];
    		String[] claims = { "DbUser", "upn", "preferred_username", "email" };
    		
        JsonNode entityJson = Jackson.jsonNodeOf(payload);
        JsonNode userTokenField;
    		
    		for(String claim : claims) {
    			userTokenField = entityJson.findValue(claim);
    			if (userTokenField != null) {
    				databaseUser = userTokenField.textValue();
    				if (!StringUtils.isNullOrEmpty(databaseUser)) {
    					
    	        if (RedshiftLogger.isEnable())
    	      		m_log.logDebug(
    	              String.format("JWT claim: %s as database user: %s", claim, databaseUser));
    					
    					break;
    				}
    			}
    		} // Loop
    		
				if (StringUtils.isNullOrEmpty(databaseUser)) {
	        throw new SdkClientException("No database user claim found in JWT");
				}
				
	    	return databaseUser;
    	}
    	else {
        throw new SdkClientException("JWT decoding error");
    	}
    }
    
    private IamMetadata readMetadata() 
    {
        IamMetadata metadata = new IamMetadata();
        
        metadata.setDbUser(m_dbUser);
        metadata.setAutoCreate(true);
        
        return metadata;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy