org.wildfly.security.mechanism.scram.ScramServer Maven / Gradle / Ivy
The newest version!
/*
* JBoss, Home of Professional Open Source.
* Copyright 2015 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.wildfly.security.mechanism.scram;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.util.Arrays.copyOfRange;
import static org.wildfly.security.mechanism._private.ElytronMessages.saslScram;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.security.SecureRandom;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import org.wildfly.common.Assert;
import org.wildfly.common.bytes.ByteStringBuilder;
import org.wildfly.common.iteration.ByteIterator;
import org.wildfly.security.auth.callback.ChannelBindingCallback;
import org.wildfly.security.mechanism._private.MechanismUtil;
import org.wildfly.security.mechanism.AuthenticationMechanismException;
import org.wildfly.security.mechanism.ScramServerErrorCode;
import org.wildfly.security.mechanism.ScramServerException;
import org.wildfly.security.password.interfaces.ScramDigestPassword;
import org.wildfly.security.password.spec.IteratedPasswordAlgorithmSpec;
import org.wildfly.security.sasl.util.StringPrep;
/**
* A server-side implementation for the SCRAM authentication.
*
* @author David M. Lloyd
*/
public final class ScramServer {
private final Supplier providers;
private final ScramMechanism mechanism;
private final CallbackHandler callbackHandler;
private final SecureRandom random;
private final byte[] bindingData;
private final String bindingType;
private final int minimumIterationCount;
private final int maximumIterationCount;
/**
* Constructs a new {@code ScramServer}.
*
* @param mechanism the SCRAM mechanism used for the authentication.
* @param callbackHandler the callback handler for the authentication.
* @param random an optional secure RNG to use.
* @param bindingData the binding data for the "PLUS" channel binding option.
* @param bindingType the binding type for the "PLUS" channel binding option.
* @param minimumIterationCount the minimum number of iterations for password hashing.
* @param maximumIterationCount the maximum number of iterations for password hashing.
* @param providers the security providers.
*/
ScramServer(final ScramMechanism mechanism, final CallbackHandler callbackHandler, final SecureRandom random, final byte[] bindingData, final String bindingType, final int minimumIterationCount, final int maximumIterationCount, final Supplier providers) {
this.mechanism = mechanism;
this.callbackHandler = callbackHandler;
this.random = random;
this.bindingData = bindingData;
this.bindingType = bindingType;
this.minimumIterationCount = minimumIterationCount;
this.maximumIterationCount = maximumIterationCount;
this.providers = providers;
}
/**
* Construct an initial response object from a byte array.
*
* @param bindingCallback the optional channel binding callback result (may be {@code null})
* @param bytes the message bytes (must not be {@code null})
* @return the constructed initial response (not {@code null})
* @throws AuthenticationMechanismException if the content of the message is invalid
*/
public ScramInitialClientMessage parseInitialClientMessage(ChannelBindingCallback bindingCallback, byte[] bytes) throws AuthenticationMechanismException {
byte[] response = bytes.clone();
ByteIterator bi = ByteIterator.ofBytes(response);
try {
final char cbindFlag = (char) bi.next();
final boolean binding;
final String bindingType;
final byte[] bindingData;
if (bindingCallback != null) {
bindingType = bindingCallback.getBindingType();
bindingData = bindingCallback.getBindingData();
} else {
bindingType = null;
bindingData = null;
}
if (cbindFlag == 'p') {
if (! mechanism.isPlus()) {
throw new ScramServerException(saslScram.mechChannelBindingNotSupported(), ScramServerErrorCode.SERVER_DOES_NOT_SUPPORT_CHANNEL_BINDING);
}
if (bindingType == null || bindingData == null) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.CHANNEL_BINDING_NOT_PROVIDED);
}
if (bi.next() != '=') {
throw saslScram.mechInvalidMessageReceived();
}
if (! bindingType.equals(bi.delimitedBy(',').asUtf8String().drainToString())) {
throw new ScramServerException(saslScram.mechChannelBindingTypeMismatch(), ScramServerErrorCode.UNSUPPORTED_CHANNEL_BINDING_TYPE);
}
binding = true;
} else if (cbindFlag == 'y') {
if (mechanism.isPlus()) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
if (bindingType != null || bindingData != null) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
binding = true;
} else if (cbindFlag == 'n') {
if (mechanism.isPlus()) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
if (bindingType != null || bindingData != null) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
binding = false;
} else {
throw saslScram.mechInvalidMessageReceived();
}
if (bi.next() != ',') {
throw saslScram.mechInvalidMessageReceived();
}
// authorization ID
final int c = bi.next();
final String authorizationID;
if (c == 'a') {
if (bi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
authorizationID = bi.delimitedBy(',').asUtf8String().drainToString();
bi.next(); // skip delimiter
} else if (c == ',') {
authorizationID = null;
} else {
throw saslScram.mechInvalidClientMessage();
}
final int initialPartIndex = (int) bi.getIndex();
// user name
final String authenticationName;
if (bi.next() == 'n') {
if (bi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
ByteStringBuilder bsb = new ByteStringBuilder();
StringPrep.encode(bi.delimitedBy(',').asUtf8String().drainToString(), bsb, StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
authenticationName = new String(bsb.toArray(), StandardCharsets.UTF_8);
bi.next(); // skip delimiter
} else {
throw saslScram.mechInvalidClientMessage();
}
// random nonce
if (bi.next() != 'r' || bi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
final byte[] nonce = bi.delimitedBy(',').drain();
if (bi.hasNext()) {
throw saslScram.mechInvalidClientMessage();
}
return new ScramInitialClientMessage(mechanism, authorizationID, authenticationName, binding, bindingType, bindingData, nonce, initialPartIndex, response);
} catch (NoSuchElementException ignored) {
throw saslScram.mechInvalidMessageReceived();
}
}
/**
* Evaluates the initial client response message in SCRAM authentication.
* Generates a server nonce and salted password.
*
* @param clientMessage the initial client response message.
* @return the initial server result, containing the initial server message and the digest password.
* @throws AuthenticationMechanismException if an error occurs during the evaluation.
*/
public ScramInitialServerResult evaluateInitialResponse(final ScramInitialClientMessage clientMessage) throws AuthenticationMechanismException {
final boolean trace = saslScram.isTraceEnabled();
if (clientMessage.getMechanism() != mechanism) {
throw saslScram.mechUnmatchedMechanism(mechanism.toString(), clientMessage.getMechanism().toString());
}
// get salted password
final NameCallback nameCallback = new NameCallback("Remote authentication name", clientMessage.getAuthenticationName());
try {
MechanismUtil.handleCallbacks(saslScram, callbackHandler, nameCallback);
} catch (UnsupportedCallbackException e) {
throw saslScram.mechCallbackHandlerDoesNotSupportUserName(e);
}
final IteratedPasswordAlgorithmSpec generateParameters = new IteratedPasswordAlgorithmSpec(
max(minimumIterationCount, min(maximumIterationCount, ScramDigestPassword.DEFAULT_ITERATION_COUNT))
);
final ScramDigestPassword password = MechanismUtil.getPasswordCredential(clientMessage.getAuthenticationName(), callbackHandler, ScramDigestPassword.class, mechanism.getPasswordAlgorithm(), null, generateParameters, providers, saslScram);
final byte[] saltedPasswordBytes = password.getDigest();
final int iterationCount = password.getIterationCount();
if (iterationCount < minimumIterationCount) {
throw saslScram.mechIterationCountIsTooLow(iterationCount, minimumIterationCount);
}
if (iterationCount > maximumIterationCount) {
throw saslScram.mechIterationCountIsTooHigh(iterationCount, maximumIterationCount);
}
final byte[] salt = password.getSalt();
if(trace) saslScram.tracef("[S] Salt: %s%n", ByteIterator.ofBytes(salt).hexEncode().drainToString());
if(trace) saslScram.tracef("[S] Salted password: %s%n", ByteIterator.ofBytes(saltedPasswordBytes).hexEncode().drainToString());
ByteStringBuilder b = new ByteStringBuilder();
// nonce (client + server nonce)
b.append('r').append('=');
b.append(clientMessage.getRawNonce());
final byte[] serverNonce = ScramUtil.generateNonce(28, getRandom());
b.append(serverNonce);
b.append(',');
// salt
b.append('s').append('=');
b.appendLatin1(ByteIterator.ofBytes(salt).base64Encode());
b.append(',');
b.append('i').append('=');
b.append(Integer.toString(iterationCount));
byte[] messageBytes = b.toArray();
return new ScramInitialServerResult(new ScramInitialServerMessage(clientMessage, serverNonce, salt, iterationCount, messageBytes), password);
}
/**
* Parses the final client message and constructs the {@link ScramFinalClientMessage} from this parsed information.
* Also checks if the message has all necessary properties.
*
* @param initialResponse the initial client response message provided by {@link ScramServer#parseInitialClientMessage(ChannelBindingCallback, byte[])}.
* @param initialResult the initial server result provided by {@link ScramServer#evaluateInitialResponse(ScramInitialClientMessage)}.
* @param bytes the byte array representation of the client response.
* @return the final client message.
* @throws AuthenticationMechanismException if an error occurs during the parsing.
*/
public ScramFinalClientMessage parseFinalClientMessage(final ScramInitialClientMessage initialResponse, final ScramInitialServerResult initialResult, final byte[] bytes) throws AuthenticationMechanismException {
final ScramInitialServerMessage initialChallenge = initialResult.getScramInitialChallenge();
Assert.checkNotNullParam("initialResponse", initialResponse);
Assert.checkNotNullParam("initialChallenge", initialChallenge);
final ScramMechanism mechanism = initialResponse.getMechanism();
if (mechanism != initialChallenge.getMechanism()) {
throw saslScram.mechUnmatchedMechanism(mechanism.toString(), initialChallenge.getMechanism().toString());
}
byte[] response = bytes.clone();
ByteIterator bi = ByteIterator.ofBytes(response);
try {
if (bi.next() != 'c' || bi.next() != '=') {
throw saslScram.mechInvalidMessageReceived();
}
ByteIterator ibi = bi.delimitedBy(',').asUtf8String().base64Decode();
char cbindFlag = (char) ibi.next();
final String bindingType = initialResponse.getBindingType();
final byte[] bindingData = initialResponse.getRawBindingData();
final boolean binding = initialResponse.isBinding();
if (cbindFlag == 'p') {
if (! binding) {
throw new ScramServerException(saslScram.mechChannelBindingNotSupported(), ScramServerErrorCode.CHANNEL_BINDING_NOT_SUPPORTED);
}
if (bindingType == null || bindingData == null) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.CHANNEL_BINDING_NOT_PROVIDED);
}
if (ibi.next() != '=') {
throw saslScram.mechInvalidMessageReceived();
}
if (! bindingType.equals(ibi.delimitedBy(',').asUtf8String().drainToString())) {
throw new ScramServerException(saslScram.mechChannelBindingTypeMismatch(), ScramServerErrorCode.UNSUPPORTED_CHANNEL_BINDING_TYPE);
}
} else if (cbindFlag == 'y') {
if (mechanism.isPlus()) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
if (bindingType != null || bindingData != null) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
} else if (cbindFlag == 'n') {
if (binding) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
if (mechanism.isPlus()) {
throw new ScramServerException(saslScram.mechChannelBindingNotProvided(), ScramServerErrorCode.SERVER_DOES_SUPPORT_CHANNEL_BINDING);
}
} else {
throw saslScram.mechInvalidMessageReceived();
}
if (ibi.next() != ',') {
throw saslScram.mechInvalidMessageReceived();
}
// authorization ID
int c = ibi.next();
final String authorizationID;
if (c == 'a') {
if (ibi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
authorizationID = ibi.delimitedBy(',').asUtf8String().drainToString();
ibi.next(); // skip delimiter
if (! authorizationID.equals(initialResponse.getAuthorizationId())) {
throw saslScram.mechAuthorizationIdChanged();
}
} else if (c == ',') {
if (initialResponse.getAuthorizationId() != null) {
throw saslScram.mechAuthorizationIdChanged();
}
} else {
throw saslScram.mechInvalidClientMessage();
}
// channel binding data
if (bindingData != null && ! ibi.contentEquals(ByteIterator.ofBytes(bindingData))) {
throw new ScramServerException(saslScram.mechChannelBindingChanged(), ScramServerErrorCode.CHANNEL_BINDINGS_DONT_MATCH);
}
bi.next(); // skip delim
// random nonce
if (bi.next() != 'r' || bi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
final byte[] clientNonce = initialResponse.getRawNonce();
final byte[] serverNonce = initialChallenge.getRawServerNonce();
final byte[] expectedMessageNonce = ByteBuffer.allocate(clientNonce.length + serverNonce.length).put(clientNonce).put(serverNonce).array();
final byte[] messageNonce = bi.delimitedBy(',').drain();
if (! MessageDigest.isEqual(messageNonce, expectedMessageNonce)) {
throw saslScram.mechNoncesDoNotMatch();
}
final int proofOffset = (int) bi.getIndex();
bi.next(); // skip delimiter
// proof
if (bi.next() != 'p' || bi.next() != '=') {
throw saslScram.mechInvalidClientMessage();
}
final byte[] proof;
proof = bi.delimitedBy(',').asUtf8String().base64Decode().drain();
if (bi.hasNext()) {
throw saslScram.mechInvalidClientMessage();
}
return new ScramFinalClientMessage(initialResponse, initialChallenge, initialResult.getScramDigestPassword(), proof, response, proofOffset);
} catch (NoSuchElementException ignored) {
throw saslScram.mechInvalidMessageReceived();
}
}
/**
* Evaluates a SCRAM final client message and authorizes the user.
*
* @param initialResult the result of the initial server message evaluation provided by {@link ScramServer#evaluateInitialResponse(ScramInitialClientMessage)}.
* @param clientMessage the final client message provided by {@link ScramServer#parseFinalClientMessage(ScramInitialClientMessage, ScramInitialServerResult, byte[])}.
* @return the final server message providing the server signature and response.
* @throws AuthenticationMechanismException if an error occurs during the evaluation.
*/
public ScramFinalServerMessage evaluateFinalClientMessage(final ScramInitialServerResult initialResult, final ScramFinalClientMessage clientMessage) throws AuthenticationMechanismException {
final boolean trace = saslScram.isTraceEnabled();
if (clientMessage.getMechanism() != mechanism) {
throw saslScram.mechUnmatchedMechanism(mechanism.toString(), clientMessage.getMechanism().toString());
}
ByteStringBuilder b = new ByteStringBuilder();
try {
final Mac mac = Mac.getInstance(getMechanism().getHmacName());
final MessageDigest messageDigest = MessageDigest.getInstance(getMechanism().getMessageDigestName());
// == verify proof ==
// client key
byte[] clientKey;
mac.reset();
byte[] saltedPassword = initialResult.getScramDigestPassword().getDigest();
mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
mac.update(ScramUtil.CLIENT_KEY_BYTES);
clientKey = mac.doFinal();
if(trace) saslScram.tracef("[S] Client key: %s%n", ByteIterator.ofBytes(clientKey).hexEncode().drainToString());
// stored key
byte[] storedKey;
messageDigest.reset();
messageDigest.update(clientKey);
storedKey = messageDigest.digest();
if(trace) saslScram.tracef("[S] Stored key: %s%n", ByteIterator.ofBytes(storedKey).hexEncode().drainToString());
// client signature
mac.reset();
mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
final byte[] clientFirstMessage = clientMessage.getInitialResponse().getRawMessageBytes();
final int clientFirstMessageBareStart = clientMessage.getInitialResponse().getInitialPartIndex();
mac.update(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart);
if(trace) saslScram.tracef("[S] Using client first message: %s%n", ByteIterator.ofBytes(copyOfRange(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length)).hexEncode().drainToString());
mac.update((byte) ',');
final byte[] serverFirstMessage = initialResult.getScramInitialChallenge().getRawMessageBytes();
mac.update(serverFirstMessage);
if(trace) saslScram.tracef("[S] Using server first message: %s%n", ByteIterator.ofBytes(serverFirstMessage).hexEncode().drainToString());
mac.update((byte) ',');
final byte[] response = clientMessage.getRawMessageBytes();
final int proofOffset = clientMessage.getProofOffset();
mac.update(response, 0, proofOffset); // client-final-message-without-proof
if(trace) saslScram.tracef("[S] Using client final message without proof: %s%n", ByteIterator.ofBytes(copyOfRange(response, 0, proofOffset)).hexEncode().drainToString());
byte[] clientSignature = mac.doFinal();
if(trace) saslScram.tracef("[S] Client signature: %s%n", ByteIterator.ofBytes(clientSignature).hexEncode().drainToString());
// server key
byte[] serverKey;
mac.reset();
mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
mac.update(ScramUtil.SERVER_KEY_BYTES);
serverKey = mac.doFinal();
if(trace) saslScram.tracef("[S] Server key: %s%n", ByteIterator.ofBytes(serverKey).hexEncode().drainToString());
// server signature
byte[] serverSignature;
mac.reset();
mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
mac.update(clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart);
mac.update((byte) ',');
mac.update(serverFirstMessage);
mac.update((byte) ',');
mac.update(response, 0, proofOffset); // client-final-message-without-proof
serverSignature = mac.doFinal();
if(trace) saslScram.tracef("[S] Server signature: %s%n", ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());
final byte[] recoveredClientProof = clientMessage.getRawClientProof();
if(trace) saslScram.tracef("[S] Client proof: %s%n", ByteIterator.ofBytes(recoveredClientProof).hexEncode().drainToString());
// now check the proof
byte[] recoveredClientKey = clientSignature.clone();
ScramUtil.xor(recoveredClientKey, recoveredClientProof);
if(trace) saslScram.tracef("[S] Recovered client key: %s%n", ByteIterator.ofBytes(recoveredClientKey).hexEncode().drainToString());
if (! MessageDigest.isEqual(recoveredClientKey, clientKey)) {
throw saslScram.mechAuthenticationRejectedInvalidProof();
}
String userName = clientMessage.getInitialResponse().getAuthenticationName();
String authorizationID = clientMessage.getInitialResponse().getAuthorizationId();
if (authorizationID == null || authorizationID.isEmpty()) {
authorizationID = userName;
} else {
ByteStringBuilder bsb = new ByteStringBuilder();
StringPrep.encode(authorizationID, bsb, StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
authorizationID = new String(bsb.toArray(), StandardCharsets.UTF_8);
}
final AuthorizeCallback authorizeCallback = new AuthorizeCallback(userName, authorizationID);
try {
MechanismUtil.handleCallbacks(saslScram, callbackHandler, authorizeCallback);
} catch (UnsupportedCallbackException e) {
throw saslScram.mechAuthorizationUnsupported(e);
}
if ( ! authorizeCallback.isAuthorized()) {
throw saslScram.mechAuthorizationFailed(userName, authorizationID);
}
// == send response ==
b.setLength(0);
b.append('v').append('=');
b.appendUtf8(ByteIterator.ofBytes(serverSignature).base64Encode());
return new ScramFinalServerMessage(serverSignature, b.toArray());
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw saslScram.mechMacAlgorithmNotSupported(e);
}
}
/**
* Returns the SCRAM mechanism used for the authentication.
*
* @return the SCRAM mechanism used for the authentication.
*/
public ScramMechanism getMechanism() {
return mechanism;
}
/**
* Returns the callback handler for the authentication.
*
* @return the callback handler for the authentication.
*/
public CallbackHandler getCallbackHandler() {
return callbackHandler;
}
/**
* Returns the RNG used for the authentication.
*
* @return the RNG used for the authentication.
*/
Random getRandom() {
return random != null ? random : ThreadLocalRandom.current();
}
/**
* Returns the copy of the binding data for the "PLUS" channel binding option.
*
* @return the copy of the binding data for the "PLUS" channel binding option.
*/
public byte[] getBindingData() {
return bindingData == null ? null : bindingData.clone();
}
/**
* Returns the binding data for the "PLUS" channel binding option.
*
* @return the binding data for the "PLUS" channel binding option.
*/
byte[] getRawBindingData() {
return bindingData;
}
/**
* Returns the binding type for the "PLUS" channel binding option.
*
* @return the binding type for the "PLUS" channel binding option.
*/
public String getBindingType() {
return bindingType;
}
}