All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.yandex.ydb.core.grpc.impl.grpc.GrpcTransportImpl Maven / Gradle / Ivy

There is a newer version: 1.45.6
Show newest version
package com.yandex.ydb.core.grpc.impl.grpc;

import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import com.google.common.net.HostAndPort;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.grpc.AsyncBidiStreamingInAdapter;
import com.yandex.ydb.core.grpc.AsyncBidiStreamingOutAdapter;
import com.yandex.ydb.core.grpc.ChannelSettings;
import com.yandex.ydb.core.grpc.GrpcTransport;
import com.yandex.ydb.core.grpc.ServerStreamToObserver;
import com.yandex.ydb.core.grpc.UnaryStreamToBiConsumer;
import com.yandex.ydb.core.grpc.UnaryStreamToConsumer;
import com.yandex.ydb.core.grpc.UnaryStreamToFuture;
import com.yandex.ydb.core.rpc.OutStreamObserver;
import com.yandex.ydb.core.rpc.StreamControl;
import com.yandex.ydb.core.rpc.StreamObserver;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author Sergey Polovko
 * @author Nikolay Perfilov
 */
public class GrpcTransportImpl extends GrpcTransport {
    private static final Logger logger = LoggerFactory.getLogger(GrpcTransportImpl.class);

    private final ManagedChannel realChannel;
    private final Channel channel;

    public GrpcTransportImpl(GrpcTransport.Builder builder) {
        super(builder);

        ChannelSettings channelSettings = ChannelSettings.fromBuilder(builder);
        this.realChannel = createChannel(builder, channelSettings);
        this.channel = interceptChannel(realChannel, channelSettings);
        init();
    }

    private void init() {
        switch (discoveryMode) {
            case SYNC:
                try {
                    Instant start = Instant.now();
                    tryToConnect().get(WAIT_FOR_CONNECTION_MS, TimeUnit.MILLISECONDS);
                    logger.info("GrpcTransport sync initialization took {} ms",
                            Duration.between(start, Instant.now()).toMillis());
                } catch (TimeoutException ignore) {
                    logger.warn("Couldn't establish YDB transport connection in {} ms", WAIT_FOR_CONNECTION_MS);
                    // Keep going
                    // Use ASYNC discovery mode and tryToConnect() method to add actions in case of connection timeout
                } catch (InterruptedException | ExecutionException e) {
                    logger.error("Exception thrown while establishing YDB transport connection: " + e);
                    throw new RuntimeException("Exception thrown while establishing YDB transport connection", e);
                }
                break;
            case ASYNC:
            default:
                break;
        }
    }

    /**
     * Establish connection for grpc channel(s) if its currently IDLE
     * Returns a future to a first {@link ConnectivityState} that is not IDLE or CONNECTING
     */
    public CompletableFuture tryToConnect() {
        CompletableFuture promise = new CompletableFuture<>();
        ConnectivityState initialState = realChannel.getState(true);
        logger.debug("GrpcTransport channel initial state: {}", initialState);
        if (!TEMPORARY_STATES.contains(initialState)) {
            promise.complete(initialState);
        } else {
            realChannel.notifyWhenStateChanged(
                    initialState,
                    new Runnable() {
                        @Override
                        public void run() {
                            ConnectivityState currState = realChannel.getState(false);
                            logger.debug("GrpcTransport channel new state: {}", currState);
                            if (TEMPORARY_STATES.contains(currState)) {
                                realChannel.notifyWhenStateChanged(currState, this);
                            } else {
                                promise.complete(currState);
                            }
                        }
                    }
            );
        }
        return promise;
    }

    private static ManagedChannel createChannel(GrpcTransport.Builder builder, ChannelSettings channelSettings) {
        String endpoint = builder.getEndpoint();
        String database = builder.getDatabase();
        List hosts = builder.getHosts();
        assert endpoint == null || database != null;
        assert (endpoint == null) != (hosts == null);

        final String localDc = builder.getLocalDc();

        // Always use random choice policy, may be add option for that?
        String defaultPolicy = registerYdbLoadBalancer(localDc, true);

        final NettyChannelBuilder channelBuilder;
        if (endpoint != null) {
            channelBuilder = NettyChannelBuilder.forTarget(YdbNameResolver.makeTarget(endpoint, database))
                    .nameResolverFactory(YdbNameResolver.newFactory(
                            builder.getAuthProvider(),
                            builder.getCert(),
                            builder.getUseTls(),
                            builder.getEndpointsDiscoveryPeriod(),
                            builder.getCallExecutor(),
                            builder.getChannelInitializer()))
                    .defaultLoadBalancingPolicy(defaultPolicy);
        } else if (hosts.size() > 1) {
            channelBuilder = NettyChannelBuilder.forTarget(HostsNameResolver.makeTarget(hosts))
                    .nameResolverFactory(HostsNameResolver.newFactory(hosts, builder.getCallExecutor()))
                    .defaultLoadBalancingPolicy(defaultPolicy);
        } else {
            channelBuilder = NettyChannelBuilder.forAddress(
                    hosts.get(0).getHost(),
                    hosts.get(0).getPortOrDefault(DEFAULT_PORT));
        }

        channelSettings.configureSecureConnection(channelBuilder);

        channelBuilder
                .maxInboundMessageSize(64 << 20) // 64 MiB
                .withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);

        builder.getChannelInitializer().accept(channelBuilder);
        return channelBuilder.build();
    }

    private static String registerYdbLoadBalancer(String localDc, boolean randomChoice) {
        String policyName = "ydb_load_balancer";
        final LoadBalancerRegistry lbr = LoadBalancerRegistry.getDefaultRegistry();

        lbr.register(new LoadBalancerProvider() {
            @Override
            public boolean isAvailable() {
                return true;
            }

            @Override
            public int getPriority() {
                return 10;
            }

            @Override
            public String getPolicyName() {
                return policyName;
            }

            @Override
            public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
                return new YdbLoadBalancer(helper, randomChoice, localDc);
            }
        });
        return policyName;
    }

    @Override
    protected  CompletableFuture> makeUnaryCall(
            MethodDescriptor method,
            ReqT request,
            CallOptions callOptions,
            CompletableFuture> promise) {
        ClientCall call = channel.newCall(method, callOptions);
        if (logger.isDebugEnabled()) {
            logger.debug("Sending request to {}, method `{}', request: `{}'", channel.authority(), method, request );
        }
        sendOneRequest(call, request, new UnaryStreamToFuture<>(promise));
        return promise;
    }

    @Override
    protected  void makeUnaryCall(
            MethodDescriptor method,
            ReqT request,
            CallOptions callOptions,
            Consumer> consumer) {
        ClientCall call = channel.newCall(method, callOptions);
        if (logger.isDebugEnabled()) {
            logger.debug("Sending request to {}, method `{}', request: `{}'", channel.authority(), method, request );
        }
        sendOneRequest(call, request, new UnaryStreamToConsumer<>(consumer));
    }

    @Override
    protected  void makeUnaryCall(
            MethodDescriptor method,
            ReqT request,
            CallOptions callOptions,
            BiConsumer consumer) {
        ClientCall call = channel.newCall(method, callOptions);
        if (logger.isDebugEnabled()) {
            logger.debug("Sending request to {}, method `{}', request: `{}'", channel.authority(), method, request );
        }
        sendOneRequest(call, request, new UnaryStreamToBiConsumer<>(consumer));
    }

    @Override
    protected  StreamControl makeServerStreamCall(
            MethodDescriptor method,
            ReqT request,
            CallOptions callOptions,
            StreamObserver observer) {
        ClientCall call = channel.newCall(method, callOptions);
        sendOneRequest(call, request, new ServerStreamToObserver<>(observer, call));
        return () -> {
            call.cancel("Cancelled on user request", new CancellationException());
        };
    }

    @Override
    protected  OutStreamObserver makeBidirectionalStreamCall(
            MethodDescriptor method,
            CallOptions callOptions,
            StreamObserver observer) {
        ClientCall call = channel.newCall(method, callOptions);
        AsyncBidiStreamingOutAdapter adapter
                = new AsyncBidiStreamingOutAdapter<>(call);
        AsyncBidiStreamingInAdapter responseListener
                = new AsyncBidiStreamingInAdapter<>(observer, adapter);
        call.start(responseListener, new Metadata());
        responseListener.onStart();
        return adapter;
    }

    @Override
    public void close() {
        super.close();
        try {
            boolean closed = realChannel.shutdown()
                    .awaitTermination(WAIT_FOR_CLOSING_MS, TimeUnit.MILLISECONDS);
            if (!closed) {
                logger.warn("closing transport timeout exceeded, terminate");
                closed = realChannel.shutdownNow()
                        .awaitTermination(WAIT_FOR_CLOSING_MS, TimeUnit.MILLISECONDS);
                if (!closed) {
                    logger.warn("closing transport problem");
                }
            }
        } catch (InterruptedException e) {
            logger.error("transport shutdown interrupted", e);
            Thread.currentThread().interrupt();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy