All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.arrow.adbc.driver.flightsql.FlightSqlConnection Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.adbc.driver.flightsql;

import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.github.benmanes.caffeine.cache.RemovalCause;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDriver;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.BulkIngestMode;
import org.apache.arrow.adbc.sql.SqlQuirks;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightCallHeaders;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.HeaderCallOption;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.checkerframework.checker.initialization.qual.UnknownInitialization;
import org.checkerframework.checker.nullness.qual.Nullable;

public class FlightSqlConnection implements AdbcConnection {
  private final BufferAllocator allocator;
  private final AtomicInteger counter = new AtomicInteger(0);
  private final FlightSqlClientWithCallOptions client;
  private final SqlQuirks quirks;
  private final Map parameters;
  private final LoadingCache clientCache;

  // Cached data to use across additional connections.
  private ClientCookieMiddleware.@Nullable Factory cookieMiddlewareFactory;
  private CallOption[] callOptions;

  // Used to cache the InputStream content as a byte array since
  // subsequent connections may need to use it but it is supplied as a stream.
  private byte @Nullable [] mtlsCertChainBytes;
  private byte @Nullable [] mtlsPrivateKeyBytes;
  private byte @Nullable [] tlsRootCertsBytes;

  FlightSqlConnection(
      BufferAllocator allocator,
      SqlQuirks quirks,
      Location location,
      Map parameters)
      throws AdbcException {
    this.allocator = allocator;
    this.quirks = quirks;
    this.parameters = parameters;
    this.callOptions = new CallOption[0];
    FlightSqlClient flightSqlClient = new FlightSqlClient(createInitialConnection(location));
    this.client = new FlightSqlClientWithCallOptions(flightSqlClient, callOptions);
    this.clientCache =
        Caffeine.newBuilder()
            .expireAfterAccess(5, TimeUnit.MINUTES)
            .removalListener(
                (@Nullable Location key,
                    @Nullable FlightSqlClientWithCallOptions value,
                    RemovalCause cause) -> {
                  if (value == null) return;
                  try {
                    value.close();
                  } catch (Exception ex) {
                    if (ex instanceof InterruptedException) {
                      Thread.currentThread().interrupt();
                    }
                    throw new RuntimeException(ex);
                  }
                })
            .build(
                loc -> {
                  FlightClient client = buildClient(loc);
                  client.handshake(callOptions);
                  return new FlightSqlClientWithCallOptions(
                      new FlightSqlClient(client), callOptions);
                });
    this.clientCache.put(location, this.client);
  }

  @Override
  public void commit() throws AdbcException {
    throw AdbcException.notImplemented("[Flight SQL] Transaction methods are not supported");
  }

  @Override
  public AdbcStatement createStatement() throws AdbcException {
    return new FlightSqlStatement(allocator, client, clientCache, quirks);
  }

  @Override
  public ArrowReader readPartition(ByteBuffer descriptor) throws AdbcException {
    final FlightEndpoint endpoint;
    try {
      final Flight.FlightEndpoint protoEndpoint = Flight.FlightEndpoint.parseFrom(descriptor);
      Location[] locations = new Location[protoEndpoint.getLocationCount()];
      int index = 0;
      for (Flight.Location protoLocation : protoEndpoint.getLocationList()) {
        Location location = new Location(protoLocation.getUri());
        locations[index++] = location;
      }

      endpoint =
          new FlightEndpoint(
              new Ticket(protoEndpoint.getTicket().getTicket().toByteArray()), locations);
    } catch (InvalidProtocolBufferException | URISyntaxException e) {
      throw AdbcException.invalidArgument(
              "[Flight SQL] Partition descriptor is invalid: " + e.getMessage())
          .withCause(e);
    }
    return new FlightInfoReader(
        allocator, client, clientCache, Collections.singletonList(endpoint));
  }

  @Override
  public AdbcStatement bulkIngest(String targetTableName, BulkIngestMode mode)
      throws AdbcException {
    return FlightSqlStatement.ingestRoot(
        allocator, client, clientCache, quirks, targetTableName, mode);
  }

  @Override
  public ArrowReader getObjects(
      GetObjectsDepth depth,
      String catalogPattern,
      String dbSchemaPattern,
      String tableNamePattern,
      String[] tableTypes,
      String columnNamePattern)
      throws AdbcException {
    return GetObjectsMetadataReaders.CreateGetObjectsReader(
        allocator,
        client,
        clientCache,
        depth,
        catalogPattern,
        dbSchemaPattern,
        tableNamePattern,
        tableTypes,
        columnNamePattern);
  }

  @Override
  public ArrowReader getInfo(int @Nullable [] infoCodes) throws AdbcException {
    try {
      return GetInfoMetadataReader.CreateGetInfoMetadataReader(
          allocator, client, clientCache, infoCodes);
    } catch (Exception e) {
      throw AdbcException.invalidState("[Flight SQL] Failed to get info").withCause(e);
    }
  }

  @Override
  public void rollback() throws AdbcException {
    throw AdbcException.notImplemented("[Flight SQL] Transaction methods are not supported");
  }

  @Override
  public boolean getAutoCommit() throws AdbcException {
    return true;
  }

  @Override
  public void setAutoCommit(boolean enableAutoCommit) throws AdbcException {
    if (!enableAutoCommit) {
      throw AdbcException.notImplemented("[Flight SQL] Transaction methods are not supported");
    }
  }

  @Override
  public void close() throws Exception {
    clientCache.invalidateAll();
    AutoCloseables.close(client, allocator);
  }

  @Override
  public String toString() {
    return "FlightSqlConnection{" + "client=" + client + '}';
  }

  /**
   * Initialize cached data to share between connections and create, test, and authenticate the
   * first connection.
   */
  private FlightClient createInitialConnection(
      @UnknownInitialization FlightSqlConnection this, Location location) throws AdbcException {
    // Setup cached pre-connection properties.
    try {
      if (parameters != null) {
        final InputStream mtlsCertChain =
            FlightSqlConnectionProperties.MTLS_CERT_CHAIN.get(parameters);
        if (mtlsCertChain != null) {
          this.mtlsCertChainBytes = inputStreamToBytes(mtlsCertChain);
        }

        final InputStream mtlsPrivateKey =
            FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.get(parameters);
        if (mtlsPrivateKey != null) {
          this.mtlsPrivateKeyBytes = inputStreamToBytes(mtlsPrivateKey);
        }

        final InputStream tlsRootCerts =
            FlightSqlConnectionProperties.TLS_ROOT_CERTS.get(parameters);
        if (tlsRootCerts != null) {
          this.tlsRootCertsBytes = inputStreamToBytes(tlsRootCerts);
        }
      }
    } catch (IOException ex) {
      throw new AdbcException(
          String.format(
              "Error reading stream for one of the options %s, %s, %s.",
              FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
              FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
              FlightSqlConnectionProperties.TLS_ROOT_CERTS.getKey()),
          ex,
          AdbcStatusCode.IO,
          null,
          0);
    }

    if (parameters != null) {
      final boolean useCookieMiddleware =
          Boolean.TRUE.equals(FlightSqlConnectionProperties.WITH_COOKIE_MIDDLEWARE.get(parameters));
      if (useCookieMiddleware) {
        this.cookieMiddlewareFactory = new ClientCookieMiddleware.Factory();
      }
    }

    // Build the client using the above properties.
    final FlightClient client = buildClient(location);

    // Add user-specified headers.
    ArrayList options = new ArrayList<>();
    final FlightCallHeaders callHeaders = new FlightCallHeaders();
    for (Map.Entry parameter : parameters.entrySet()) {
      if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
        String userHeaderName =
            parameter
                .getKey()
                .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());

        if (parameter.getValue() instanceof String) {
          callHeaders.insert(userHeaderName, (String) parameter.getValue());
        } else if (parameter.getValue() instanceof byte[]) {
          callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
        } else {
          throw new AdbcException(
              String.format(
                  "Header values must be String or byte[]. The header failing was %s.",
                  parameter.getKey()),
              null,
              AdbcStatusCode.INVALID_ARGUMENT,
              null,
              0);
        }
      }
    }

    options.add(new HeaderCallOption(callHeaders));

    // Test the connection.
    String username = AdbcDriver.PARAM_USERNAME.get(parameters);
    String password = AdbcDriver.PARAM_PASSWORD.get(parameters);
    if (username != null && password != null) {
      Optional bearerToken =
          client.authenticateBasicToken(username, password);
      options.add(
          bearerToken.orElse(
              new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
      this.callOptions = options.toArray(new CallOption[0]);
    } else {
      this.callOptions = options.toArray(new CallOption[0]);
      client.handshake(this.callOptions);
    }

    return client;
  }

  /** Returns a yet-to-be authenticated FlightClient */
  private FlightClient buildClient(
      @UnknownInitialization FlightSqlConnection this, Location location) throws AdbcException {
    if (allocator == null) {
      throw new IllegalStateException("Internal error: allocator was not initialized");
    }
    final FlightClient.Builder builder =
        FlightClient.builder()
            .allocator(
                allocator.newChildAllocator(
                    "adbc-flightclient-connection-" + counter.getAndIncrement(),
                    0,
                    allocator.getLimit()))
            .location(location);

    // Configure TLS options.
    if (mtlsCertChainBytes != null && mtlsPrivateKeyBytes != null) {
      builder.clientCertificate(
          new ByteArrayInputStream(mtlsCertChainBytes),
          new ByteArrayInputStream(mtlsPrivateKeyBytes));
    } else if (mtlsCertChainBytes != null) {
      throw new AdbcException(
          String.format(
              "Must provide both %s and %s or neither. %s provided only.",
              FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
              FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
              FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey()),
          null,
          AdbcStatusCode.INVALID_ARGUMENT,
          null,
          0);
    } else if (mtlsPrivateKeyBytes != null) {
      throw new AdbcException(
          String.format(
              "Must provide both %s and %s or neither. %s provided only.",
              FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey(),
              FlightSqlConnectionProperties.MTLS_CERT_CHAIN.getKey(),
              FlightSqlConnectionProperties.MTLS_PRIVATE_KEY.getKey()),
          null,
          AdbcStatusCode.INVALID_ARGUMENT,
          null,
          0);
    }

    if (tlsRootCertsBytes != null) {
      builder.trustedCertificates(new ByteArrayInputStream(tlsRootCertsBytes));
    }

    if (parameters != null) {
      if (Boolean.TRUE.equals(FlightSqlConnectionProperties.TLS_SKIP_VERIFY.get(parameters))) {
        builder.verifyServer(false);
      }

      String hostnameOverride = FlightSqlConnectionProperties.TLS_OVERRIDE_HOSTNAME.get(parameters);
      if (hostnameOverride != null) {
        builder.overrideHostname(hostnameOverride);
      }
    }

    // Setup cookies if needed.
    if (cookieMiddlewareFactory != null) {
      builder.intercept(cookieMiddlewareFactory);
    }

    return builder.build();
  }

  private static byte[] inputStreamToBytes(InputStream stream) throws IOException {
    byte[] bytes = new byte[stream.available()];
    DataInputStream dataInputStream = new DataInputStream(stream);
    dataInputStream.readFully(bytes);
    return bytes;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy