
com.github.mauricio.async.db.postgresql.codec.PostgreSQLConnectionHandler.scala Maven / Gradle / Ivy
/*
* Copyright 2013 Maurício Linhares
*
* Maurício Linhares licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.github.mauricio.async.db.postgresql.codec
import com.github.mauricio.async.db.Configuration
import com.github.mauricio.async.db.SSLConfiguration.Mode
import com.github.mauricio.async.db.column.{
ColumnDecoderRegistry,
ColumnEncoderRegistry
}
import com.github.mauricio.async.db.postgresql.exceptions._
import com.github.mauricio.async.db.postgresql.messages.backend._
import com.github.mauricio.async.db.postgresql.messages.frontend._
import com.github.mauricio.async.db.util.ChannelFutureTransformer.toFuture
import com.github.mauricio.async.db.util._
import java.net.InetSocketAddress
import scala.annotation.switch
import scala.concurrent._
import io.netty.channel._
import io.netty.bootstrap.Bootstrap
import io.netty.channel
import scala.util.Failure
import com.github.mauricio.async.db.postgresql.messages.backend.DataRowMessage
import com.github.mauricio.async.db.postgresql.messages.backend.CommandCompleteMessage
import com.github.mauricio.async.db.postgresql.messages.backend.ProcessData
import scala.util.Success
import com.github.mauricio.async.db.postgresql.messages.backend.RowDescriptionMessage
import com.github.mauricio.async.db.postgresql.messages.backend.ParameterStatusMessage
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.codec.CodecException
import io.netty.handler.ssl.{SslContextBuilder, SslHandler}
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
import io.netty.util.concurrent.FutureListener
import javax.net.ssl.{SSLParameters, TrustManagerFactory}
import java.security.KeyStore
import java.io.FileInputStream
object PostgreSQLConnectionHandler {
final val log = Log.get[PostgreSQLConnectionHandler]
}
class PostgreSQLConnectionHandler(
configuration: Configuration,
encoderRegistry: ColumnEncoderRegistry,
decoderRegistry: ColumnDecoderRegistry,
connectionDelegate: PostgreSQLConnectionDelegate,
group: EventLoopGroup,
executionContext: ExecutionContext
) extends SimpleChannelInboundHandler[Object] {
import PostgreSQLConnectionHandler.log
private val properties = List(
"user" -> configuration.username,
"database" -> configuration.database,
"client_encoding" -> configuration.charset.name(),
"DateStyle" -> "ISO",
"extra_float_digits" -> "2"
)
private implicit final val _executionContext: ExecutionContext =
executionContext
private final val bootstrap = new Bootstrap()
private final val connectionFuture = Promise[PostgreSQLConnectionHandler]()
private final val disconnectionPromise =
Promise[PostgreSQLConnectionHandler]()
private var processData: ProcessData = null
private var currentContext: ChannelHandlerContext = null
def connect: Future[PostgreSQLConnectionHandler] = {
this.bootstrap.group(this.group)
this.bootstrap.channel(classOf[NioSocketChannel])
this.bootstrap.handler(new ChannelInitializer[channel.Channel]() {
override def initChannel(ch: channel.Channel): Unit = {
ch.pipeline.addLast(
new MessageDecoder(
configuration.ssl.mode != Mode.Disable,
configuration.charset,
configuration.maximumMessageSize
),
new MessageEncoder(configuration.charset, encoderRegistry),
PostgreSQLConnectionHandler.this
)
}
})
this.bootstrap.option[java.lang.Boolean](ChannelOption.SO_KEEPALIVE, true)
this.bootstrap.option(ChannelOption.ALLOCATOR, configuration.allocator)
this.bootstrap
.connect(new InetSocketAddress(configuration.host, configuration.port))
.failed
.foreach { case e =>
connectionFuture.tryFailure(e)
}
this.connectionFuture.future
}
def disconnect: Future[PostgreSQLConnectionHandler] = {
if (this.isConnected) {
this.currentContext.channel.writeAndFlush(CloseMessage).onComplete {
case Success(writeFuture) =>
writeFuture.channel.close().onComplete {
case Success(closeFuture) =>
this.disconnectionPromise.trySuccess(this)
case Failure(e) => this.disconnectionPromise.tryFailure(e)
}
case Failure(e) => this.disconnectionPromise.tryFailure(e)
}
}
this.disconnectionPromise.future
}
def isConnected: Boolean = {
if (this.currentContext != null) {
this.currentContext.channel.isActive
} else {
false
}
}
override def channelActive(ctx: ChannelHandlerContext): Unit = {
if (configuration.ssl.mode == Mode.Disable)
ctx.writeAndFlush(new StartupMessage(this.properties))
else
ctx.writeAndFlush(SSLRequestMessage)
}
override def channelRead0(ctx: ChannelHandlerContext, msg: Object): Unit = {
msg match {
case SSLResponseMessage(supported) =>
if (supported) {
val ctxBuilder = SslContextBuilder.forClient()
if (configuration.ssl.mode >= Mode.VerifyCA) {
configuration.ssl.rootCert.fold {
val tmf = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm()
)
val ks = KeyStore.getInstance(KeyStore.getDefaultType())
val cacerts = new FileInputStream(
System.getProperty("java.home") + "/lib/security/cacerts"
)
try {
ks.load(cacerts, "changeit".toCharArray)
} finally {
cacerts.close()
}
tmf.init(ks)
ctxBuilder.trustManager(tmf)
} { path =>
ctxBuilder.trustManager(path)
}
} else {
ctxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
}
val sslContext = ctxBuilder.build()
val sslEngine = sslContext.newEngine(
ctx.alloc(),
configuration.host,
configuration.port
)
if (configuration.ssl.mode >= Mode.VerifyFull) {
val sslParams = sslEngine.getSSLParameters()
sslParams.setEndpointIdentificationAlgorithm("HTTPS")
sslEngine.setSSLParameters(sslParams)
}
val handler = new SslHandler(sslEngine)
ctx.pipeline().addFirst(handler)
handler.handshakeFuture.addListener(
new FutureListener[channel.Channel]() {
def operationComplete(
future: io.netty.util.concurrent.Future[channel.Channel]
): Unit = {
if (future.isSuccess()) {
ctx.writeAndFlush(new StartupMessage(properties))
} else {
connectionDelegate.onError(future.cause())
}
}
}
)
} else if (configuration.ssl.mode < Mode.Require) {
ctx.writeAndFlush(new StartupMessage(properties))
} else {
connectionDelegate.onError(
new IllegalArgumentException("SSL is not supported on server")
)
}
case m: ServerMessage => {
(m.kind: @switch) match {
case ServerMessage.BackendKeyData => {
this.processData = m.asInstanceOf[ProcessData]
}
case ServerMessage.BindComplete => {}
case ServerMessage.Authentication => {
log.debug("Authentication response received {}", m)
connectionDelegate.onAuthenticationResponse(
m.asInstanceOf[AuthenticationMessage]
)
}
case ServerMessage.CommandComplete => {
connectionDelegate.onCommandComplete(
m.asInstanceOf[CommandCompleteMessage]
)
}
case ServerMessage.CloseComplete => {}
case ServerMessage.DataRow => {
connectionDelegate.onDataRow(m.asInstanceOf[DataRowMessage])
}
case ServerMessage.Error => {
connectionDelegate.onError(m.asInstanceOf[ErrorMessage])
}
case ServerMessage.EmptyQueryString => {
val exception = new QueryMustNotBeNullOrEmptyException(null)
exception.fillInStackTrace()
connectionDelegate.onError(exception)
}
case ServerMessage.NoData => {}
case ServerMessage.Notice => {
log.info("Received notice {}", m)
}
case ServerMessage.NotificationResponse => {
connectionDelegate.onNotificationResponse(
m.asInstanceOf[NotificationResponse]
)
}
case ServerMessage.ParameterStatus => {
connectionDelegate.onParameterStatus(
m.asInstanceOf[ParameterStatusMessage]
)
}
case ServerMessage.ParseComplete => {}
case ServerMessage.ReadyForQuery => {
connectionDelegate.onReadyForQuery()
}
case ServerMessage.RowDescription => {
connectionDelegate.onRowDescription(
m.asInstanceOf[RowDescriptionMessage]
)
}
case _ => {
val exception = new IllegalStateException(
"Handler not implemented for message %s".format(m.kind)
)
exception.fillInStackTrace()
connectionDelegate.onError(exception)
}
}
}
case _ => {
log.error("Unknown message type - {}", msg)
val exception = new IllegalArgumentException(
"Unknown message type - %s".format(msg)
)
exception.fillInStackTrace()
connectionDelegate.onError(exception)
}
}
}
override def exceptionCaught(
ctx: ChannelHandlerContext,
cause: Throwable
): Unit = {
// unwrap CodecException if needed
cause match {
case t: CodecException => connectionDelegate.onError(t.getCause)
case _ => connectionDelegate.onError(cause)
}
}
override def channelInactive(ctx: ChannelHandlerContext): Unit = {
log.info("Connection disconnected - {}", ctx.channel.remoteAddress)
}
override def handlerAdded(ctx: ChannelHandlerContext): Unit = {
this.currentContext = ctx
}
def write(message: ClientMessage): Unit = {
this.currentContext.writeAndFlush(message).failed.foreach {
case e: Throwable => connectionDelegate.onError(e)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy