All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.vesoft.nebula.driver.graph.net.GrpcConnection Maven / Gradle / Ivy

The newest version!
package com.vesoft.nebula.driver.graph.net;

import com.alibaba.fastjson.JSON;
import com.google.common.base.Charsets;
import com.google.protobuf.ByteString;
import com.vesoft.nebula.driver.graph.ErrorCode;
import com.vesoft.nebula.driver.graph.data.HostAddress;
import com.vesoft.nebula.driver.graph.exception.AuthFailedException;
import com.vesoft.nebula.driver.graph.exception.IOErrorException;
import com.vesoft.nebula.driver.graph.utils.ClientVersion;
import com.vesoft.nebula.proto.common.ClientInfo;
import com.vesoft.nebula.proto.common.Common;
import com.vesoft.nebula.proto.graph.AuthRequest;
import com.vesoft.nebula.proto.graph.AuthResponse;
import com.vesoft.nebula.proto.graph.ExecuteRequest;
import com.vesoft.nebula.proto.graph.ExecuteResponse;
import com.vesoft.nebula.proto.graph.GraphServiceGrpc;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.nio.charset.Charset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GrpcConnection extends Connection {

    private static final Logger LOGGER = LoggerFactory.getLogger(GrpcConnection.class);

    private static final ConcurrentHashMap channels =
            new ConcurrentHashMap<>();
    private GraphServiceGrpc.GraphServiceBlockingStub stub;
    private long connectTimeout = 0;
    private long requestTimeout = 0;

    private final Charset charset = Charsets.UTF_8;

    private static final ReadWriteLock lock = new ReentrantReadWriteLock();

    @Override
    public void open(HostAddress address, long connectTimeout, long requestTimeout) {
        this.serverAddr = address;
        this.connectTimeout = connectTimeout;
        this.requestTimeout = requestTimeout;
        lock.readLock().lock();
        try {
            channels.computeIfAbsent(serverAddr, key -> createChannel());
        } finally {
            lock.readLock().unlock();
        }
        stub = GraphServiceGrpc.newBlockingStub(channels.get(serverAddr));
    }

    @Override
    public void close() {
        if (!channels.isEmpty()) {
            closeChannel();
        }
        stub = null;
    }

    @Override
    public boolean ping(long sessionID, long timeoutMs) throws IOErrorException {
        ExecuteResponse response = execute(sessionID, "RETURN 1", timeoutMs);
        return ErrorCode.SUCCESSFUL_COMPLETION.code
                .equals(response.getStatus().getCode().toString(charset));
    }

    public AuthResult authenticate(String user, Map authOptions)
            throws AuthFailedException {
        try {
            ClientInfo clientInfo = ClientInfo.newBuilder()
                    .setLang(ClientInfo.Language.JAVA)
                    .setProtocolVersion(Common
                            .getDescriptor()
                            .getOptions()
                            .getExtension(Common.protocolVersion))
                    .setVersion(ByteString.copyFrom(ClientVersion.clientVersion, charset))
                    .build();
            String authInfoString = JSON.toJSONString(authOptions);
            AuthRequest authReq = AuthRequest.newBuilder()
                    .setUsername(ByteString.copyFrom(user, charset))
                    .setAuthInfo(ByteString.copyFrom(authInfoString, charset))
                    .setClientInfo(clientInfo)
                    .build();

            getChannel();
            AuthResponse resp = stub
                    .withDeadlineAfter(connectTimeout, TimeUnit.MILLISECONDS)
                    .authenticate(authReq);
            String code = resp.getStatus().getCode().toString(charset);
            if (!ErrorCode.SUCCESSFUL_COMPLETION.code.equals(code)) {
                throw new AuthFailedException(resp.getStatus().getMessage().toString(charset));
            }
            return new AuthResult(resp.getSessionId());
        } catch (Exception e) {
            // TODO
            throw e;
        }
    }

    public ExecuteResponse execute(long sessionID, String stmt, long timeout)
            throws IOErrorException {
        getChannel();
        try {
            ExecuteRequest request = ExecuteRequest.newBuilder()
                    .setSessionId(sessionID)
                    .setStmt(ByteString.copyFrom(stmt, charset))
                    .build();

            return stub.withDeadlineAfter(timeout, TimeUnit.MILLISECONDS).execute(request);
        } catch (Exception e) {
            // TODO
            throw e;
        }
    }

    public ExecuteResponse execute(long sessionID, String stmt) throws IOErrorException {
        return execute(sessionID, stmt, this.requestTimeout);
    }

    private void getChannel() {
        lock.readLock().lock();
        try {
            channels.computeIfAbsent(serverAddr, key -> {
                ManagedChannel channel = createChannel();
                stub = GraphServiceGrpc.newBlockingStub(channel)
                        .withDeadline(Deadline.after(requestTimeout, TimeUnit.MILLISECONDS));
                return channel;
            });
        } finally {
            lock.readLock().unlock();
        }
    }

    private ManagedChannel createChannel() {
        return ManagedChannelBuilder
                .forAddress(serverAddr.getHost(), serverAddr.getPort()).usePlaintext()
                .build();
    }

    private static void closeChannel() {
        lock.writeLock().lock();
        try {
            for (ManagedChannel channel : channels.values()) {
                if (channel != null && !channel.isShutdown()) {
                    try {
                        channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
                    } catch (InterruptedException e) {
                        LOGGER.warn("close grpc connection is interrupted.", e);
                    }
                }
            }
            channels.clear();
        } finally {
            lock.writeLock().unlock();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy