All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mariadb.r2dbc.message.flow.AuthenticationFlow Maven / Gradle / Ivy

The newest version!
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2020-2024 MariaDB Corporation Ab

package org.mariadb.r2dbc.message.flow;

import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import java.util.Arrays;
import org.mariadb.r2dbc.ExceptionFactory;
import org.mariadb.r2dbc.MariadbConnectionConfiguration;
import org.mariadb.r2dbc.SslMode;
import org.mariadb.r2dbc.authentication.AuthenticationFlowPluginLoader;
import org.mariadb.r2dbc.authentication.AuthenticationPlugin;
import org.mariadb.r2dbc.authentication.standard.CachingSha2PasswordFlow;
import org.mariadb.r2dbc.client.Client;
import org.mariadb.r2dbc.client.DecoderState;
import org.mariadb.r2dbc.client.SimpleClient;
import org.mariadb.r2dbc.message.ClientMessage;
import org.mariadb.r2dbc.message.ServerMessage;
import org.mariadb.r2dbc.message.client.HandshakeResponse;
import org.mariadb.r2dbc.message.client.SslRequestPacket;
import org.mariadb.r2dbc.message.server.*;
import org.mariadb.r2dbc.util.Assert;
import org.mariadb.r2dbc.util.HostAddress;
import org.mariadb.r2dbc.util.constants.Capabilities;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.util.Logger;
import reactor.util.Loggers;

public final class AuthenticationFlow {
  private static final Logger logger = Loggers.getLogger(AuthenticationFlow.class);

  private final MariadbConnectionConfiguration configuration;
  private final SimpleClient client;
  private final HostAddress hostAddress;
  private InitialHandshakePacket initialHandshakePacket;
  private AuthenticationPlugin pluginHandler;
  private byte[] seed;
  private Sequencer sequencer;
  private AuthMoreDataPacket authMoreDataPacket;
  private FluxSink sink;
  private long clientCapabilities;

  private AuthenticationFlow(
      SimpleClient client, MariadbConnectionConfiguration configuration, HostAddress hostAddress) {
    this.client = client;
    this.configuration = configuration;
    this.hostAddress = hostAddress;
  }

  public static Mono exchange(
      SimpleClient client, MariadbConnectionConfiguration configuration, HostAddress hostAddress) {
    AuthenticationFlow flow = new AuthenticationFlow(client, configuration, hostAddress);
    Assert.requireNonNull(client, "client must not be null");

    return Flux.create(
            sink -> {
              flow.sink = sink;
              State.INIT.handle(flow).subscribe(sink::next, sink::error);
            })
        .doOnNext(
            state -> {
              if (State.COMPLETED == state) {
                if (flow.authMoreDataPacket != null) {
                  flow.authMoreDataPacket.release();
                  flow.authMoreDataPacket = null;
                }
                flow.sink.complete();
              } else {
                if (logger.isTraceEnabled()) {
                  logger.trace("authentication state {}", state);
                }
                state.handle(flow).subscribe(flow.sink::next, flow.sink::error);
              }
            })
        .doOnComplete(
            () -> {
              if (logger.isDebugEnabled()) {
                logger.debug("Authentication success");
              }
            })
        .doOnError(
            e -> {
              logger.warn("Authentication failed", e);
              flow.client.handleConnectionError(e);
            })
        .doFinally(
            c -> {
              if (flow.authMoreDataPacket != null) {
                flow.authMoreDataPacket.release();
                flow.authMoreDataPacket = null;
              }
            })
        .then(Mono.just(client));
  }

  private static long initializeClientCapabilities(
      final long serverCapabilities, MariadbConnectionConfiguration configuration) {
    long capabilities =
        Capabilities.IGNORE_SPACE
            | Capabilities.CLIENT_PROTOCOL_41
            | Capabilities.TRANSACTIONS
            | Capabilities.SECURE_CONNECTION
            | Capabilities.MULTI_RESULTS
            | Capabilities.PS_MULTI_RESULTS
            | Capabilities.PLUGIN_AUTH
            | Capabilities.CONNECT_ATTRS
            | Capabilities.PLUGIN_AUTH_LENENC_CLIENT_DATA
            | Capabilities.CLIENT_SESSION_TRACK
            | Capabilities.FOUND_ROWS
            | Capabilities.MARIADB_CLIENT_CACHE_METADATA;

    if (configuration.allowMultiQueries()) {
      capabilities |= Capabilities.MULTI_STATEMENTS;
    }

    if (configuration.getDatabase() != null) {
      capabilities |= Capabilities.CONNECT_WITH_DB;
    }

    return capabilities & serverCapabilities;
  }

  private HandshakeResponse createHandshakeResponse(long clientCapabilities) {
    return new HandshakeResponse(
        this.initialHandshakePacket,
        this.configuration.getUsername(),
        this.configuration.getPassword(),
        this.configuration.getDatabase(),
        configuration.getConnectionAttributes(),
        this.hostAddress,
        clientCapabilities);
  }

  private SslRequestPacket createSslRequest(long clientCapabilities) {
    return new SslRequestPacket(this.initialHandshakePacket, clientCapabilities);
  }

  public enum State {
    INIT {
      @Override
      Mono handle(AuthenticationFlow flow) {
        // Server send first, so no need send anything to server in here.
        return flow.client
            .receive(DecoderState.INIT_HANDSHAKE)
            .handle(
                (message, sink) -> {
                  if (message instanceof ErrorPacket) {
                    sink.error(ExceptionFactory.INSTANCE.from((ErrorPacket) message));
                  } else if (message instanceof InitialHandshakePacket) {
                    InitialHandshakePacket packet = (InitialHandshakePacket) message;
                    flow.initialHandshakePacket = packet;
                    flow.clientCapabilities =
                        initializeClientCapabilities(
                            flow.initialHandshakePacket.getCapabilities(), flow.configuration);
                    flow.client.setContext(packet, flow.clientCapabilities);

                    SslMode sslMode = flow.configuration.getSslConfig().getSslMode();
                    if (sslMode != SslMode.DISABLE && sslMode != SslMode.TUNNEL) {
                      if ((packet.getCapabilities() & Capabilities.SSL) == 0) {
                        sink.error(
                            new R2dbcNonTransientResourceException(
                                "Trying to connect with ssl, but ssl not enabled in the server",
                                "08000"));
                      } else {
                        sink.next(SSL_REQUEST);
                      }
                    } else {
                      sink.next(HANDSHAKE);
                    }

                  } else {
                    sink.error(
                        new IllegalStateException(
                            String.format(
                                "Unexpected message type '%s' in handshake init phase",
                                message.getClass().getSimpleName())));
                  }
                })
            .next();
      }
    },

    SSL_REQUEST {
      @Override
      Mono handle(AuthenticationFlow flow) {
        flow.clientCapabilities |= Capabilities.SSL;
        SslRequestPacket sslRequest = flow.createSslRequest(flow.clientCapabilities);
        return flow.client
            .sendSslRequest(sslRequest, flow.configuration)
            .then(Mono.just(HANDSHAKE));
      }
    },

    HANDSHAKE {
      @Override
      Mono handle(AuthenticationFlow flow) {
        flow.seed = flow.initialHandshakePacket.getSeed();
        flow.sequencer = flow.initialHandshakePacket.getSequencer();

        if (flow.initialHandshakePacket
            .getAuthenticationPluginType()
            .equals(CachingSha2PasswordFlow.TYPE)) {
          AuthenticationPlugin authPlugin =
              AuthenticationFlowPluginLoader.get(CachingSha2PasswordFlow.TYPE);
          ((CachingSha2PasswordFlow) authPlugin).setStateFastAuth();
          flow.authMoreDataPacket = null;
          flow.pluginHandler = authPlugin;
        }

        return flow.client
            .sendCommand(
                flow.createHandshakeResponse(flow.clientCapabilities),
                DecoderState.AUTHENTICATION_SWITCH_RESPONSE,
                false)
            .handle(
                (message, sink) -> {
                  if (message instanceof ErrorPacket) {
                    sink.error(ExceptionFactory.createException((ErrorPacket) message, null));
                  } else if (message instanceof OkPacket) {
                    sink.next(COMPLETED);
                  } else if (message instanceof AuthSwitchPacket) {
                    AuthSwitchPacket authSwitchPacket = ((AuthSwitchPacket) message);
                    flow.seed = authSwitchPacket.getSeed();
                    flow.sequencer = authSwitchPacket.getSequencer();
                    String plugin = authSwitchPacket.getPlugin();
                    if (flow.configuration.getRestrictedAuth() != null
                        && !Arrays.stream(flow.configuration.getRestrictedAuth())
                            .anyMatch(s -> plugin.equals(s))) {
                      sink.error(
                          new R2dbcPermissionDeniedException(
                              String.format(
                                  "Unsupported authentication plugin %s. Authorized plugin: %s",
                                  plugin,
                                  Arrays.toString(flow.configuration.getRestrictedAuth()))));
                    } else {
                      AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin);
                      flow.authMoreDataPacket = null;
                      flow.pluginHandler = authPlugin;
                      sink.next(AUTH_SWITCH);
                    }
                  } else if (flow.pluginHandler != null && message instanceof AuthMoreDataPacket) {
                    flow.authMoreDataPacket = (AuthMoreDataPacket) message;
                    sink.next(AUTH_SWITCH);
                  } else {
                    sink.error(
                        new IllegalStateException(
                            String.format(
                                "Unexpected message type '%s' in handshake response phase",
                                message.getClass().getSimpleName())));
                  }
                })
            .next();
      }
    },

    AUTH_SWITCH {
      @Override
      Mono handle(AuthenticationFlow flow) {
        ClientMessage clientMessage;
        try {
          clientMessage =
              flow.pluginHandler.next(
                  flow.configuration, flow.seed, flow.sequencer, flow.authMoreDataPacket);
        } catch (R2dbcException ex) {
          return Mono.error(ex);
        }

        Flux flux;
        if (clientMessage != null) {
          // this can occur when there is a "finishing" message for authentication plugin
          // example CachingSha2PasswordFlow that finish with a successful FAST_AUTH
          flux =
              flow.client.sendCommand(
                  clientMessage, DecoderState.AUTHENTICATION_SWITCH_RESPONSE, false);
        } else {
          flux = flow.client.receive(DecoderState.AUTHENTICATION_SWITCH_RESPONSE);
        }
        if (flow.authMoreDataPacket != null) {
          flow.authMoreDataPacket.release();
          flow.authMoreDataPacket = null;
        }
        return flux.handle(
                (message, sink) -> {
                  if (message instanceof ErrorPacket) {
                    sink.error(
                        new R2dbcNonTransientResourceException(((ErrorPacket) message).message()));
                  } else if (message instanceof OkPacket) {
                    sink.next(COMPLETED);
                  } else if (message instanceof AuthSwitchPacket) {
                    AuthSwitchPacket authSwitchPacket = ((AuthSwitchPacket) message);
                    flow.seed = authSwitchPacket.getSeed();
                    flow.sequencer = authSwitchPacket.getSequencer();
                    String plugin = authSwitchPacket.getPlugin();
                    if (flow.configuration.getRestrictedAuth() != null
                        && !Arrays.stream(flow.configuration.getRestrictedAuth())
                            .anyMatch(s -> plugin.equals(s))) {
                      sink.error(
                          new R2dbcPermissionDeniedException(
                              String.format(
                                  "Unsupported authentication plugin %s. Authorized plugin: %s",
                                  plugin,
                                  Arrays.toString(flow.configuration.getRestrictedAuth()))));
                    } else {
                      AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin);
                      flow.authMoreDataPacket = null;
                      flow.pluginHandler = authPlugin;
                      sink.next(AUTH_SWITCH);
                    }
                  } else if (message instanceof AuthMoreDataPacket) {
                    flow.authMoreDataPacket = (AuthMoreDataPacket) message;
                    flow.sequencer = (Sequencer) ((AuthMoreDataPacket) message).getSequencer();
                    sink.next(AUTH_SWITCH);
                  } else {
                    sink.error(
                        new IllegalStateException(
                            String.format(
                                "Unexpected message type '%s' in handshake response phase",
                                message.getClass().getSimpleName())));
                  }
                })
            .next();
      }
    },

    COMPLETED {
      @Override
      Mono handle(AuthenticationFlow flow) {
        return Mono.just(COMPLETED);
      }
    };

    abstract Mono handle(AuthenticationFlow flow);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy