com.mongodb.connection.ScramSha1Authenticator Maven / Gradle / Ivy
/*
* Copyright (c) 2008-2014 MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mongodb.connection;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.internal.authentication.NativeAuthenticationHelper;
import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.UnsupportedEncodingException;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.util.HashMap;
import java.util.Random;
import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1;
class ScramSha1Authenticator extends SaslAuthenticator {
private final RandomStringGenerator randomStringGenerator;
ScramSha1Authenticator(final MongoCredential credential) {
this(credential, new DefaultRandomStringGenerator());
}
ScramSha1Authenticator(final MongoCredential credential, final RandomStringGenerator randomStringGenerator) {
super(credential);
this.randomStringGenerator = randomStringGenerator;
}
@Override
public String getMechanismName() {
return SCRAM_SHA_1.getMechanismName();
}
@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
return new ScramSha1SaslClient(getCredential(), randomStringGenerator);
}
private static class ScramSha1SaslClient implements SaslClient {
private static final String GS2_HEADER = "n,,";
private static final int RANDOM_LENGTH = 24;
private final Base64Codec base64Codec;
private final MongoCredential credential;
private String clientFirstMessageBare;
private final RandomStringGenerator randomStringGenerator;
private String rPrefix;
private byte[] serverSignature;
private int step;
ScramSha1SaslClient(final MongoCredential credential, final RandomStringGenerator randomStringGenerator) {
this.credential = credential;
this.base64Codec = new Base64Codec();
this.randomStringGenerator = randomStringGenerator;
}
public String getMechanismName() {
return SCRAM_SHA_1.getMechanismName();
}
public boolean hasInitialResponse() {
return true;
}
public byte[] evaluateChallenge(final byte[] challenge) throws SaslException {
if (this.step == 0) {
this.step++;
return computeClientFirstMessage();
}
else if (this.step == 1) {
this.step++;
return computeClientFinalMessage(challenge);
}
else if (this.step == 2) {
this.step++;
String serverResponse = encodeUTF8(challenge);
HashMap map = parseServerResponse(serverResponse);
if (!MessageDigest.isEqual(decodeBase64(map.get("v")), this.serverSignature)) {
throw new SaslException("Server signature was invalid.");
}
return challenge;
}
else {
throw new SaslException("Too many steps involved in the SCRAM-SHA-1 negotiation.");
}
}
public boolean isComplete() {
return this.step > 2;
}
public byte[] unwrap(final byte[] incoming, final int offset, final int len) throws SaslException {
throw new UnsupportedOperationException("Not implemented yet!");
}
public byte[] wrap(final byte[] outgoing, final int offset, final int len) throws SaslException {
throw new UnsupportedOperationException("Not implemented yet!");
}
public Object getNegotiatedProperty(final String propName) {
throw new UnsupportedOperationException("Not implemented yet!");
}
public void dispose() throws SaslException {
// nothing to do
}
private byte[] computeClientFirstMessage() throws SaslException {
String userName = "n=" + prepUserName(this.credential.getUserName());
this.rPrefix = randomStringGenerator.generate(RANDOM_LENGTH);
String nonce = "r=" + this.rPrefix;
this.clientFirstMessageBare = userName + "," + nonce;
String clientFirstMessage = GS2_HEADER + this.clientFirstMessageBare;
return decodeUTF8(clientFirstMessage);
}
private byte[] computeClientFinalMessage(final byte[] challenge) throws SaslException {
String serverFirstMessage = encodeUTF8(challenge);
HashMap map = parseServerResponse(serverFirstMessage);
String r = map.get("r");
if (!r.startsWith(this.rPrefix)) {
throw new SaslException("Server sent an invalid nonce.");
}
String s = map.get("s");
String i = map.get("i");
String channelBinding = "c=" + encodeBase64(decodeUTF8(GS2_HEADER));
String nonce = "r=" + r;
String clientFinalMessageWithoutProof = channelBinding + "," + nonce;
byte[] saltedPassword = hi(
NativeAuthenticationHelper.createAuthenticationHash(this.credential.getUserName(),
this.credential.getPassword()),
decodeBase64(s),
Integer.parseInt(i)
);
byte[] clientKey = hmac(saltedPassword, "Client Key");
byte[] storedKey = h(clientKey);
String authMessage = this.clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
byte[] clientSignature = hmac(storedKey, authMessage);
byte[] clientProof = xor(clientKey, clientSignature);
byte[] serverKey = hmac(saltedPassword, "Server Key");
this.serverSignature = hmac(serverKey, authMessage);
String proof = "p=" + encodeBase64(clientProof);
String clientFinalMessage = clientFinalMessageWithoutProof + "," + proof;
return decodeUTF8(clientFinalMessage);
}
private byte[] decodeBase64(final String str) {
return this.base64Codec.decode(str);
}
private byte[] decodeUTF8(final String str) throws SaslException {
try {
return str.getBytes("UTF-8");
}
catch (UnsupportedEncodingException e) {
throw new SaslException("UTF-8 is not a supported encoding.", e);
}
}
private String encodeBase64(final byte[] bytes) {
return this.base64Codec.encode(bytes);
}
private String encodeUTF8(final byte[] bytes) throws SaslException {
try {
return new String(bytes, "UTF-8");
}
catch (UnsupportedEncodingException e) {
throw new SaslException("UTF-8 is not a supported encoding.", e);
}
}
private byte[] h(final byte[] data) throws SaslException {
try {
return MessageDigest.getInstance("SHA-1").digest(data);
}
catch (NoSuchAlgorithmException e) {
throw new SaslException("SHA-1 could not be found.", e);
}
}
private byte[] hi(final String password, final byte[] salt, final int iterations) throws SaslException {
PBEKeySpec spec = new PBEKeySpec(password.toCharArray(), salt, iterations, 20 * 8 /* 20 bytes */);
SecretKeyFactory keyFactory;
try {
keyFactory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1");
}
catch (NoSuchAlgorithmException e) {
throw new SaslException("Unable to find PBKDF2WithHmacSHA1.", e);
}
try {
return keyFactory.generateSecret(spec).getEncoded();
}
catch (InvalidKeySpecException e) {
throw new SaslException("Invalid key spec for PBKDC2WithHmacSHA1.", e);
}
}
private byte[] hmac(final byte[] bytes, final String key) throws SaslException {
SecretKeySpec signingKey = new SecretKeySpec(bytes, "HmacSHA1");
Mac mac;
try {
mac = Mac.getInstance("HmacSHA1");
}
catch (NoSuchAlgorithmException e) {
throw new SaslException("Could not find HmacSHA1.", e);
}
try {
mac.init(signingKey);
}
catch (InvalidKeyException e) {
throw new SaslException("Could not initialize mac.", e);
}
return mac.doFinal(decodeUTF8(key));
}
/**
* The server provides back key value pairs using an = sign and delimited
* by a command. All keys are also a single character.
* For example: a=kg4io3,b=skljsfoiew,c=1203
*/
private HashMap parseServerResponse(final String response) {
HashMap map = new HashMap();
String[] pairs = response.split(",");
for (String pair : pairs) {
String[] parts = pair.split("=", 2);
map.put(parts[0], parts[1]);
}
return map;
}
private String prepUserName(final String userName) {
return userName.replace("=", "=3D").replace(",", "=2D");
}
private byte[] xor(final byte[] a, final byte[] b) {
byte[] result = new byte[a.length];
for (int i = 0; i < a.length; i++) {
result[i] = (byte) (a[i] ^ b[i]);
}
return result;
}
}
public interface RandomStringGenerator {
String generate(int length);
}
public static class DefaultRandomStringGenerator implements RandomStringGenerator {
public String generate(final int length) {
int comma = 44;
int low = 33;
int high = 126;
int range = high - low;
Random random = new SecureRandom();
char[] text = new char[length];
for (int i = 0; i < length; i++) {
int next = random.nextInt(range) + low;
while (next == comma) {
next = random.nextInt(range) + low;
}
text[i] = (char) next;
}
return new String(text);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy