All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cn.bestwu.api.test.RememberMeService Maven / Gradle / Ivy

package cn.bestwu.api.test;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Date;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
 * 记住我过滤器
 *
 * @author Peter Wu
 */
public class RememberMeService {

  private Logger logger = LoggerFactory.getLogger(RememberMeService.class);

  public static final String SECURITY_REMEMBER_ME_COOKIE_KEY = "api-test-remember-me";
  public static final int TWO_WEEKS_S = 1209600;
  private int tokenValiditySeconds = TWO_WEEKS_S;
  private static final String DELIMITER = ":";
  private String cookieName = SECURITY_REMEMBER_ME_COOKIE_KEY;
  private final String key;

  private final ApiTestUserRepository apiTestUserRepository;

  public RememberMeService(String key, ApiTestUserRepository apiTestUserRepository) {
    this.key = key;
    this.apiTestUserRepository = apiTestUserRepository;
  }

  public void setTokenValiditySeconds(int tokenValiditySeconds) {
    this.tokenValiditySeconds = tokenValiditySeconds;
  }

  public void setCookieName(String cookieName) {
    this.cookieName = cookieName;
  }

  public ApiTestUser autoLogin(HttpServletRequest request, HttpServletResponse response) {
    String rememberMeCookie = extractRememberMeCookie(request);

    if (rememberMeCookie == null) {
      return null;
    }
    logger.debug("api-test-remember-me cookie detected");
    if (rememberMeCookie.length() == 0) {
      logger.debug("Cookie was empty");
      cancelCookie(request, response);
      return null;
    }

    try {
      String[] cookieTokens = decodeCookie(rememberMeCookie);
      if (cookieTokens.length != 3) {
        throw new RuntimeException("Cookie token did not contain 3"
            + " tokens, but contained '" + Arrays.asList(cookieTokens) + "'");
      }

      long tokenExpiryTime;

      try {
        tokenExpiryTime = new Long(cookieTokens[1]);
      } catch (NumberFormatException nfe) {
        throw new RuntimeException(
            "Cookie token[1] did not contain a valid number (contained '"
                + cookieTokens[1] + "')");
      }

      if (isTokenExpired(tokenExpiryTime)) {
        throw new RuntimeException("Cookie token[1] has expired (expired on '"
            + new Date(tokenExpiryTime) + "'; current time is '" + new Date()
            + "')");
      }
      ApiTestUser apiTestUser = apiTestUserRepository.findByUsername(cookieTokens[0]);
      if (apiTestUser == null) {
        throw new RuntimeException(
            "api-test-remember-me login was valid but corresponding user not found.");
      }
      String expectedTokenSignature = makeTokenSignature(tokenExpiryTime, apiTestUser.getUsername(),
          apiTestUser.getPassword());
      if (!equals(expectedTokenSignature, cookieTokens[2])) {
        throw new RuntimeException("Cookie token[2] contained signature '"
            + cookieTokens[2] + "' but expected '" + expectedTokenSignature + "'");
      }
      logger.debug("api-test-remember-me cookie accepted");

      return apiTestUser;
    } catch (Exception e) {
      logger.debug(e.getMessage());
    }

    cancelCookie(request, response);
    return null;
  }

  protected String[] decodeCookie(String cookieValue) {
    for (int j = 0; j < cookieValue.length() % 4; j++) {
      cookieValue = cookieValue + "=";
    }

    String cookieAsPlainText = new String(Base64.getDecoder().decode(cookieValue.getBytes()));

    String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText,
        DELIMITER);

    if ((tokens[0].equalsIgnoreCase("http") || tokens[0].equalsIgnoreCase("https"))
        && tokens[1].startsWith("//")) {
      String[] newTokens = new String[tokens.length - 1];
      newTokens[0] = tokens[0] + ":" + tokens[1];
      System.arraycopy(tokens, 2, newTokens, 1, newTokens.length - 1);
      tokens = newTokens;
    }

    return tokens;
  }

  public void cancelCookie(HttpServletRequest request, HttpServletResponse response) {
    logger.debug("Cancelling cookie");
    Cookie cookie = new Cookie(cookieName, null);
    cookie.setMaxAge(0);
    cookie.setPath(getCookiePath(request));
    response.addCookie(cookie);
  }

  protected String extractRememberMeCookie(HttpServletRequest request) {
    Cookie[] cookies = request.getCookies();

    if ((cookies == null) || (cookies.length == 0)) {
      return null;
    }

    for (Cookie cookie : cookies) {
      if (cookieName.equals(cookie.getName())) {
        return cookie.getValue();
      }
    }

    return null;
  }

  protected String makeTokenSignature(long tokenExpiryTime, String username,
      String password) {
    String data = username + ":" + tokenExpiryTime + ":" + password + ":" + key;
    MessageDigest digest;
    try {
      digest = MessageDigest.getInstance("MD5");
    } catch (NoSuchAlgorithmException e) {
      throw new IllegalStateException("No MD5 algorithm available!");
    }

    return new String(Sha1DigestUtil.Hex.encode(digest.digest(data.getBytes())));
  }

  protected boolean isTokenExpired(long tokenExpiryTime) {
    return tokenExpiryTime < System.currentTimeMillis();
  }

  public void onLoginSuccess(HttpServletRequest request, HttpServletResponse response,
      String username, String password) {
    Assert.hasText(username, "用户名不能为空");
    Assert.hasText(password, "密码不能为空");

    int tokenLifetime = tokenValiditySeconds;
    long expiryTime = System.currentTimeMillis();
    expiryTime += 1000L * (tokenLifetime < 0 ? TWO_WEEKS_S : tokenLifetime);

    String signatureValue = makeTokenSignature(expiryTime, username, password);
    setCookie(new String[]{username, Long.toString(expiryTime), signatureValue},
        tokenLifetime, request, response);

    if (logger.isDebugEnabled()) {
      logger.debug("Added api-test-remember-me cookie for user '" + username
          + "', expiry: '" + new Date(expiryTime) + "'");
    }
  }

  protected void setCookie(String[] tokens, int maxAge, HttpServletRequest request,
      HttpServletResponse response) {
    String cookieValue = encodeCookie(tokens);
    Cookie cookie = new Cookie(cookieName, cookieValue);
    cookie.setMaxAge(maxAge);
    cookie.setPath(getCookiePath(request));
    if (maxAge < 1) {
      cookie.setVersion(1);
    }

    //		cookie.setSecure(true);

    response.addCookie(cookie);
  }

  private String getCookiePath(HttpServletRequest request) {
    String contextPath = request.getContextPath();
    return contextPath.length() > 0 ? contextPath : "/";
  }

  protected String encodeCookie(String[] cookieTokens) {
    StringBuilder sb = new StringBuilder();
    for (int i = 0; i < cookieTokens.length; i++) {
      sb.append(cookieTokens[i]);

      if (i < cookieTokens.length - 1) {
        sb.append(DELIMITER);
      }
    }

    String value = sb.toString();

    sb = new StringBuilder(new String(Base64.getEncoder().encode(value.getBytes())));

    while (sb.charAt(sb.length() - 1) == '=') {
      sb.deleteCharAt(sb.length() - 1);
    }

    return sb.toString();
  }

  /**
   * Constant time comparison to prevent against timing attacks.
   */
  private static boolean equals(String expected, String actual) {
    byte[] expectedBytes = bytesUtf8(expected);
    byte[] actualBytes = bytesUtf8(actual);
    if (expectedBytes.length != actualBytes.length) {
      return false;
    }

    int result = 0;
    for (int i = 0; i < expectedBytes.length; i++) {
      result |= expectedBytes[i] ^ actualBytes[i];
    }
    return result == 0;
  }

  private static byte[] bytesUtf8(String s) {
    if (s == null) {
      return null;
    }
    return encode(s);
  }

  private static final Charset CHARSET = Charset.forName("UTF-8");

  public static byte[] encode(CharSequence string) {
    try {
      ByteBuffer bytes = CHARSET.newEncoder().encode(CharBuffer.wrap(string));
      byte[] bytesCopy = new byte[bytes.limit()];
      System.arraycopy(bytes.array(), 0, bytesCopy, 0, bytes.limit());

      return bytesCopy;
    } catch (CharacterCodingException e) {
      throw new IllegalArgumentException("Encoding failed", e);
    }
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy