io.kroxylicious.proxy.internal.ProxyChannelStateMachine Maven / Gradle / Ivy
Show all versions of kroxylicious-runtime Show documentation
/*
* Copyright Kroxylicious Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/
package io.kroxylicious.proxy.internal;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.apache.kafka.common.errors.ApiException;
import org.apache.kafka.common.message.ApiVersionsRequestData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.Errors;
import org.slf4j.Logger;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.ProxyChannelState.Closed;
import io.kroxylicious.proxy.internal.ProxyChannelState.Forwarding;
import io.kroxylicious.proxy.internal.codec.FrameOversizedException;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.service.HostPort;
import io.kroxylicious.proxy.tag.VisibleForTesting;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import static io.kroxylicious.proxy.internal.ProxyChannelState.Startup.STARTING_STATE;
import static org.slf4j.LoggerFactory.getLogger;
/**
* The state machine for a single client's proxy session.
* The "session state" is held in the {@link #state} field and is represented by an immutable
* subclass of {@link ProxyChannelState} which contains state-specific data.
* Events which cause state transitions are represented by the {@code on*()} family of methods.
* Depending on the transition the frontend or backend handlers may get notified via one if their
* {@code in*()} methods.
*
*
*
* «start»
* │
* ↓ frontend.{@link KafkaProxyFrontendHandler#channelActive(ChannelHandlerContext) channelActive}
* {@link ProxyChannelState.ClientActive ClientActive} ╌╌╌╌⤍ error ╌╌╌╌⤍
* ╭───┤
* ↓ ↓ frontend.{@link KafkaProxyFrontendHandler#channelRead(ChannelHandlerContext, Object) channelRead} receives a PROXY header
* │ {@link ProxyChannelState.HaProxy HaProxy} ╌╌╌╌⤍ error ╌╌╌╌⤍
* ╰───┤
* ╭───┤
* ↓ ↓ frontend.{@link KafkaProxyFrontendHandler#channelRead(ChannelHandlerContext, Object) channelRead} receives an ApiVersions request
* │ {@link ProxyChannelState.ApiVersions ApiVersions} ╌╌╌╌⤍ error ╌╌╌╌⤍
* ╰───┤
* ↓ frontend.{@link KafkaProxyFrontendHandler#channelRead(ChannelHandlerContext, Object) channelRead} receives any other KRPC request
* {@link ProxyChannelState.SelectingServer SelectingServer} ╌╌╌╌⤍ error ╌╌╌╌⤍
* │
* ↓ netFilter.{@link NetFilter#selectServer(NetFilter.NetFilterContext) selectServer} calls frontend.{@link KafkaProxyFrontendHandler#initiateConnect(HostPort, List) initiateConnect}
* {@link ProxyChannelState.Connecting Connecting} ╌╌╌╌⤍ error ╌╌╌╌⤍
* │
* ↓
* {@link Forwarding Forwarding} ╌╌╌╌⤍ error ╌╌╌╌⤍
* │ backend.{@link KafkaProxyBackendHandler#channelInactive(ChannelHandlerContext) channelInactive}
* │ or frontend.{@link KafkaProxyFrontendHandler#channelInactive(ChannelHandlerContext) channelInactive}
* ↓
* {@link Closed Closed} ⇠╌╌╌╌ error ⇠╌╌╌╌
*
*
* In addition to the "session state" this class also manages a second state machine for
* handling TCP backpressure via the {@link #clientReadsBlocked} and {@link #serverReadsBlocked} field:
*
*
* bothBlocked ←────────────────→ serverBlocked
* ↑ ↑
* │ │
* ↓ ↓
* clientBlocked ←───────────────→ neitherBlocked
*
* Note that this backpressure state machine is not tied to the
* session state machine: in general backpressure could happen in
* several of the session states and is independent of them.
*
*
* When either side of the proxy stats applying back pressure the proxy should propagate that fact to teh other peer.
* Thus when the proxy is notified that a peer is applying back pressure it results in action on the channel with the opposite peer.
*
*/
public class ProxyChannelStateMachine {
private static final String DUPLICATE_INITIATE_CONNECT_ERROR = "NetFilter called NetFilterContext.initiateConnect() more than once";
private static final Logger LOGGER = getLogger(ProxyChannelStateMachine.class);
/**
* The current state. This can be changed via a call to one of the {@code on*()} methods.
*/
@NonNull
private ProxyChannelState state = STARTING_STATE;
/*
* The netty autoread flag is volatile =>
* expensive to set in every call to channelRead.
* So we track autoread states via these non-volatile fields,
* allowing us to only touch the volatile when it needs to be changed
*/
@VisibleForTesting
boolean serverReadsBlocked;
@VisibleForTesting
boolean clientReadsBlocked;
/**
* The frontend handler. Non-null if we got as far as ClientActive.
*/
@SuppressWarnings({ "DataFlowIssue", "java:S2637" })
@NonNull
private KafkaProxyFrontendHandler frontendHandler = null;
/**
* The backend handler. Non-null if {@link #onNetFilterInitiateConnect(HostPort, List, VirtualCluster, NetFilter)}
* has been called
*/
@VisibleForTesting
@Nullable
private KafkaProxyBackendHandler backendHandler;
ProxyChannelState state() {
return state;
}
/**
* Purely for tests DO NOT USE IN PRODUCTION code!!
* Sonar will complain if one uses this in prod code listen to it.
*/
@VisibleForTesting
void forceState(@NonNull ProxyChannelState state, @NonNull KafkaProxyFrontendHandler frontendHandler, @Nullable KafkaProxyBackendHandler backendHandler) {
LOGGER.info("Forcing state to {} with {} and {}", state, frontendHandler, backendHandler);
this.state = state;
this.frontendHandler = frontendHandler;
this.backendHandler = backendHandler;
}
@Override
public String toString() {
return "StateHolder{" +
"state=" + state +
", serverReadsBlocked=" + serverReadsBlocked +
", clientReadsBlocked=" + clientReadsBlocked +
", frontendHandler=" + frontendHandler +
", backendHandler=" + backendHandler +
'}';
}
public String currentState() {
return this.state().getClass().getSimpleName();
}
/**
* Notify the state machine when the client applies back pressure.
*/
public void onClientUnwritable() {
if (!serverReadsBlocked) {
serverReadsBlocked = true;
Objects.requireNonNull(backendHandler).applyBackpressure();
}
}
/**
* Notify the state machine when the client stops applying back pressure
*/
public void onClientWritable() {
if (serverReadsBlocked) {
serverReadsBlocked = false;
Objects.requireNonNull(backendHandler).relieveBackpressure();
}
}
/**
* Notify the state machine when the server applies back pressure
*/
public void onServerUnwritable() {
if (!clientReadsBlocked) {
clientReadsBlocked = true;
frontendHandler.applyBackpressure();
}
}
/**
* Notify the state machine when the server stops applying back pressure
*/
public void onServerWritable() {
if (clientReadsBlocked) {
clientReadsBlocked = false;
frontendHandler.relieveBackpressure();
}
}
/**
* Notify the statemachine that the client channel has an active TCP connection.
* @param frontendHandler with active connection
*/
void onClientActive(@NonNull KafkaProxyFrontendHandler frontendHandler) {
if (STARTING_STATE.equals(this.state)) {
this.frontendHandler = frontendHandler;
toClientActive(STARTING_STATE.toClientActive(), frontendHandler);
}
else {
illegalState("Client activation while not in the start state");
}
}
/**
* Notify the statemachine that the netfilter has chosen an outbound peer.
* @param peer the upstream host to connect to.
* @param filters the set of filters to be applied to the session
* @param virtualCluster the virtual cluster the client is connecting too
* @param netFilter the netFilter which selected the upstream peer.
*/
void onNetFilterInitiateConnect(
@NonNull HostPort peer,
@NonNull List filters,
VirtualCluster virtualCluster,
NetFilter netFilter) {
if (state instanceof ProxyChannelState.SelectingServer selectingServerState) {
toConnecting(selectingServerState.toConnecting(peer), filters, virtualCluster);
}
else {
illegalState(DUPLICATE_INITIATE_CONNECT_ERROR + " : netFilter='" + netFilter + "'");
}
}
/**
* Notify the statemachine that the upstream connection is ready for RPC calls.
*/
void onServerActive() {
if (state() instanceof ProxyChannelState.Connecting connectedState) {
toForwarding(connectedState.toForwarding());
}
else {
illegalState("Server became active while not in the connecting state");
}
}
/**
* Notify the state machine of an unexpected event.
* The definition of unexpected events is up to the callers.
* An example would be trying to forward an event upstream before the upstream connection is established.
*
* illegalState implies termination of the proxy session. As this really represents a programming error NO error messages are propagated to clients.
*
* @param msg the message to be logged in explanation of the error condition
*/
void illegalState(@NonNull String msg) {
if (!(state instanceof Closed)) {
LOGGER.error("Unexpected event while in {} message: {}, closing channels with no client response.", state, msg);
toClosed(null);
}
}
/**
* A message has been received from the upstream node which should be passed to the downstream client
* @param msg the object received from the upstream
*/
void messageFromServer(Object msg) {
Objects.requireNonNull(frontendHandler).forwardToClient(msg);
}
/**
* Called to notify the state machine that reading the upstream batch is complete.
*/
void serverReadComplete() {
Objects.requireNonNull(frontendHandler).flushToClient();
}
/**
* A message has been received from the downstream client which should be passed to the upstream node
* @param msg the RPC received from the upstream
*/
void messageFromClient(Object msg) {
Objects.requireNonNull(backendHandler).forwardToServer(msg);
}
/**
* Called to notify the state machine that reading the downstream the batch is complete.
*/
void clientReadComplete() {
if (state instanceof Forwarding) {
Objects.requireNonNull(backendHandler).flushToServer();
}
}
/**
* The proxy has received something from the client. The current state of the session determines what happens to it.
* @param dp the decode predicate to be used if the session is still being negotiated
* @param msg the RPC received from the downstream client
*/
void onClientRequest(
@NonNull SaslDecodePredicate dp,
Object msg) {
Objects.requireNonNull(frontendHandler);
if (state() instanceof Forwarding) { // post-backend connection
messageFromClient(msg);
}
else if (!onClientRequestBeforeForwarding(dp, msg)) {
illegalState("Unexpected message received: " + (msg == null ? "null" : "message class=" + msg.getClass()));
}
}
/**
* ensure the state machine is in the connecting state.
* @param msg to be logged if in another state.
*/
void assertIsConnecting(String msg) {
if (!(state instanceof ProxyChannelState.Connecting)) {
illegalState(msg);
}
}
/**
* ensure the state machine is in the selecting server state.
*
* @return the SelectingServer state
* @throws IllegalStateException if the state is not {@link ProxyChannelState.SelectingServer}.
*/
ProxyChannelState.SelectingServer enforceInSelectingServer(String errorMessage) {
if (state instanceof ProxyChannelState.SelectingServer selectingServerState) {
return selectingServerState;
}
else {
illegalState(errorMessage);
throw new IllegalStateException("State required to be "
+ ProxyChannelState.SelectingServer.class.getSimpleName()
+ " but was "
+ currentState()
+ ":"
+ errorMessage);
}
}
/**
* Notify the statemachine that the connection to the upstream node has been disconnected.
*
* This will result in the proxy session being torn down.
*
*/
void onServerInactive() {
toClosed(null);
}
/**
* Notify the statemachine that the connection to the downstream client has been disconnected.
*
* This will result in the proxy session being torn down.
*
*/
void onClientInactive() {
toClosed(null);
}
/**
* Notify the state machine that something exceptional and un-recoverable has happened on the upstream side.
* @param cause the exception that triggered the issue
*/
void onServerException(Throwable cause) {
LOGGER.atWarn()
.setCause(LOGGER.isDebugEnabled() ? cause : null)
.addArgument(cause != null ? cause.getMessage() : "")
.log("Exception from the server channel: {}. Increase log level to DEBUG for stacktrace");
toClosed(cause);
}
/**
* Notify the state machine that something exceptional and un-recoverable has happened on the downstream side.
* @param cause the exception that triggered the issue
*/
void onClientException(Throwable cause, boolean tlsEnabled) {
ApiException errorCodeEx;
if (cause instanceof DecoderException de
&& de.getCause() instanceof FrameOversizedException e) {
var tlsHint = tlsEnabled ? "" : " or an unexpected TLS handshake";
LOGGER.warn(
"Received over-sized frame from the client, max frame size bytes {}, received frame size bytes {} "
+ "(hint: are we decoding a Kafka frame, or something unexpected like an HTTP request{}?)",
e.getMaxFrameSizeBytes(), e.getReceivedFrameSizeBytes(), tlsHint);
errorCodeEx = Errors.INVALID_REQUEST.exception();
}
else {
LOGGER.atWarn()
.setCause(LOGGER.isDebugEnabled() ? cause : null)
.addArgument(cause != null ? cause.getMessage() : "")
.log("Exception from the client channel: {}. Increase log level to DEBUG for stacktrace");
errorCodeEx = Errors.UNKNOWN_SERVER_ERROR.exception();
}
toClosed(errorCodeEx);
}
private void toClientActive(
@NonNull ProxyChannelState.ClientActive clientActive,
@NonNull KafkaProxyFrontendHandler frontendHandler) {
setState(clientActive);
frontendHandler.inClientActive();
}
private void toConnecting(
ProxyChannelState.Connecting connecting,
@NonNull List filters,
VirtualCluster virtualCluster) {
setState(connecting);
backendHandler = new KafkaProxyBackendHandler(this, virtualCluster);
frontendHandler.inConnecting(connecting.remote(), filters, backendHandler);
}
private void toForwarding(Forwarding forwarding) {
setState(forwarding);
Objects.requireNonNull(frontendHandler).inForwarding();
}
/**
* handle a message received from the client prior to connecting to the upstream node
* @param dp DecodePredicate to cope with SASL offload
* @param msg Message received from the downstream client.
* @return false
for unsupported message types
*/
private boolean onClientRequestBeforeForwarding(@NonNull SaslDecodePredicate dp, Object msg) {
frontendHandler.bufferMsg(msg);
if (state() instanceof ProxyChannelState.ClientActive clientActive) {
return onClientRequestInClientActiveState(dp, msg, clientActive);
}
else if (state() instanceof ProxyChannelState.HaProxy haProxy) {
return onClientRequestInHaProxyState(dp, msg, haProxy);
}
else if (state() instanceof ProxyChannelState.ApiVersions apiVersions) {
return onClientRequestInApiVersionsState(dp, msg, apiVersions);
}
else if (state() instanceof ProxyChannelState.SelectingServer) {
return msg instanceof RequestFrame;
}
else {
return state() instanceof ProxyChannelState.Connecting && msg instanceof RequestFrame;
}
}
@SuppressWarnings("java:S1172")
// We keep dp as we should need it and it gives consistency with the other onClientRequestIn methods (sue me)
private boolean onClientRequestInApiVersionsState(@NonNull SaslDecodePredicate dp, Object msg, ProxyChannelState.ApiVersions apiVersions) {
if (msg instanceof RequestFrame) {
// TODO if dp.isAuthenticationOffloadEnabled() then we need to forward to that handler
// TODO we only do the connection once we know the authenticated identity
toSelectingServer(apiVersions.toSelectingServer());
return true;
}
return false;
}
private boolean onClientRequestInHaProxyState(@NonNull SaslDecodePredicate dp, Object msg, ProxyChannelState.HaProxy haProxy) {
return transitionClientRequest(dp, msg, haProxy::toApiVersions, haProxy::toSelectingServer);
}
private boolean transitionClientRequest(
@NonNull SaslDecodePredicate dp,
Object msg,
Function, ProxyChannelState.ApiVersions> apiVersionsFactory,
Function, ProxyChannelState.SelectingServer> selectingServerFactory) {
if (isMessageApiVersionsRequest(msg)) {
// We know it's an API Versions request even if the compiler doesn't
@SuppressWarnings("unchecked")
DecodedRequestFrame apiVersionsFrame = (DecodedRequestFrame) msg;
if (dp.isAuthenticationOffloadEnabled()) {
toApiVersions(apiVersionsFactory.apply(apiVersionsFrame), apiVersionsFrame);
}
else {
toSelectingServer(selectingServerFactory.apply(apiVersionsFrame));
}
return true;
}
else if (msg instanceof RequestFrame) {
toSelectingServer(selectingServerFactory.apply(null));
return true;
}
return false;
}
private boolean onClientRequestInClientActiveState(@NonNull SaslDecodePredicate dp, Object msg, ProxyChannelState.ClientActive clientActive) {
if (msg instanceof HAProxyMessage haProxyMessage) {
toHaProxy(clientActive.toHaProxy(haProxyMessage));
return true;
}
else {
return transitionClientRequest(dp, msg, clientActive::toApiVersions, clientActive::toSelectingServer);
}
}
private void toHaProxy(ProxyChannelState.HaProxy haProxy) {
setState(haProxy);
}
private void toApiVersions(
ProxyChannelState.ApiVersions apiVersions,
DecodedRequestFrame apiVersionsFrame) {
setState(apiVersions);
Objects.requireNonNull(frontendHandler).inApiVersions(apiVersionsFrame);
}
private void toSelectingServer(ProxyChannelState.SelectingServer selectingServer) {
setState(selectingServer);
Objects.requireNonNull(frontendHandler).inSelectingServer();
}
private void toClosed(@Nullable Throwable errorCodeEx) {
if (state instanceof Closed) {
return;
}
setState(new Closed());
// Close the server connection
if (backendHandler != null) {
backendHandler.inClosed();
}
// Close the client connection with any error code
Objects.requireNonNull(frontendHandler).inClosed(errorCodeEx);
}
private void setState(@NonNull ProxyChannelState state) {
LOGGER.trace("{} transitioning to {}", this, state);
this.state = state;
}
private static boolean isMessageApiVersionsRequest(Object msg) {
return msg instanceof DecodedRequestFrame
&& ((DecodedRequestFrame>) msg).apiKey() == ApiKeys.API_VERSIONS;
}
}