All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.amazon.redshift.plugin.SamlCredentialsProvider Maven / Gradle / Ivy

There is a newer version: 2.1.0.30
Show newest version
package com.amazon.redshift.plugin;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.CredentialsHolder.IamMetadata;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.RedshiftProperty;
import com.amazon.redshift.core.IamHelper;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.logger.LogLevel;
import com.amazon.redshift.logger.RedshiftLogger;
import com.amazon.redshift.plugin.utils.RequestUtils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

public abstract class SamlCredentialsProvider extends IdpCredentialsProvider implements IPlugin
{
    protected static final String KEY_IDP_HOST = "idp_host";
    private static final String KEY_IDP_PORT = "idp_port";
    private static final String KEY_DURATION = "duration";
    private static final String KEY_PREFERRED_ROLE = "preferred_role";

    protected String m_userName;
    protected String m_password;
    protected String m_idpHost;
    protected int m_idpPort = 443;
    protected int m_duration;
    protected String m_preferredRole;
    protected String m_dbUser;
    protected String m_dbGroups;
    protected String m_dbGroupsFilter;
    protected Boolean m_forceLowercase;
    protected Boolean m_autoCreate;
    protected String m_stsEndpoint;
    protected String m_region;
    protected Boolean m_disableCache = false;
    protected Boolean m_groupFederation = false;

    private static Map m_cache = new HashMap();
    
    private CredentialsHolder m_lastRefreshCredentials; // Used when cache is disable.

    /**
     * The custom log factory class.
     */
    private static final Class CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;

    /**
     * Log properties file name.
     */
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";

    /**
     * Log properties file path.
     */
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";

    /**
     * A custom context class loader which allows us to control which LogFactory is loaded.
     * Our CUSTOM_LOG_FACTORY_CLASS will divert any wire logging to NoOpLogger to suppress wire
     * messages being logged.
     */
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(
        SamlCredentialsProvider.class.getClassLoader())
    {
        @Override
        public Class loadClass(String name) throws ClassNotFoundException
        {
            Class clazz = getParent().loadClass(name);
            if (org.apache.commons.logging.LogFactory.class.isAssignableFrom(clazz))
            {
                return CUSTOM_LOG_FACTORY_CLASS;
            }
            return clazz;
        }

        @Override
        public Enumeration getResources(String name) throws IOException
        {
            if (LogFactory.FACTORY_PROPERTIES.equals(name))
            {
                // make sure not load any other commons-logging.properties files
                return Collections.enumeration(Collections.emptyList());
            }
            return super.getResources(name);
        }

        @Override
        public URL getResource(String name)
        {
            if (LOG_PROPERTIES_FILE_PATH.equals(name))
            {
                return SamlCredentialsProvider.class.getResource(LOG_PROPERTIES_FILE_NAME);
            }
            return super.getResource(name);
        }
    };

    protected abstract String getSamlAssertion() throws IOException;

    @Override
    public void addParameter(String key, String value)
    {
	      if (RedshiftLogger.isEnable())
	    		m_log.logDebug("key: {0}", key);
    	
        if (RedshiftProperty.UID.getName().equalsIgnoreCase(key)
                || RedshiftProperty.USER.getName().equalsIgnoreCase(key))
        {
            m_userName = value;
        }
        else if (RedshiftProperty.PWD.getName().equalsIgnoreCase(key)
                || RedshiftProperty.PASSWORD.getName().equalsIgnoreCase(key))
        {
            m_password = value;
        }
        else if (KEY_IDP_HOST.equalsIgnoreCase(key))
        {
            m_idpHost = value;
        }
        else if (KEY_IDP_PORT.equalsIgnoreCase(key))
        {
            m_idpPort = Integer.parseInt(value);
        }
        else if (KEY_DURATION.equalsIgnoreCase(key))
        {
            m_duration = Integer.parseInt(value);
        }
        else if (KEY_PREFERRED_ROLE.equalsIgnoreCase(key))
        {
            m_preferredRole = value;
        }
        else if (KEY_SSL_INSECURE.equalsIgnoreCase(key))
        {
            m_sslInsecure = Boolean.parseBoolean(value);
        }
        else if (RedshiftProperty.DB_USER.getName().equalsIgnoreCase(key))
        {
            m_dbUser = value;
        }
        else if (RedshiftProperty.DB_GROUPS.getName().equalsIgnoreCase(key))
        {
            m_dbGroups = value;
        }
        else if (RedshiftProperty.DB_GROUPS_FILTER.getName().equalsIgnoreCase(key))
        {
            m_dbGroupsFilter = value;
        }
        else if (RedshiftProperty.FORCE_LOWERCASE.getName().equalsIgnoreCase(key))
        {
            m_forceLowercase = Boolean.valueOf(value);
        }
        else if (RedshiftProperty.USER_AUTOCREATE.getName().equalsIgnoreCase(key))
        {
            m_autoCreate = Boolean.valueOf(value);
        }
        else if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(key))
        {
            m_region = value;
        }
        else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(key))
        {
            m_stsEndpoint = value;
        }
        else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(key))
        {
            m_disableCache = Boolean.valueOf(value);
        }
    }

    @Override
    public void setLogger(RedshiftLogger log)
    {
        m_log = log;
    }

    @Override
    public int getSubType()
    {
        return IamHelper.SAML_PLUGIN;
    }
    
    @Override
    public CredentialsHolder getCredentials()
    {
    		CredentialsHolder credentials = null;
    		
    		if(!m_disableCache) {
    			String key = getCacheKey();
    			credentials = m_cache.get(key);
    		}
    		
        if (credentials == null || credentials.isExpired())
        {
	          if(RedshiftLogger.isEnable()) 
	            m_log.logInfo("SAML getCredentials NOT from cache");
        	
	          synchronized(this) {
	          	
	          	refresh();
	          	
	          	if(m_disableCache) {
	          		credentials = m_lastRefreshCredentials;
	          		m_lastRefreshCredentials = null;
	          	}
	          }
        }
        else {
          credentials.setRefresh(false);
          if(RedshiftLogger.isEnable()) 
            m_log.logInfo("SAML getCredentials from cache");
        }
        
        if(!m_disableCache) {
          // if the SAML response has dbUser argument, it will be picked up at this point.
        	credentials = m_cache.get(getCacheKey());
        }

        // if dbUser argument has been passed in the connection string, add it to metadata.
        if (!StringUtils.isNullOrEmpty(m_dbUser))
        {
            credentials.getThisMetadata().setDbUser(this.m_dbUser);
        }

        if (credentials == null)
        {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
      
        if(RedshiftLogger.isEnable()) {
            Date now = new Date();
            m_log.logInfo(now + ": Using entry for SamlCredentialsProvider.getCredentials cache with expiration " + credentials.getExpiration());
        }
        
        return credentials;
    }

    @Override
    public void refresh()
    {
        // Get the current thread and set the context loader with our custom load class method.
        Thread currentThread = Thread.currentThread();
        ClassLoader cl = currentThread.getContextClassLoader();

        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);

        try
        {
            String samlAssertion = getSamlAssertion();

            if (RedshiftLogger.isEnable())
            		m_log.logDebug("SamlCredentialsProvider: Received SAML assertion of length={0}", samlAssertion != null ? samlAssertion.length() : -1);
                    
            final Pattern SAML_PROVIDER_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+");
            final Pattern ROLE_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:role/\\S+");
            Document doc = parse(Base64.decodeBase64(samlAssertion));
            XPath xPath =  XPathFactory.newInstance().newXPath();
            String expression = "//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()";
            NodeList nodeList = (NodeList) xPath.compile(expression)
                    .evaluate(doc, XPathConstants.NODESET);

            Map roles = new HashMap();
            if (nodeList != null)
            {
                for (int i = 0; i < nodeList.getLength(); ++i)
                {
                    Node node = nodeList.item(i);
                    String value = node.getNodeValue();
                    String[] arns = value.split(",");
                    if (arns.length >= 2)
                    {
                        String provider = null;
                        String role = null;
                        for (String arn : arns)
                        {
                            Matcher providerMatcher = SAML_PROVIDER_PATTERN.matcher(arn);
                            if (providerMatcher.find())
                            {
                                provider = providerMatcher.group(0);
                                continue;
                            }
                            Matcher roleMatcher = ROLE_PATTERN.matcher(arn);
                            if (roleMatcher.find())
                            {
                                role = roleMatcher.group(0);
                            }
                        }
                        if (!StringUtils.isNullOrEmpty(role) && !StringUtils.isNullOrEmpty(provider))
                        {
                            roles.put(role, provider);
                        }
                    }
                }
            }

            if (roles.isEmpty())
            {
                throw new SdkClientException("No role found in SamlAssertion: " + samlAssertion);
            }

            String roleArn;
            String principal;
            if (m_preferredRole != null)
            {
                roleArn = m_preferredRole;
                principal = roles.get(m_preferredRole);
                if (principal == null)
                {
                    throw new SdkClientException("Preferred role not found in SamlAssertion: " + samlAssertion);
                }
            }
            else
            {
                Map.Entry entry = roles.entrySet().iterator().next();
                roleArn = entry.getKey();
                principal = entry.getValue();
            }

            AssumeRoleWithSAMLRequest samlRequest = new AssumeRoleWithSAMLRequest();
            samlRequest.setSAMLAssertion(samlAssertion);
            samlRequest.setRoleArn(roleArn);
            samlRequest.setPrincipalArn(principal);
            if (m_duration > 0)
            {
                samlRequest.setDurationSeconds(m_duration);
            }

            AWSCredentialsProvider p = new AWSStaticCredentialsProvider(new AnonymousAWSCredentials());
            AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard();
            ClientConfiguration config = null;
            builder.withClientConfiguration(config);
            
            AWSSecurityTokenService stsSvc =
            		RequestUtils.buildSts(m_stsEndpoint, m_region, builder, p, m_log);
            
            AssumeRoleWithSAMLResult result = stsSvc.assumeRoleWithSAML(samlRequest);
            Credentials cred = result.getCredentials();
            Date expiration = cred.getExpiration();
            AWSCredentials c = new BasicSessionCredentials(cred.getAccessKeyId(),
                    cred.getSecretAccessKey(), cred.getSessionToken());
            CredentialsHolder credentials = CredentialsHolder.newInstance(c, expiration);
            credentials.setMetadata(readMetadata(doc));
            credentials.setRefresh(true);
            
            if(!m_disableCache)
            	m_cache.put(getCacheKey(), credentials);
            else
              m_lastRefreshCredentials = credentials;
        }
        catch (IOException e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
        	
          throw new SdkClientException("SAML error: " + e.getMessage(), e);
        }
        catch (SAXException e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
          
          throw new SdkClientException("SAML error: " + e.getMessage(), e);
        }
        catch (ParserConfigurationException e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
        	
          throw new SdkClientException("SAML error: " + e.getMessage(), e);
        }
        catch (XPathExpressionException e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
        	
          throw new SdkClientException("SAML error: " + e.getMessage(), e);
        }
        catch (Exception e)
        {
          if (RedshiftLogger.isEnable())
        		m_log.logError(e);
        	
          throw new SdkClientException("SAML error: " + e.getMessage(), e);
        }
        finally
        {
          currentThread.setContextClassLoader(cl);
        }
    }
    
    @Override
    public String getPluginSpecificCacheKey() {
    	// Override this in each derived plugin such as Azure, Browser, Okta, Ping etc.
    	return "";
    }

    @Override
    public String getIdpToken() {
    	String samlAssertion = null;
      // Get the current thread and set the context loader with our custom load class method.
      Thread currentThread = Thread.currentThread();
      ClassLoader cl = currentThread.getContextClassLoader();

      Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);

      try
      {
          samlAssertion = getSamlAssertion();

          if (RedshiftLogger.isEnable())
              m_log.logDebug("SamlCredentialsProvider: Got SAML assertion of " +
                      "length={0}", samlAssertion != null ? samlAssertion.length() : -1);
      }
      catch (IOException e)
      {
        if (RedshiftLogger.isEnable())
      		m_log.logError(e);
      	
        throw new SdkClientException("SAML error: " + e.getMessage(), e);
      }
      catch (Exception e)
      {
        if (RedshiftLogger.isEnable())
      		m_log.logError(e);
      	
        throw new SdkClientException("SAML error: " + e.getMessage(), e);
      }
      finally
      {
        currentThread.setContextClassLoader(cl);
      }
      
      return  samlAssertion;
    }
    
    @Override
    public void setGroupFederation(boolean groupFederation) {
    	m_groupFederation = groupFederation;
    }
    
    @Override
    public String getCacheKey()
    {
    		String pluginSpecificKey = getPluginSpecificCacheKey();
        return m_userName + m_password + m_idpHost + m_idpPort + m_duration + m_preferredRole + pluginSpecificKey;
    }

    private IamMetadata readMetadata(Document doc) throws XPathExpressionException
    {
        IamMetadata metadata = new IamMetadata();
        XPath xPath = XPathFactory.newInstance().newXPath();


        List attributeValues = GetSAMLAttributeValues(xPath, doc,
            "https://redshift.amazon.com/SAML/Attributes/AllowDbUserOverride");
        if (!attributeValues.isEmpty())
        {
            metadata.setAllowDbUserOverride(Boolean.valueOf(attributeValues.get(0)));
        }

        attributeValues = GetSAMLAttributeValues(xPath, doc,
                "https://redshift.amazon.com/SAML/Attributes/DbUser");
        if (!attributeValues.isEmpty())
        {
            metadata.setSamlDbUser(attributeValues.get(0));
        }
        else
        {
            attributeValues = GetSAMLAttributeValues(xPath, doc,
                    "https://aws.amazon.com/SAML/Attributes/RoleSessionName");
            if (!attributeValues.isEmpty())
            {
                metadata.setSamlDbUser(attributeValues.get(0));
            }
        }

        attributeValues = GetSAMLAttributeValues(xPath, doc,
                "https://redshift.amazon.com/SAML/Attributes/AutoCreate");
        if (!attributeValues.isEmpty())
        {
            metadata.setAutoCreate(Boolean.valueOf(attributeValues.get(0)));
        }

        attributeValues = GetSAMLAttributeValues(xPath, doc,
                "https://redshift.amazon.com/SAML/Attributes/DbGroups");
        if (!attributeValues.isEmpty())
        {
            attributeValues = filterOutGroups(attributeValues);
            if (!attributeValues.isEmpty())
            {
                StringBuilder sb = new StringBuilder();
                for (String value : attributeValues)
                {
                    if (sb.length() > 0)
                    {
                        sb.append(',');
                    }
                    sb.append(value);
                }
                metadata.setDbGroups(sb.toString());
            }
        }
        
        attributeValues = GetSAMLAttributeValues(xPath, doc,
		                "https://redshift.amazon.com/SAML/Attributes/ForceLowercase");
        if (!attributeValues.isEmpty())
        {
            metadata.setForceLowercase(Boolean.valueOf(attributeValues.get(0)));
        }
        		
        return metadata;
    }
    
    /**
     * Method removes all groups from given lists matching {@link m_dbGroupsFilter}
     *  regex.
     * @param attributeValues in
     * @return attributeValues filtered
     */
    private List filterOutGroups(List attributeValues) {
        if ( m_dbGroupsFilter != null )
        {
            final Pattern groupsFilter = Pattern.compile(m_dbGroupsFilter);
            List ret = new ArrayList<>();
            for (String attributeValue : attributeValues)
            {
                m_log.logDebug("Check group {0} with regexp {1}",
                                             attributeValue, m_dbGroupsFilter);
                if (!groupsFilter.matcher(attributeValue).matches())
                {
                    m_log.logDebug("Add {0} to dbgroups", attributeValue);
                    ret.add(attributeValue);
                }
            }
            return ret;
        }
        else {
            return attributeValues;
        }
    }

    private static Document parse(byte[] samlAssertion) throws IOException, SAXException,
            ParserConfigurationException
    {
        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
        factory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
        factory.setXIncludeAware(false);
        factory.setExpandEntityReferences(false);        
        factory.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
        factory.setFeature("http://xml.org/sax/features/external-general-entities", false);
        DocumentBuilder db = factory.newDocumentBuilder();
        return db.parse(new ByteArrayInputStream(samlAssertion));
    }

    private static List GetSAMLAttributeValues(XPath xPath, Document doc, String attributeName)
            throws XPathExpressionException
    {
        String expression = String.format("//Attribute[@Name='%s']/AttributeValue/text()", attributeName);
        NodeList nodeList = (NodeList) xPath.compile(expression).evaluate(doc, XPathConstants.NODESET);
        if (null == nodeList || nodeList.getLength() == 0)
        {
            return Collections.emptyList();
        }
        List attributeValues = new ArrayList(nodeList.getLength());
        for (int i = 0; i < nodeList.getLength(); ++i)
        {
            Node node = nodeList.item(i);
            attributeValues.add(node.getNodeValue());
        }
        return attributeValues;
    }

    protected List getInputTagsfromHTML(String body)
    {
    		Set distinctInputTags = new HashSet<>();    	
        List inputTags = new ArrayList();
        Pattern inputTagPattern = Pattern.compile("", Pattern.DOTALL);
        Matcher inputTagMatcher = inputTagPattern.matcher(body);
        while (inputTagMatcher.find())
        {
           String tag = inputTagMatcher.group(0);
           String tagNameLower = getValueByKey(tag, "name").toLowerCase();
           if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) 
           {
          	 inputTags.add(tag);
           }
        }
        return inputTags;
    }

    protected String getFormAction(String body)
    {
        Pattern pattern = Pattern.compile("');
                i += 4;
            }
            else
            {
                sb.append(c);
                ++i;
            }
        }
        return sb.toString();
    }

    protected void checkRequiredParameters() throws IOException
    {
        if (StringUtils.isNullOrEmpty(m_userName))
        {
            throw new IOException("Missing required property: " + RedshiftProperty.USER.getName());
        }
        if (StringUtils.isNullOrEmpty(m_password))
        {
            throw new IOException("Missing required property: " + RedshiftProperty.PASSWORD.getName());
        }
        if (StringUtils.isNullOrEmpty(m_idpHost))
        {
            throw new IOException("Missing required property: " + KEY_IDP_HOST);
        }
    }
    
    protected boolean isText(String inputTag)
    {
    		String typeVal = getValueByKey(inputTag, "type");
    		if(typeVal == null
    				|| typeVal.length() == 0)
    		{
    			typeVal = getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type");
    		}
    		
        return "text".equals(typeVal);
    }

    protected boolean isPassword(String inputTag)
    {
  		String typeVal = getValueByKey(inputTag, "type");
  		if(typeVal == null
  				|| typeVal.length() == 0)
  		{
  			typeVal = getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type");
  		}
  		
      return "password".equals(typeVal);
    }   
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy