All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.athaydes.protobuf.tcp.internal.ProtobufServer Maven / Gradle / Ivy

package com.athaydes.protobuf.tcp.internal;

import com.athaydes.protobuf.tcp.api.Api;
import com.athaydes.protobuf.tcp.api.ServiceReference;
import com.google.protobuf.Any;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.StringValue;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousCloseException;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.athaydes.protobuf.tcp.internal.MethodResolver.resolveMethods;
import static com.athaydes.protobuf.tcp.internal.Utils.closeQuietly;
import static java.util.Collections.emptyList;

/**
 * A TCP implementation of a Protobuf RPC server that sends method invocations to a local service.
 */
public class ProtobufServer implements ServiceReference {

    private static final Logger log = LoggerFactory.getLogger(ProtobufServer.class);

    private final int port;
    private final T service;
    private final AtomicBoolean running = new AtomicBoolean(false);
    private final AtomicReference serverSocketRef = new AtomicReference<>();
    private final Map> methodsByName;

    public ProtobufServer(T service, int port, Class... exportedInterfaces) {
        this.port = port;
        this.service = service;
        this.methodsByName = resolveMethods(service, exportedInterfaces);
    }

    @Override
    public T getLocalService() {
        return service;
    }

    /**
     * Run this server.
     * 

* This is a non-blocking call. */ @Override public void run() { log.info("Starting ProtobufServer on port " + port); AsynchronousServerSocketChannel serverSocket; try { if (!serverSocketRef.compareAndSet(null, serverSocket = AsynchronousServerSocketChannel.open())) { throw new RuntimeException("Server already running"); } serverSocket.bind(new InetSocketAddress("127.0.0.1", port)); } catch (IOException e) { log.warn("Error starting server", e); throw new RuntimeException(e); } log.info("Accepting client connections"); running.set(true); serverSocket.accept(null, new CompletionHandler() { @Override public void completed(AsynchronousSocketChannel clientSocket, Void ignore) { if (running.get()) { try { log.debug("Accepting connection from: {}", clientSocket.getRemoteAddress()); serverSocket.accept(null, this); new Handler(service, methodsByName, clientSocket).run(); } catch (IOException e) { log.warn("Unable to get client remote address"); closeQuietly(clientSocket); } } else { closeQuietly(serverSocket); } } @Override public void failed(Throwable exc, Void ignore) { if (exc instanceof AsynchronousCloseException) { log.debug("Client closed connection"); } else { log.warn("Failed to accept client socket", exc); } } }); } @Override public void close() { log.info("Stopping server"); running.set(false); AsynchronousServerSocketChannel serverSocket = serverSocketRef.get(); if (serverSocket != null) { closeQuietly(serverSocket); } } private static class Handler implements CompletionHandler { private final Object service; private final Map> methodsByName; private final AsynchronousSocketChannel clientSocket; Handler(Object service, Map> methodsByName, AsynchronousSocketChannel clientSocket) { this.service = service; this.methodsByName = methodsByName; this.clientSocket = clientSocket; } void run() { VarIntReader reader = new VarIntReader(); try { clientSocket.read(reader.buffer(), 5, TimeUnit.SECONDS, reader, this); } catch (IllegalStateException e) { log.debug("Unable to continue listening to client socket due to {}", e.toString()); closeQuietly(clientSocket); } } @Override public void completed(Integer bytesCount, VarIntReader lengthReader) { if (bytesCount < 0) { log.debug("Received bytesCount = {}, closing client socket", bytesCount); closeQuietly(clientSocket); return; } OptionalInt length; try { length = lengthReader.read(); if (!length.isPresent()) { clientSocket.read(lengthReader.buffer(), 5, TimeUnit.SECONDS, lengthReader, this); return; // wait for more bytes } } catch (IOException e) { sendError(e); return; } int messageLength = length.getAsInt(); log.debug("Expecting message with length {}", messageLength); if (messageLength <= 0) { sendError(new IllegalArgumentException("Invalid message length")); } else { ByteBuffer msgBuffer = ByteBuffer.allocate(messageLength); clientSocket.read(msgBuffer, 10, TimeUnit.SECONDS, msgBuffer, new ServiceMethodInvoker(messageLength)); } } @Override public void failed(Throwable exc, VarIntReader reader) { log.debug("Handler failed: {}", exc.toString()); closeQuietly(clientSocket); } private void sendError(Throwable error) { Api.Result response = Api.Result.newBuilder().setException(Api.Exception.newBuilder() .setType(error.getClass().getName()) .setMessage(Optional.ofNullable(error.getMessage()).orElse("")) .build()).build(); sendResult(response); } private void sendResult(Api.Result result) { int resultLength = result.getSerializedSize(); ByteArrayOutputStream out = new ByteArrayOutputStream(resultLength + 4); try { log.debug("Sending result to client: {}", result); result.writeDelimitedTo(out); clientSocket.write(ByteBuffer.wrap(out.toByteArray())); } catch (IOException ignore) { // never happens, write is async } finally { // start waiting for new invocations again run(); } } private class ServiceMethodInvoker implements CompletionHandler { private final int messageLength; private final AtomicInteger bytesReceived = new AtomicInteger(0); ServiceMethodInvoker(int messageLength) { this.messageLength = messageLength; } @Override public void completed(Integer bytesCount, ByteBuffer msgBuffer) { if (bytesCount < 0) { log.debug("Received bytesCount = {}, closing client socket", bytesCount); closeQuietly(clientSocket); return; } int received = bytesReceived.addAndGet(bytesCount); if (received < messageLength) { log.debug("Received {} bytes so far, waiting for a total of {}.", received, messageLength); clientSocket.read(msgBuffer, 5, TimeUnit.SECONDS, msgBuffer, this); return; } log.debug("Received full message with length {}, parsing it.", messageLength); msgBuffer.flip(); Api.MethodInvocation message; try { message = Api.MethodInvocation.parseFrom(CodedInputStream.newInstance(msgBuffer)); } catch (IOException e) { // should not happen, the msgBuffer is read from the socket already sendError(e); return; } String methodName = message.getMethodName(); List args = message.getArgsList(); log.debug("Looking up method '{}' of service {}", methodName, service); Optional resolvedInvocationInfo = methodsByName .getOrDefault(methodName, emptyList()).stream() .map(m -> MethodInvocationResolver.resolveMethodInvocation(m, args)) .filter(Optional::isPresent) .map(Optional::get) .findAny(); if (resolvedInvocationInfo.isPresent()) { log.debug("Resolved method invocation: {}", resolvedInvocationInfo.get()); try { Any result = resolvedInvocationInfo.get().callWith(service); sendResult(Api.Result.newBuilder().setSuccessResult(result).build()); log.debug("Successfully processed method invocation"); } catch (InvocationTargetException e) { sendError(e.getCause()); } catch (Exception e) { sendError(e); } } else { log.debug("Method not found"); sendError(new NoSuchMethodException(methodName)); } } @Override public void failed(Throwable exc, ByteBuffer attachment) { sendError(exc); } } } public static void main(String[] args) { class HelloService { public StringValue sayHello(StringValue message) { return StringValue.newBuilder() .setValue("Hello " + message.getValue()) .build(); } } System.out.println("Starting ProtobufServer with HelloService.\n" + "The only method call allowed has the following signature:\n" + " public StringValue sayHello(StringValue message)\n" + "Send a message to port 5562 and this server will respond!\n"); ProtobufServer server = new ProtobufServer<>(new HelloService(), 5562); server.run(); System.out.println("Press enter to stop the server"); try { //noinspection ResultOfMethodCallIgnored System.in.read(); } catch (IOException e) { e.printStackTrace(); } server.close(); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy