All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ldaptive.transport.ScramSaslClient Maven / Gradle / Ivy

There is a newer version: 2.4.1
Show newest version
/* See LICENSE for licensing and NOTICE for copyright. */
package org.ldaptive.transport;

import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.ldaptive.BindResponse;
import org.ldaptive.LdapException;
import org.ldaptive.LdapUtils;
import org.ldaptive.ResultCode;
import org.ldaptive.sasl.Mechanism;
import org.ldaptive.sasl.SaslBindRequest;
import org.ldaptive.sasl.SaslClient;
import org.ldaptive.sasl.ScramBindRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * SASL client that implements the SCRAM protocol. See RFC 5802.
 *
 * @author  Middleware Services
 */
public class ScramSaslClient implements SaslClient
{

  /** Logger for this class. */
  private static final Logger LOGGER = LoggerFactory.getLogger(ScramSaslClient.class);


  /**
   * Performs a SCRAM SASL bind.
   *
   * @param  conn  to perform the bind on
   * @param  request  SASL request to perform
   *
   * @return  final result of the bind process
   *
   * @throws  LdapException  if an error occurs
   */
  @Override
  public BindResponse bind(final TransportConnection conn, final ScramBindRequest request)
    throws LdapException
  {
    final ClientFirstMessage clientFirstMessage = new ClientFirstMessage(request.getUsername(), request.getNonce());

    final BindResponse serverFirstResult = conn.operation(
      new SaslBindRequest(
        request.getMechanism().mechanism(), LdapUtils.utf8Encode(clientFirstMessage.encode(), false))).execute();

    if (serverFirstResult.getResultCode() != ResultCode.SASL_BIND_IN_PROGRESS) {
      if (serverFirstResult.isSuccess()) {
        throw new IllegalStateException(
          "Unexpected success result from SCRAM SASL bind: " + serverFirstResult.getResultCode());
      }
      LOGGER.warn("Unexpected server result {}", serverFirstResult);
      return serverFirstResult;
    }

    final ClientFinalMessage clientFinalMessage = new ClientFinalMessage(
      request.getMechanism(),
      request.getPassword(),
      clientFirstMessage,
      new ServerFirstMessage(clientFirstMessage, serverFirstResult));

    final BindResponse serverFinalResult = conn.operation(
      new SaslBindRequest(
        request.getMechanism().mechanism(), LdapUtils.utf8Encode(clientFinalMessage.encode(), false))).execute();

    final ServerFinalMessage serverFinalMessage = new ServerFinalMessage(
      request.getMechanism(),
      clientFinalMessage,
      serverFinalResult);

    if (!serverFinalResult.isSuccess() && serverFinalMessage.isVerified()) {
      throw new IllegalStateException("Verified server message but result was not a success");
    } else if (serverFinalResult.isSuccess() && !serverFinalMessage.isVerified()) {
      throw new IllegalStateException("Received success from server but message could not be verified");
    }
    return serverFinalResult;
  }


  /** Properties associated with the client first message. */
  static class ClientFirstMessage
  {

    /** GS2 header for no channel binding. */
    private static final String GS2_NO_CHANNEL_BINDING = "n,,";

    /** Default nonce size. */
    private static final int DEFAULT_NONCE_SIZE = 16;

    /** Username to authenticate. */
    private final String clientUsername;

    /** Protocol nonce. */
    private final String clientNonce;

    /** Message produced from the username and nonce. */
    private final String message;


    /**
     * Creates a new client first message. If nonce is null a random is created for this client.
     *
     * @param  username  to authenticate
     * @param  nonce  to supply to the server or null
     */
    ClientFirstMessage(final String username, final byte[] nonce)
    {
      clientUsername = username;
      if (nonce == null) {
        final SecureRandom random = new SecureRandom();
        // force seeding
        random.nextBytes(new byte[1]);
        final byte[] b = new byte[DEFAULT_NONCE_SIZE];
        random.nextBytes(b);
        clientNonce = LdapUtils.base64Encode(b);
      } else {
        clientNonce = LdapUtils.base64Encode(nonce);
      }
      message = "n=".concat(clientUsername).concat(",").concat("r=").concat(clientNonce);
    }


    public String getNonce()
    {
      return clientNonce;
    }


    public String getMessage()
    {
      return message;
    }


    /**
     * Encodes this message to send to the server. This methods prepends the message with a GS2 header indicating that
     * no channel binding is supported.
     *
     * @return  encoded message
     */
    public String encode()
    {
      return GS2_NO_CHANNEL_BINDING.concat(message);
    }
  }


  /** Properties associated with the final client message. */
  static class ClientFinalMessage
  {

    /** GS2 header for no channel binding. */
    private static final String GS2_NO_CHANNEL_BINDING = LdapUtils.base64Encode("n,,");

    /** 4-octet encoding of the integer 1. */
    private static final byte[] INTEGER_ONE = {0x00, 0x00, 0x00, 0x01, };

    /** Bytes for the client key hmac. */
    private static final byte[] CLIENT_KEY_INIT = LdapUtils.utf8Encode("Client Key");

    /** Scram SASL mechanism. */
    private final Mechanism mechanism;

    /** Channel binding attribute plus the combined nonce. */
    private final String withoutProof;

    /** Client first message plus the server first message plus the withoutProof string. */
    private final String message;

    /** Computed password using the server salt and iterations. */
    private final byte[] saltedPassword;


    /**
     * Creates a new client final message.
     *
     * @param  mech  scram mechanism
     * @param  password  to authenticate the user with
     * @param  clientFirstMessage  first message sent to the server
     * @param  serverFirstMessage  first response from the server
     */
    ClientFinalMessage(
      final Mechanism mech,
      final String password,
      final ClientFirstMessage clientFirstMessage,
      final ServerFirstMessage serverFirstMessage)
    {
      mechanism = mech;
      saltedPassword = createSaltedPassword(
        mechanism.properties()[1],
        password,
        serverFirstMessage.getSalt(),
        serverFirstMessage.getIterations());

      withoutProof = "c=".concat(GS2_NO_CHANNEL_BINDING).concat(",")
        .concat("r=").concat(serverFirstMessage.getCombinedNonce());

      message = clientFirstMessage.getMessage().concat(",")
        .concat(serverFirstMessage.getMessage()).concat(",")
        .concat(withoutProof);
    }


    public byte[] getSaltedPassword()
    {
      return saltedPassword;
    }


    public String getMessage()
    {
      return message;
    }


    /**
     * Encodes this message to send to the server. Concatenation of the message without proof and the proof.
     *
     * @return  encoded message
     */
    public String encode()
    {
      final byte[] clientKey = createMac(mechanism.properties()[1], saltedPassword).doFinal(CLIENT_KEY_INIT);
      final byte[] storedKey = createDigest(mechanism.properties()[0], clientKey);

      final byte[] clientSignature =
        createMac(mechanism.properties()[1], storedKey).doFinal(LdapUtils.utf8Encode(message, false));

      final byte[] clientProof = new byte[clientKey.length];
      for (int i = 0; i < clientProof.length; i++) {
        clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
      }

      return withoutProof.concat(",p=").concat(LdapUtils.base64Encode(clientProof));
    }


    /**
     * Computes a salted password.
     *
     * @param  algorithm  of the MAC
     * @param  password  to seed the MAC with
     * @param  salt  for the MAC
     * @param  iterations  of the MAC
     *
     * @return  salted password
     */
    private static byte[] createSaltedPassword(
      final String algorithm,
      final String password,
      final byte[] salt,
      final int iterations)
    {
      // create an HMAC using the UTF-8 password
      final Mac mac = createMac(algorithm, LdapUtils.utf8Encode(password, false));

      // Per the RFC, seed the salt with the bytes of integer 1
      byte[] bytes = Arrays.copyOf(salt, salt.length + INTEGER_ONE.length);
      System.arraycopy(INTEGER_ONE, 0, bytes, salt.length, INTEGER_ONE.length);

      // first iteration is the MAC of the salt and integer 1
      bytes = mac.doFinal(bytes);

      // remaining iterations create the MAC of the previous MAC and XOR that result with the previous MAC
      final byte[] xor = bytes;
      for (int i = 1; i < iterations; i++) {
        final byte[] macResult = mac.doFinal(bytes);
        for (int j = 0; j < macResult.length; j++) {
          xor[j] ^= macResult[j];
        }
        bytes = macResult;
      }
      return xor;
    }
  }


  /** Properties associated with the first server response. */
  static class ServerFirstMessage
  {
    /** Minimum number of iterations we will allow. */
    private static final int MINIMUM_ITERATION_COUNT = 4096;

    /** The server SASL credentials. */
    private final String message;

    /** Nonce parsed from the SASL credentials. */
    private final String combinedNonce;

    /** Salt parsed from the SASL credentials. */
    private final byte[] salt;

    /** Iterations parsed from the SASL credentials. */
    private final int iterations;


    /**
     * Creates a new server first message.
     *
     * @param  clientFirstMessage  first message sent to the server
     * @param  result  response to the first message
     */
    ServerFirstMessage(final ClientFirstMessage clientFirstMessage, final BindResponse result)
    {
      if (result.getServerSaslCreds() == null || result.getServerSaslCreds().length == 0) {
        throw new IllegalArgumentException("Bind response missing server SASL credentials");
      }

      message = LdapUtils.utf8Encode(result.getServerSaslCreds(), false);
      final Map attributes = Stream.of(message.split(","))
        .map(s -> s.split("=", 2)).collect(Collectors.toMap(attr -> attr[0], attr -> attr[1]));

      final String r = attributes.get("r");
      if (r == null) {
        throw new IllegalArgumentException("Invalid SASL credentials, missing server nonce");
      }
      if (!r.startsWith(clientFirstMessage.getNonce())) {
        throw new IllegalArgumentException("Invalid SASL credentials, missing client nonce");
      }
      combinedNonce = r;

      final String s = attributes.get("s");
      if (s == null) {
        throw new IllegalArgumentException("Invalid SASL credentials, missing server salt");
      }
      salt = LdapUtils.base64Decode(s);

      final String i = attributes.get("i");
      iterations = Integer.parseInt(i);
      if (iterations < MINIMUM_ITERATION_COUNT) {
        throw new IllegalArgumentException("Invalid SASL credentials, iterations minimum value is 4096");
      }
    }


    public String getMessage()
    {
      return message;
    }


    public String getCombinedNonce()
    {
      return combinedNonce;
    }


    public byte[] getSalt()
    {
      return salt;
    }


    public int getIterations()
    {
      return iterations;
    }
  }


  /** Verifies the final server message. */
  static class ServerFinalMessage
  {

    /** Bytes for the server key hmac. */
    private static final byte[] SERVER_KEY_INIT = LdapUtils.utf8Encode("Server Key");

    /** Server SASL credentials. */
    private final String message;

    /** Whether the server message was successfully verified. */
    private final boolean verified;


    /**
     * Creates a new server final message.
     *
     * @param  mech  scram mechanism
     * @param  clientFinalMessage  final message sent to the server
     * @param  result  response to the final message
     */
    ServerFinalMessage(
      final Mechanism mech,
      final ClientFinalMessage clientFinalMessage,
      final BindResponse result)
    {
      if (result.getServerSaslCreds() == null || result.getServerSaslCreds().length == 0) {
        throw new IllegalArgumentException("Bind response missing server SASL credentials");
      }

      message = LdapUtils.utf8Encode(result.getServerSaslCreds(), false);
      final Map attributes = Stream.of(message.split(","))
        .map(s -> s.split("=", 2)).collect(Collectors.toMap(attr -> attr[0], attr -> attr[1]));

      final String e = attributes.get("e");
      if (e != null) {
        LOGGER.warn("SASL bind server final message included error: {}", e);
      }

      if (result.getResultCode() != ResultCode.SUCCESS) {
        verified = false;
      } else {
        final String serverSignature = attributes.get("v");
        if (serverSignature == null) {
          throw new IllegalArgumentException("Invalid SASL credentials, missing server verification");
        }

        // compare the server signature in the message to what we expect
        final byte[] serverKey =
          createMac(mech.properties()[1], clientFinalMessage.getSaltedPassword()).doFinal(SERVER_KEY_INIT);
        final String expectedServerSignature = LdapUtils.base64Encode(
          createMac(mech.properties()[1], serverKey).doFinal(
            LdapUtils.utf8Encode(clientFinalMessage.getMessage(), false)));
        if (!expectedServerSignature.equals(serverSignature)) {
          throw new IllegalArgumentException("Invalid SASL credentials, incorrect server verification");
        }
        verified = true;
      }
    }


    /**
     * Returns whether the server final message was successfully verified.
     *
     * @return  whether the server message was verified.
     */
    public boolean isVerified()
    {
      return verified;
    }
  }


  /**
   * Creates a new MAC using the supplied algorithm and key.
   *
   * @param  algorithm  of the MAC
   * @param  key  to seed the MAC
   *
   * @return  new mac
   */
  private static Mac createMac(final String algorithm, final byte[] key)
  {
    try {
      final Mac mac = Mac.getInstance(algorithm);
      mac.init(new SecretKeySpec(key, algorithm));
      return mac;
    } catch (Exception e) {
      throw new IllegalStateException("Could not create MAC", e);
    }
  }


  /**
   * Digests the supplied data using the supplied algorithm.
   *
   * @param  algorithm  of the digest
   * @param  data  to digest
   *
   * @return  digested data
   */
  private static byte[] createDigest(final String algorithm, final byte[] data)
  {
    try {
      return MessageDigest.getInstance(algorithm).digest(data);
    } catch (Exception e) {
      throw new IllegalStateException("Could not create digest", e);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy