All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.palantir.atlasdb.keyvalue.cassandra.CassandraClientFactory Maven / Gradle / Ivy

The newest version!
/*
 * (c) Copyright 2018 Palantir Technologies Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.palantir.atlasdb.keyvalue.cassandra;

import com.codahale.metrics.Timer;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import com.palantir.atlasdb.cassandra.CassandraCredentialsConfig;
import com.palantir.atlasdb.cassandra.CassandraKeyValueServiceConfig;
import com.palantir.atlasdb.keyvalue.cassandra.ImmutableCassandraClientConfig.SocketTimeoutMillisBuildStage;
import com.palantir.atlasdb.keyvalue.cassandra.pool.CassandraServer;
import com.palantir.atlasdb.util.AtlasDbMetrics;
import com.palantir.atlasdb.util.MetricsManager;
import com.palantir.common.annotations.ImmutablesStyles.StagedBuilderStyle;
import com.palantir.common.exception.AtlasDbDependencyException;
import com.palantir.conjure.java.api.config.ssl.SslConfiguration;
import com.palantir.conjure.java.config.ssl.SslSocketFactories;
import com.palantir.exception.SafeSSLPeerUnverifiedException;
import com.palantir.logsafe.DoNotLog;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.UnsafeArg;
import com.palantir.logsafe.exceptions.SafeRuntimeException;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import com.palantir.util.TimedRunner;
import com.palantir.util.TimedRunner.TaskContext;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import org.apache.cassandra.thrift.AuthenticationRequest;
import org.apache.cassandra.thrift.Cassandra;
import org.apache.cassandra.thrift.Cassandra.Client;
import org.apache.commons.pool2.BasePooledObjectFactory;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.hc.client5.http.ssl.DefaultHostnameVerifier;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.layered.TFramedTransport;
import org.immutables.value.Value;

public class CassandraClientFactory extends BasePooledObjectFactory {
    private static final SafeLogger log = SafeLoggerFactory.get(CassandraClientFactory.class);
    private static final LoadingCache sslSocketFactoryCache =
            Caffeine.newBuilder().weakValues().build(SslSocketFactories::createSslSocketFactory);

    private static final DefaultHostnameVerifier hostnameVerifier = new DefaultHostnameVerifier();

    private final MetricsManager metricsManager;
    private final CassandraServer cassandraServer;
    private final CassandraClientConfig clientConfig;
    private final SSLSocketFactory sslSocketFactory;
    private final TimedRunner timedRunner;
    private final TSocketFactory tSocketFactory;
    private final CassandraClientMetrics cassandraClientMetrics;

    public CassandraClientFactory(
            MetricsManager metricsManager,
            CassandraServer cassandraServer,
            CassandraClientConfig clientConfig,
            CassandraClientMetrics cassandraClientMetrics) {
        this.metricsManager = metricsManager;
        this.cassandraServer = cassandraServer;
        this.clientConfig = clientConfig;
        this.sslSocketFactory = createSslSocketFactory(clientConfig.sslConfiguration());
        this.timedRunner = TimedRunner.create(clientConfig.timeoutOnConnectionClose());
        this.tSocketFactory = new InstrumentedTSocket.Factory(metricsManager);
        this.cassandraClientMetrics = cassandraClientMetrics;
    }

    @Override
    public CassandraClient create() {
        try {
            return instrumentClient(getRawClientWithKeyspaceSet());
        } catch (Exception e) {
            String message = String.format(
                    "Failed to construct client for %s/%s", cassandraServer.proxy(), clientConfig.keyspace());
            if (clientConfig.usingSsl()) {
                message += " over SSL";
            }
            throw new ClientCreationFailedException(message, e);
        }
    }

    private CassandraClient instrumentClient(Client rawClient) {
        CassandraClient client = new CassandraClientImpl(rawClient);
        // TODO(ssouza): use the kvsMethodName to tag the timers.
        client = AtlasDbMetrics.instrumentTimed(metricsManager.getRegistry(), CassandraClient.class, client);
        client = new ProfilingCassandraClient(client);
        client = new TracingCassandraClient(client);
        client = new InstrumentedCassandraClient(client, cassandraClientMetrics);
        client = QosCassandraClient.instrumentWithMetrics(client, metricsManager);
        return client;
    }

    private Cassandra.Client getRawClientWithKeyspaceSet() throws TException {
        Client ret = getRawClientWithTimedCreation();
        try {
            ret.set_keyspace(clientConfig.keyspace());
            if (log.isDebugEnabled()) {
                log.debug(
                        "Created new client for {}/{}{}{}",
                        SafeArg.of("address", CassandraLogHelper.host(cassandraServer.proxy())),
                        UnsafeArg.of("keyspace", clientConfig.keyspace()),
                        SafeArg.of("usingSsl", clientConfig.usingSsl() ? " over SSL" : ""),
                        UnsafeArg.of(
                                "usernameConfig",
                                " as user " + clientConfig.credentials().username()));
            }
            return ret;
        } catch (TException e) {
            ret.getOutputProtocol().getTransport().close();
            throw e;
        }
    }

    static CassandraClient getClientInternal(CassandraServer cassandraServer, CassandraClientConfig clientConfig)
            throws TException {
        return new CassandraClientImpl(getRawClient(
                cassandraServer,
                clientConfig,
                createSslSocketFactory(clientConfig.sslConfiguration()),
                TSocketFactory.Default.INSTANCE));
    }

    private static SSLSocketFactory createSslSocketFactory(Optional sslConfiguration) {
        return sslConfiguration
                .map(sslSocketFactoryCache::get)
                // This path should never be hit in production code, since we expect SSL is configured explicitly.
                .orElseGet(() -> (SSLSocketFactory) SSLSocketFactory.getDefault());
    }

    private Cassandra.Client getRawClientWithTimedCreation() throws TException {
        Timer clientCreation = metricsManager.registerOrGetTimer(CassandraClientFactory.class, "clientCreation");
        try (Timer.Context timer = clientCreation.time()) {
            return getRawClient(cassandraServer, clientConfig, sslSocketFactory, tSocketFactory);
        }
    }

    private static Cassandra.Client getRawClient(
            CassandraServer cassandraServer,
            CassandraClientConfig clientConfig,
            SSLSocketFactory sslSocketFactory,
            TSocketFactory tSocketFactory)
            throws TException {
        InetSocketAddress addr = cassandraServer.proxy();
        TSocket thriftSocket =
                tSocketFactory.create(addr.getHostString(), addr.getPort(), clientConfig.socketTimeoutMillis());
        thriftSocket.open();
        setSocketOptions(
                thriftSocket,
                socket -> {
                    socket.getSocket().setKeepAlive(true);
                    socket.getSocket().setSoTimeout(clientConfig.initialSocketQueryTimeoutMillis());
                },
                addr);

        if (clientConfig.usingSsl()) {
            boolean success = false;
            try {
                SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
                        thriftSocket.getSocket(), addr.getHostString(), addr.getPort(), true);
                thriftSocket = tSocketFactory.create(socket);
                verifyEndpoint(cassandraServer, socket, clientConfig.enableEndpointVerification());
                success = true;
            } catch (IOException e) {
                throw new TTransportException(e);
            } finally {
                if (!success) {
                    thriftSocket.close();
                }
            }
        }
        TTransport thriftFramedTransport =
                new TFramedTransport(thriftSocket, CassandraConstants.CLIENT_MAX_THRIFT_FRAME_SIZE_BYTES);
        TProtocol protocol = new TBinaryProtocol(thriftFramedTransport);
        Cassandra.Client client = new Cassandra.Client(protocol);

        try {
            login(client, clientConfig.credentials());
            setSocketOptions(
                    thriftSocket,
                    socket -> socket.getSocket().setSoTimeout(clientConfig.socketQueryTimeoutMillis()),
                    addr);
        } catch (TException e) {
            client.getOutputProtocol().getTransport().close();
            log.error("Exception thrown attempting to authenticate with config provided credentials", e);
            throw e;
        }

        return client;
    }

    private static void login(Client client, CassandraCredentialsConfig config) throws TException {
        Map credsMap = new HashMap<>();
        credsMap.put("username", config.username());
        credsMap.put("password", config.password());
        client.login(new AuthenticationRequest(credsMap));
    }

    /**
     * Verifies that the current SSL connection hostname/ip address matches what the certificate has served.
     * This will check both ip address/hostname, and uses the IP address associated with the socket, rather
     * that what has been provided. Hostname/ip address are both need to be checked, as historically we've
     * connected to Cassandra directly using IP addresses, and therefore need to support such cases.
     *
     * Will only throw when throwOnFailure is true, even if the socket is closed during verification.
     */
    @VisibleForTesting
    static void verifyEndpoint(CassandraServer cassandraServer, SSLSocket socket, boolean throwOnFailure)
            throws SafeSSLPeerUnverifiedException {
        Set endpointsToCheck = getEndpointsToCheck(cassandraServer, socket);
        boolean endpointVerified =
                endpointsToCheck.stream().anyMatch(address -> hostnameVerifier.verify(address, socket.getSession()));
        if (socket.isClosed()) {
            if (throwOnFailure) {
                throw new SafeSSLPeerUnverifiedException(
                        "Unable to verify endpoints as socket is closed.",
                        SafeArg.of("endpoint", socket.getInetAddress()));
            }
            return;
        }
        if (!endpointVerified) {
            log.warn("Endpoint verification failed for host.", SafeArg.of("endpointsChecked", endpointsToCheck));
            if (throwOnFailure) {
                throw new SafeSSLPeerUnverifiedException(
                        "Endpoint verification failed for host.", SafeArg.of("endpointsChecked", endpointsToCheck));
            }
        }
    }

    @VisibleForTesting
    static Set getEndpointsToCheck(CassandraServer cassandraServer, SSLSocket socket) {
        return ImmutableSet.of(socket.getInetAddress().getHostAddress(), cassandraServer.cassandraHostName());
    }

    @Override
    public boolean validateObject(PooledObject client) {
        try {
            return client.getObject().getOutputProtocol().getTransport().isOpen();
        } catch (Throwable t) {
            log.info(
                    "Failed when attempting to validate a Cassandra client in the Cassandra client pool."
                            + " Defensively believing that this object is NOT valid.",
                    SafeArg.of("cassandraClient", CassandraLogHelper.host(cassandraServer.proxy())),
                    t);
            return false;
        }
    }

    @Override
    public PooledObject wrap(CassandraClient client) {
        return new DefaultPooledObject<>(client);
    }

    @Override
    public void destroyObject(PooledObject client) {
        if (log.isDebugEnabled()) {
            log.debug(
                    "Attempting to close transport for client {} of host {}",
                    UnsafeArg.of("client", client),
                    SafeArg.of("cassandraClient", CassandraLogHelper.host(cassandraServer.proxy())));
        }
        try {
            TaskContext taskContext =
                    TaskContext.createRunnable(() -> client.getObject().close(), () -> {});
            timedRunner.run(taskContext);
        } catch (Throwable t) {
            if (log.isDebugEnabled()) {
                log.debug(
                        "Failed to close transport for client {} of host {}",
                        UnsafeArg.of("client", client),
                        SafeArg.of("cassandraClient", CassandraLogHelper.host(cassandraServer.proxy())),
                        t);
            }
            throw new SafeRuntimeException("Threw while attempting to close transport for client", t);
        }
        if (log.isDebugEnabled()) {
            log.debug(
                    "Closed transport for client {} of host {}",
                    UnsafeArg.of("client", client),
                    SafeArg.of("cassandraClient", CassandraLogHelper.host(cassandraServer.proxy())));
        }
    }

    static final class ClientCreationFailedException extends AtlasDbDependencyException {
        private static final long serialVersionUID = 1L;

        ClientCreationFailedException(String message, Exception cause) {
            super(message, cause);
        }
    }

    private static void setSocketOptions(TSocket thriftSocket, SocketConsumer consumer, InetSocketAddress addr) {
        try {
            consumer.accept(thriftSocket);
        } catch (SocketException e) {
            log.error(
                    "Couldn't set socket options for host {}", SafeArg.of("address", CassandraLogHelper.host(addr)), e);
        }
    }

    @FunctionalInterface
    private interface SocketConsumer {
        void accept(TSocket value) throws SocketException;
    }

    @DoNotLog
    @Value.Immutable
    @StagedBuilderStyle
    interface CassandraClientConfig {
        int socketTimeoutMillis();

        int socketQueryTimeoutMillis();

        int initialSocketQueryTimeoutMillis();

        CassandraCredentialsConfig credentials();

        boolean usingSsl();

        boolean enableEndpointVerification();

        Optional sslConfiguration();

        String keyspace();

        Duration timeoutOnConnectionClose();

        static CassandraClientConfig of(CassandraKeyValueServiceConfig config) {
            return builder()
                    .socketTimeoutMillis(config.socketTimeoutMillis())
                    .socketQueryTimeoutMillis(config.socketQueryTimeoutMillis())
                    .initialSocketQueryTimeoutMillis(config.initialSocketQueryTimeoutMillis())
                    .credentials(config.credentials())
                    .usingSsl(config.usingSsl())
                    .enableEndpointVerification(config.enableEndpointVerification())
                    .keyspace(config.getKeyspaceOrThrow())
                    .timeoutOnConnectionClose(config.timeoutOnConnectionClose())
                    .sslConfiguration(config.sslConfiguration())
                    .build();
        }

        static SocketTimeoutMillisBuildStage builder() {
            return ImmutableCassandraClientConfig.builder();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy