All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.zeppelin.realm.jwt.KnoxJwtRealm Maven / Gradle / Ivy

There is a newer version: 0.11.2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.zeppelin.realm.jwt;

import java.util.Date;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Groups;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAccount;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import javax.servlet.ServletException;

import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jwt.SignedJWT;

/**
 * Created for org.apache.zeppelin.server.
 */
public class KnoxJwtRealm extends AuthorizingRealm {
  private static final Logger LOGGER = LoggerFactory.getLogger(KnoxJwtRealm.class);

  private String providerUrl;
  private String redirectParam;
  private String cookieName;
  private String publicKeyPath;
  private String login;
  private String logout;
  private Boolean logoutAPI;

  private String principalMapping;
  private String groupPrincipalMapping;

  private SimplePrincipalMapper mapper = new SimplePrincipalMapper();

  /**
   * Configuration object needed by for Hadoop classes.
   */
  private Configuration hadoopConfig;

  /**
   * Hadoop Groups implementation.
   */
  private Groups hadoopGroups;

  @Override
  protected void onInit() {
    super.onInit();
    if (principalMapping != null && !principalMapping.isEmpty()
        || groupPrincipalMapping != null && !groupPrincipalMapping.isEmpty()) {
      try {
        mapper.loadMappingTable(principalMapping, groupPrincipalMapping);
      } catch (PrincipalMappingException e) {
        LOGGER.error("PrincipalMappingException in onInit", e);
      }
    }

    try {
      hadoopConfig = new Configuration();
      hadoopGroups = new Groups(hadoopConfig);
    } catch (final Exception e) {
      LOGGER.error("Exception in onInit", e);
    }
  }

  @Override
  public boolean supports(AuthenticationToken token) {
    return token != null && token instanceof JWTAuthenticationToken;
  }

  @Override
  protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) {
    JWTAuthenticationToken upToken = (JWTAuthenticationToken) token;

    if (validateToken(upToken.getToken())) {
      try {
        SimpleAccount account = new SimpleAccount(getName(upToken), upToken.getToken(), getName());
        account.addRole(mapGroupPrincipals(getName(upToken)));
        return account;
      } catch (ParseException e) {
        LOGGER.error("ParseException in doGetAuthenticationInfo", e);
      }
    }
    return null;
  }

  public String getName(JWTAuthenticationToken upToken) throws ParseException {
    SignedJWT signed = SignedJWT.parse(upToken.getToken());
    String userName = signed.getJWTClaimsSet().getSubject();
    return userName;
  }

  protected boolean validateToken(String token) {
    try {
      SignedJWT signed = SignedJWT.parse(token);
      boolean sigValid = validateSignature(signed);
      if (!sigValid) {
        LOGGER.warn("Signature of JWT token could not be verified. Please check the public key");
        return false;
      }
      boolean expValid = validateExpiration(signed);
      if (!expValid) {
        LOGGER.warn("Expiration time validation of JWT token failed.");
        return false;
      }
      String currentUser = (String) org.apache.shiro.SecurityUtils.getSubject().getPrincipal();
      if (currentUser == null) {
        return true;
      }
      String cookieUser = signed.getJWTClaimsSet().getSubject();
      if (!cookieUser.equals(currentUser)) {
        return false;
      }
      return true;
    } catch (ParseException ex) {
      LOGGER.info("ParseException in validateToken", ex);
      return false;
    }
  }

  public static RSAPublicKey parseRSAPublicKey(String pem) throws IOException, ServletException {
    final String pemHeader = "-----BEGIN CERTIFICATE-----\n";
    final String pemFooter = "\n-----END CERTIFICATE-----";
    String fullPem = pemHeader + pem + pemFooter;
    PublicKey key = null;
    try {
      CertificateFactory fact = CertificateFactory.getInstance("X.509");
      ByteArrayInputStream is = new ByteArrayInputStream(
          FileUtils.readFileToString(new File(pem)).getBytes("UTF8"));
      X509Certificate cer = (X509Certificate) fact.generateCertificate(is);
      key = cer.getPublicKey();
    } catch (CertificateException ce) {
      String message = null;
      if (pem.startsWith(pemHeader)) {
        message = "CertificateException - be sure not to include PEM header "
            + "and footer in the PEM configuration element.";
      } else {
        message = "CertificateException - PEM may be corrupt";
      }
      throw new ServletException(message, ce);
    } catch (UnsupportedEncodingException uee) {
      throw new ServletException(uee);
    } catch (IOException e) {
      throw new IOException(e);
    }
    return (RSAPublicKey) key;
  }

  protected boolean validateSignature(SignedJWT jwtToken) {
    boolean valid = false;
    if (JWSObject.State.SIGNED == jwtToken.getState()) {
      if (jwtToken.getSignature() != null) {
        try {
          RSAPublicKey publicKey = parseRSAPublicKey(publicKeyPath);
          JWSVerifier verifier = new RSASSAVerifier(publicKey);
          if (verifier != null && jwtToken.verify(verifier)) {
            valid = true;
          }
        } catch (Exception e) {
          LOGGER.info("Exception in validateSignature", e);
        }
      }
    }
    return valid;
  }

  /**
   * Validate that the expiration time of the JWT token has not been violated.
   * If it has then throw an AuthenticationException. Override this method in
   * subclasses in order to customize the expiration validation behavior.
   *
   * @param jwtToken
   *            the token that contains the expiration date to validate
   * @return valid true if the token has not expired; false otherwise
   */
  protected boolean validateExpiration(SignedJWT jwtToken) {
    boolean valid = false;
    try {
      Date expires = jwtToken.getJWTClaimsSet().getExpirationTime();
      if (expires == null || new Date().before(expires)) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("SSO token expiration date has been " + "successfully validated");
        }
        valid = true;
      } else {
        LOGGER.warn("SSO expiration date validation failed.");
      }
    } catch (ParseException pe) {
      LOGGER.warn("SSO expiration date validation failed.", pe);
    }
    return valid;
  }

  @Override
  protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
    Set roles = mapGroupPrincipals(principals.toString());
    return new SimpleAuthorizationInfo(roles);
  }

  /**
   * Query the Hadoop implementation of {@link Groups} to retrieve groups for provided user.
   */
  public Set mapGroupPrincipals(final String mappedPrincipalName) {
    /* return the groups as seen by Hadoop */
    Set groups = null;
    try {
      hadoopGroups.refresh();
      final List groupList = hadoopGroups
          .getGroups(mappedPrincipalName);

      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug(String.format("group found %s, %s",
            mappedPrincipalName, groupList.toString()));
      }

      groups = new HashSet<>(groupList);

    } catch (final IOException e) {
      if (e.toString().contains("No groups found for user")) {
        /* no groups found move on */
        LOGGER.info(String.format("No groups found for user %s", mappedPrincipalName));

      } else {
        /* Log the error and return empty group */
        LOGGER.info(String.format("errorGettingUserGroups for %s", mappedPrincipalName));
      }
      groups = new HashSet();
    }
    return groups;
  }

  public String getProviderUrl() {
    return providerUrl;
  }

  public void setProviderUrl(String providerUrl) {
    this.providerUrl = providerUrl;
  }

  public String getRedirectParam() {
    return redirectParam;
  }

  public void setRedirectParam(String redirectParam) {
    this.redirectParam = redirectParam;
  }

  public String getCookieName() {
    return cookieName;
  }

  public void setCookieName(String cookieName) {
    this.cookieName = cookieName;
  }

  public String getPublicKeyPath() {
    return publicKeyPath;
  }

  public void setPublicKeyPath(String publicKeyPath) {
    this.publicKeyPath = publicKeyPath;
  }

  public String getLogin() {
    return login;
  }

  public void setLogin(String login) {
    this.login = login;
  }

  public String getLogout() {
    return logout;
  }

  public void setLogout(String logout) {
    this.logout = logout;
  }

  public Boolean getLogoutAPI() {
    return logoutAPI;
  }

  public void setLogoutAPI(Boolean logoutAPI) {
    this.logoutAPI = logoutAPI;
  }

  public String getPrincipalMapping() {
    return principalMapping;
  }

  public void setPrincipalMapping(String principalMapping) {
    this.principalMapping = principalMapping;
  }

  public String getGroupPrincipalMapping() {
    return groupPrincipalMapping;
  }

  public void setGroupPrincipalMapping(String groupPrincipalMapping) {
    this.groupPrincipalMapping = groupPrincipalMapping;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy