org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClient Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@code SaslClient} implementation for SASL/OAUTHBEARER in Kafka. This
* implementation requires an instance of {@code AuthenticateCallbackHandler}
* that can handle an instance of {@link OAuthBearerTokenCallback} and return
* the {@link OAuthBearerToken} generated by the {@code login()} event on the
* {@code LoginContext}. Said handler can also optionally handle an instance of {@link SaslExtensionsCallback}
* to return any extensions generated by the {@code login()} event on the {@code LoginContext}.
*
* @see RFC 6750,
* Section 2.1
*
*/
public class OAuthBearerSaslClient implements SaslClient {
static final byte BYTE_CONTROL_A = (byte) 0x01;
private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslClient.class);
private final CallbackHandler callbackHandler;
enum State {
SEND_CLIENT_FIRST_MESSAGE, RECEIVE_SERVER_FIRST_MESSAGE, RECEIVE_SERVER_MESSAGE_AFTER_FAILURE, COMPLETE, FAILED
}
private State state;
public OAuthBearerSaslClient(AuthenticateCallbackHandler callbackHandler) {
this.callbackHandler = Objects.requireNonNull(callbackHandler);
setState(State.SEND_CLIENT_FIRST_MESSAGE);
}
public CallbackHandler callbackHandler() {
return callbackHandler;
}
@Override
public String getMechanismName() {
return OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
}
@Override
public boolean hasInitialResponse() {
return true;
}
@Override
public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
switch (state) {
case SEND_CLIENT_FIRST_MESSAGE:
if (challenge != null && challenge.length != 0)
throw new SaslException("Expected empty challenge");
callbackHandler().handle(new Callback[] {callback});
SaslExtensions extensions = retrieveCustomExtensions();
setState(State.RECEIVE_SERVER_FIRST_MESSAGE);
return new OAuthBearerClientInitialResponse(callback.token().value(), extensions).toBytes();
case RECEIVE_SERVER_FIRST_MESSAGE:
if (challenge != null && challenge.length != 0) {
String jsonErrorResponse = new String(challenge, StandardCharsets.UTF_8);
if (log.isDebugEnabled())
log.debug("Sending %%x01 response to server after receiving an error: {}",
jsonErrorResponse);
setState(State.RECEIVE_SERVER_MESSAGE_AFTER_FAILURE);
return new byte[] {BYTE_CONTROL_A};
}
callbackHandler().handle(new Callback[] {callback});
if (log.isDebugEnabled())
log.debug("Successfully authenticated as {}", callback.token().principalName());
setState(State.COMPLETE);
return null;
default:
throw new IllegalSaslStateException("Unexpected challenge in Sasl client state " + state);
}
} catch (SaslException e) {
setState(State.FAILED);
throw e;
} catch (IOException | UnsupportedCallbackException e) {
setState(State.FAILED);
throw new SaslException(e.getMessage(), e);
}
}
@Override
public boolean isComplete() {
return state == State.COMPLETE;
}
@Override
public byte[] unwrap(byte[] incoming, int offset, int len) {
if (!isComplete())
throw new IllegalStateException("Authentication exchange has not completed");
return Arrays.copyOfRange(incoming, offset, offset + len);
}
@Override
public byte[] wrap(byte[] outgoing, int offset, int len) {
if (!isComplete())
throw new IllegalStateException("Authentication exchange has not completed");
return Arrays.copyOfRange(outgoing, offset, offset + len);
}
@Override
public Object getNegotiatedProperty(String propName) {
if (!isComplete())
throw new IllegalStateException("Authentication exchange has not completed");
return null;
}
@Override
public void dispose() {
}
private void setState(State state) {
log.debug("Setting SASL/{} client state to {}", OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, state);
this.state = state;
}
private SaslExtensions retrieveCustomExtensions() throws SaslException {
SaslExtensionsCallback extensionsCallback = new SaslExtensionsCallback();
try {
callbackHandler().handle(new Callback[] {extensionsCallback});
} catch (UnsupportedCallbackException e) {
log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added",
callbackHandler());
} catch (Exception e) {
throw new SaslException("SASL extensions could not be obtained", e);
}
return extensionsCallback.extensions();
}
public static class OAuthBearerSaslClientFactory implements SaslClientFactory {
@Override
public SaslClient createSaslClient(String[] mechanisms, String authorizationId, String protocol,
String serverName, Map props, CallbackHandler callbackHandler) {
String[] mechanismNamesCompatibleWithPolicy = getMechanismNames(props);
for (String mechanism : mechanisms) {
for (int i = 0; i < mechanismNamesCompatibleWithPolicy.length; i++) {
if (mechanismNamesCompatibleWithPolicy[i].equals(mechanism)) {
if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler))
throw new IllegalArgumentException(String.format(
"Callback handler must be castable to %s: %s",
AuthenticateCallbackHandler.class.getName(), callbackHandler.getClass().getName()));
return new OAuthBearerSaslClient((AuthenticateCallbackHandler) callbackHandler);
}
}
}
return null;
}
@Override
public String[] getMechanismNames(Map props) {
return OAuthBearerSaslServer.mechanismNamesCompatibleWithPolicy(props);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy