All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.impossibl.postgres.protocol.v30.ServerConnectionFactory Maven / Gradle / Ivy

The newest version!
/**
 * Copyright (c) 2013, impossibl.com
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *  * Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *  * Neither the name of impossibl.com nor the names of its contributors may
 *    be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
package com.impossibl.postgres.protocol.v30;

import com.impossibl.postgres.protocol.CopyFormat;
import com.impossibl.postgres.protocol.FieldFormat;
import com.impossibl.postgres.protocol.Notice;
import com.impossibl.postgres.protocol.ssl.SSLEngineFactory;
import com.impossibl.postgres.protocol.ssl.SSLMode;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CommandError;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CopyData;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CopyDone;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CopyFail;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CopyInResponse;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.CopyOutResponse;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.Notification;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.ParameterStatus;
import com.impossibl.postgres.protocol.v30.ProtocolHandler.ReportNotice;
import com.impossibl.postgres.system.Configuration;
import com.impossibl.postgres.system.NoticeException;
import com.impossibl.postgres.system.ParameterNames;
import com.impossibl.postgres.system.ServerInfo;
import com.impossibl.postgres.system.SystemSettings;
import com.impossibl.postgres.system.Version;

import static com.impossibl.postgres.protocol.ServerConnection.KeyData;
import static com.impossibl.postgres.protocol.v30.HostNameVerifier.verifyHostName;
import static com.impossibl.postgres.system.SystemSettings.APPLICATION_NAME;
import static com.impossibl.postgres.system.SystemSettings.CREDENTIALS_USERNAME;
import static com.impossibl.postgres.system.SystemSettings.DATABASE_NAME;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_BUFFER_POOLING;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_ENCODING;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_IO_MODE;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_IO_THREADS;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_MESSAGE_SIZE_MAX;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_SOCKET_RECV_BUFFER_SIZE;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_SOCKET_SEND_BUFFER_SIZE;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_TRACE;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_TRACE_FILE;
import static com.impossibl.postgres.system.SystemSettings.PROTOCOL_VERSION;
import static com.impossibl.postgres.system.SystemSettings.SSL_MODE;
import static com.impossibl.postgres.utils.Await.awaitUninterruptibly;
import static com.impossibl.postgres.utils.Nulls.firstNonNull;

import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.lang.ref.WeakReference;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;

import static java.util.concurrent.TimeUnit.SECONDS;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDomainSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.channel.unix.DomainSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.ssl.SslHandler;


public class ServerConnectionFactory implements com.impossibl.postgres.protocol.ServerConnectionFactory {

  private static final long DEFAULT_STARTUP_TIMEOUT = 60;
  private static final long DEFAULT_SSL_TIMEOUT = 60;

  static class CreatedChannel {
    ServerConnectionShared.Ref sharedRef;
    ChannelFuture channelFuture;

    CreatedChannel(ServerConnectionShared.Ref sharedRef, ChannelFuture channelFuture) {
      this.sharedRef = sharedRef;
      this.channelFuture = channelFuture;
    }
  }

  public ServerConnection connect(Configuration config, SocketAddress address, ServerConnection.Listener listener) throws IOException {

    SSLMode sslMode = config.getSetting(SSL_MODE);

    return connect(config, sslMode, address, listener);
  }

  private ServerConnection connect(Configuration config, SSLMode sslMode, SocketAddress address, ServerConnection.Listener listener) throws IOException {

    try {

      CreatedChannel createdChannel = createChannel(address, config);

      ServerConnectionShared.Ref sharedRef = createdChannel.sharedRef;
      Channel channel = createdChannel.channelFuture.syncUninterruptibly().channel();

      if (sslMode != SSLMode.Disable && sslMode != SSLMode.Allow) {

        // Execute SSL query command

        SSLQueryRequest sslQueryRequest = new SSLQueryRequest();
        channel.writeAndFlush(sslQueryRequest).syncUninterruptibly();

        boolean sslQueryCompleted = awaitUninterruptibly(DEFAULT_SSL_TIMEOUT, SECONDS, sslQueryRequest::await);

        if (sslQueryCompleted && sslQueryRequest.isAllowed()) {

          // Attach the actual handler

          SSLEngine sslEngine = SSLEngineFactory.create(sslMode, config);

          final SslHandler sslHandler = new SslHandler(sslEngine);

          channel.pipeline().addFirst("ssl", sslHandler);

          try {

            sslHandler.handshakeFuture().syncUninterruptibly();

          }
          catch (Exception e) {

            // Retry with no SSL
            if (sslMode == SSLMode.Prefer) {
              return connect(config, SSLMode.Disable, address, listener);
            }

            throw e;
          }

        }
        else if (sslMode.isRequired()) {

          throw new IOException("SSL not allowed by server");
        }

      }

      try {

        Map parameterStatuses = new HashMap<>();
        ServerConnection serverConnection = startup(config, channel, parameterStatuses, sharedRef);

        if (sslMode == SSLMode.VerifyFull) {

          SslHandler sslHandler = channel.pipeline().get(SslHandler.class);
          if (sslHandler != null) {

            String hostname;
            if (address instanceof InetSocketAddress) {
              hostname = ((InetSocketAddress) address).getHostString();
            }
            else {
              hostname = "";
            }

            verifyHostName(hostname, sslHandler.engine().getSession());
          }

        }


        // Finally successfully connected...

        serverConnection.getMessageDispatchHandler().setDefaultHandler(new DefaultHandler(listener));

        parameterStatuses.forEach(listener::parameterStatusChanged);

        return serverConnection;
      }
      catch (Exception e) {

        switch (sslMode) {
          case Allow:
            return connect(config, SSLMode.Require, address, listener);

          case Prefer:
            return connect(config, SSLMode.Disable, address, listener);

          default:
            throw e;
        }

      }

    }
    catch (NoticeException e) {

      throw e;
    }
    catch (Exception e) {

      throw translateConnectionException(e);
    }

  }

  private CreatedChannel createChannel(SocketAddress address, Configuration config) {

    if (address instanceof InetSocketAddress) {
      return createInetSocketChannel((InetSocketAddress) address, config);
    }
    else if (address instanceof DomainSocketAddress) {
      return createDomainSocketChannel((DomainSocketAddress) address, config);
    }
    else {
      throw new IllegalArgumentException("Unsupported socket address: " + address.getClass().getSimpleName());
    }
  }

  @SuppressWarnings("deprecation")
  private CreatedChannel createInetSocketChannel(InetSocketAddress address, Configuration config) {

    int maxMessageSize = config.getSetting(PROTOCOL_MESSAGE_SIZE_MAX);
    Charset clientEncoding = config.getSetting(PROTOCOL_ENCODING);

    Class channelType;
    Class groupType;

    int maxThreads = config.getSetting(PROTOCOL_IO_THREADS);

    SystemSettings.ProtocolIOMode ioMode = config.getSetting(PROTOCOL_IO_MODE);
    switch (ioMode) {
      case OIO:
        channelType = io.netty.channel.socket.oio.OioSocketChannel.class;
        groupType = io.netty.channel.oio.OioEventLoopGroup.class;
        maxThreads = 0;
        break;

      case ANY:

        // Fallthrough to try in order...

      case NATIVE:
        if (KQueue.isAvailable()) {
          channelType = KQueueSocketChannel.class;
          groupType = KQueueEventLoopGroup.class;
          break;
        }
        else if (Epoll.isAvailable()) {
          channelType = EpollSocketChannel.class;
          groupType = EpollEventLoopGroup.class;
          break;
        }
        else if (ioMode != SystemSettings.ProtocolIOMode.ANY) {
          throw new IllegalStateException("Unsupported io mode: native: no native library loaded");
        }

      case NIO:
        channelType = NioSocketChannel.class;
        groupType = NioEventLoopGroup.class;
        break;

      default:
        throw new IllegalStateException("Unsupported io mode: " + ioMode);
    }

    ServerConnectionShared.Ref sharedRef = ServerConnectionShared.acquire(groupType, maxThreads);

    Writer protocolTraceWriter = createProtocolTracer(config);

    Bootstrap bootstrap = new Bootstrap()
        .group(sharedRef.get().getEventLoopGroup())
        .channel(channelType)
        .handler(new ChannelInitializer() {
          @Override
          protected void initChannel(SocketChannel ch) {
            ch.pipeline().addLast(
                new LengthFieldBasedFrameDecoder(maxMessageSize, 1, 4, -4, 0),
                new MessageDispatchHandler(clientEncoding, protocolTraceWriter)
            );
          }
        })
        .option(ChannelOption.TCP_NODELAY, true);

    configureChannelOptions(config, bootstrap);

    ChannelFuture channelFuture = bootstrap.connect(address);

    return new CreatedChannel(sharedRef, channelFuture);
  }

  private CreatedChannel createDomainSocketChannel(DomainSocketAddress address, Configuration config) {

    int maxMessageSize = config.getSetting(PROTOCOL_MESSAGE_SIZE_MAX);
    Charset clientEncoding = config.getSetting(PROTOCOL_ENCODING);

    Class channelType;
    Class groupType;
    if (KQueue.isAvailable()) {
      channelType = KQueueDomainSocketChannel.class;
      groupType = KQueueEventLoopGroup.class;
    }
    else if (Epoll.isAvailable()) {
      channelType = EpollDomainSocketChannel.class;
      groupType = EpollEventLoopGroup.class;
    }
    else {
      throw new IllegalArgumentException("Unix domain sockets not supported: missing native libraries");
    }

    int maxThreads = config.getSetting(PROTOCOL_IO_THREADS);

    ServerConnectionShared.Ref sharedRef = ServerConnectionShared.acquire(groupType, maxThreads);

    Writer protocolTraceWriter = createProtocolTracer(config);

    Bootstrap bootstrap = new Bootstrap()
        .group(sharedRef.get().getEventLoopGroup())
        .channel(channelType)
        .handler(new ChannelInitializer() {
          @Override
          protected void initChannel(DomainSocketChannel ch) {
            ch.pipeline().addLast(
                new LengthFieldBasedFrameDecoder(maxMessageSize, 1, 4, -4, 0),
                new MessageDispatchHandler(clientEncoding, protocolTraceWriter)
            );
          }
        });

    configureChannelOptions(config, bootstrap);

    ChannelFuture channelFuture = bootstrap.connect(address);

    return new CreatedChannel(sharedRef, channelFuture);
  }

  private void configureChannelOptions(Configuration config, Bootstrap bootstrap) {

    Integer receiveBufferSize = config.getSetting(PROTOCOL_SOCKET_RECV_BUFFER_SIZE);
    if (receiveBufferSize != null) {
      bootstrap.option(ChannelOption.SO_RCVBUF, receiveBufferSize);
    }

    Integer sendBufferSize = config.getSetting(PROTOCOL_SOCKET_SEND_BUFFER_SIZE);
    if (sendBufferSize != null) {
      bootstrap.option(ChannelOption.SO_SNDBUF, sendBufferSize);
    }

    boolean usePooledAllocator = config.getSetting(PROTOCOL_BUFFER_POOLING);
    bootstrap.option(ChannelOption.ALLOCATOR, usePooledAllocator ? PooledByteBufAllocator.DEFAULT : UnpooledByteBufAllocator.DEFAULT);
  }

  private Writer createProtocolTracer(Configuration config) {
    if (config.getSetting(PROTOCOL_TRACE)) {
      OutputStream out = System.out;
      String filePath = config.getSetting(PROTOCOL_TRACE_FILE);
      if (filePath != null) {
        try {
          out = new FileOutputStream(filePath, false);
        }
        catch (FileNotFoundException ignored) {
        }
      }
      return new BufferedWriter(new OutputStreamWriter(out));
    }
    return null;
  }

  private static ServerConnection startup(Configuration config, Channel channel, Map startupParameterStatuses, ServerConnectionShared.Ref sharedRef) throws IOException {

    Map params = new HashMap<>();
    params.put(ParameterNames.APPLICATION_NAME, config.getSetting(APPLICATION_NAME));
    params.put(ParameterNames.CLIENT_ENCODING, config.getSetting(PROTOCOL_ENCODING));
    params.put(ParameterNames.DATABASE, config.getSetting(DATABASE_NAME));
    params.put(ParameterNames.USER, config.getSetting(CREDENTIALS_USERNAME));

    Version protocolVersion = config.getSetting(PROTOCOL_VERSION);

    AtomicReference startupProtocolVersion = new AtomicReference<>();
    AtomicReference startupKeyData = new AtomicReference<>();
    AtomicReference startupError = new AtomicReference<>();
    CountDownLatch startupCompleted = new CountDownLatch(1);

    StartupRequest startupRequest = new StartupRequest(protocolVersion, params, new AuthenticationHandler(config, channel) {

      @Override
      public void handleNegotiate(Version maxProtocolVersion, List unrecognizedParameters) {
        startupProtocolVersion.set(maxProtocolVersion);
      }

      @Override
      public void handleComplete(int processId, int secretKey, Map parameterStatuses, List notices) {

        startupParameterStatuses.putAll(parameterStatuses);
        startupKeyData.set(new KeyData(processId, secretKey));

        startupCompleted.countDown();
      }

      @Override
      public void handleError(Throwable error, List notices) {
        startupError.set(error);
        startupCompleted.countDown();
      }

    });
    channel.writeAndFlush(startupRequest).syncUninterruptibly();

    if (!awaitUninterruptibly(DEFAULT_STARTUP_TIMEOUT, SECONDS, startupCompleted::await)) {
      throw new IOException("Timeout starting connection");
    }

    if (startupError.get() != null) {
      Throwable error = startupError.get();
      if (error instanceof IOException) throw (IOException) error;
      if (error instanceof RuntimeException) throw (RuntimeException) error;
      throw new RuntimeException(error);
    }

    // Pull out static server parameters
    ServerInfo serverInfo = new ServerInfo(
        Version.parse(startupParameterStatuses.remove("server_version")),
        startupParameterStatuses.remove("server_encoding"),
        firstNonNull(startupParameterStatuses.remove("integer_datetimes"), "on").equalsIgnoreCase("on")
    );

    protocolVersion = startupProtocolVersion.get() != null ? startupProtocolVersion.get() : protocolVersion;

    return new ServerConnection(config, channel, serverInfo, protocolVersion, startupKeyData.get(), sharedRef);
  }

  private static IOException translateConnectionException(Exception e) {

    IOException io;

    // Unwrap
    if (e instanceof ClosedChannelException) {
      io = new IOException("Channel Closed", e);
    }
    else if (e instanceof IOException) {
      io = (IOException) e;
    }
    else if (e.getCause() == null) {
      io = new IOException(e);
    }
    else if (e.getCause() instanceof IOException) {
      io = (IOException) e.getCause();
    }
    else {
      io = new IOException(e.getCause());
    }

    // Unwrap SSL Handshake exceptions

    while (io instanceof SSLHandshakeException) {
      if (io.getCause() instanceof IOException) {
        io = (IOException) io.getCause();
      }
      else if (io.getCause() != null) {
        io = new SSLException(io.getCause().getMessage(), io.getCause());
      }
      else {
        io = new SSLException(io.getMessage(), io);
      }
    }

    if (io instanceof SSLException) {
      if (!io.getMessage().startsWith("SSL Error"))
        io = new SSLException("SSL Error: " + io.getMessage(), io.getCause());
    }

    return io;
  }

  static class DefaultHandler implements ParameterStatus, ReportNotice, Notification, CopyInResponse, CopyOutResponse, CommandError {

    private static final Logger logger = Logger.getLogger(ServerConnection.class.getName());

    private final WeakReference listener;

    DefaultHandler(ServerConnection.Listener listener) {
      this.listener = new WeakReference<>(listener);
    }

    private ServerConnection.Listener getListener() {
      return listener.get();
    }

    @Override
    public String toString() {
      return "DEFAULT";
    }

    @Override
    public Action parameterStatus(String name, String value) {
      ServerConnection.Listener listener = getListener();
      if (listener != null) {
        listener.parameterStatusChanged(name, value);
      }
      return Action.Resume;
    }

    @Override
    public void notification(int processId, String channelName, String payload) {
      ServerConnection.Listener listener = getListener();
      if (listener != null) {
        listener.notificationReceived(processId, channelName, payload);
      }
    }

    @Override
    public InputStream copyIn(CopyFormat format, FieldFormat[] fieldFormats) {
      ServerConnection.Listener listener = getListener();
      if (listener == null) return null;
      return listener.openStandardInput();
    }

    @Override
    public ProtocolHandler copyOut(CopyFormat format, FieldFormat[] fieldFormats) {
      ServerConnection.Listener listener = getListener();
      if (listener == null) return null;
      return new DefaultCopyOutHandler(listener.openStandardOutput());
    }

    @Override
    public void exception(Channel channel, Throwable cause) {
      if (!channel.isOpen()) {
        ServerConnection.Listener listener = getListener();
        if (listener != null) {
          listener.closed();
        }
      }
    }

    @Override
    public void exception(Throwable cause) {
      if (cause instanceof ClosedChannelException) return;
      logger.log(Level.WARNING, "Unhandled connection exception", cause);
    }

    @Override
    public Action notice(Notice notice) {
      return null;
    }

    @Override
    public Action error(Notice notice) {
      logger.warning(notice.getMessage());
      return Action.Resume;
    }

  }

  static class DefaultCopyOutHandler implements CopyData, CopyDone, CopyFail {

    OutputStream stream;

    DefaultCopyOutHandler(OutputStream stream) {
      this.stream = stream;
    }

    @Override
    public void copyData(ByteBuf data) throws IOException {

      while (data.isReadable()) {
        data.readBytes(stream, data.readableBytes());
      }

    }

    @Override
    public void copyDone() {
    }

    public void copyFail(String message) {
    }

    @Override
    public void exception(Throwable cause) {
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy