com.twitter.finagle.postgres.codec.PgCodec.scala Maven / Gradle / Ivy
package com.twitter.finagle.postgres.codec
import com.twitter.finagle._
import com.twitter.finagle.postgres.ResultSet
import com.twitter.finagle.postgres.connection.{AuthenticationRequired, Connection, RequestingSsl}
import com.twitter.finagle.postgres.messages._
import com.twitter.finagle.postgres.values.Md5Encryptor
import com.twitter.logging.Logger
import com.twitter.util.Future
import javax.net.ssl.TrustManagerFactory
import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.frame.FrameDecoder
import org.jboss.netty.handler.ssl.util.InsecureTrustManagerFactory
import org.jboss.netty.handler.ssl.{SslHandler, SslContext}
import scala.collection.mutable
/*
* Postgres codec implementation.
*
* Used by client to encode requests and parse responses.
*/
class PgCodec(
user: String,
password: Option[String],
database: String,
id: String,
useSsl: Boolean,
trustManagerFactory: TrustManagerFactory = InsecureTrustManagerFactory.INSTANCE,
customTypes: Boolean = false)
extends CodecFactory[PgRequest, PgResponse] {
def server = throw new UnsupportedOperationException("client only")
val sslContext: SslContext = SslContext.newClientContext(trustManagerFactory)
def client = Function.const {
new Codec[PgRequest, PgResponse] {
def pipelineFactory = new ChannelPipelineFactory {
def getPipeline = {
val pipeline = Channels.pipeline()
pipeline.addLast("binary_to_packet", new PacketDecoder(useSsl))
pipeline.addLast("packet_to_backend_messages", new BackendMessageDecoder(new BackendMessageParser))
pipeline.addLast("backend_messages_to_postgres_response", new PgClientChannelHandler(sslContext, useSsl))
pipeline
}
}
override def prepareConnFactory(underlying: ServiceFactory[PgRequest, PgResponse]) = {
val errorHandling = new HandleErrorsProxy(underlying)
new AuthenticationProxy(errorHandling, user, password, database, useSsl)
}
override def prepareServiceFactory(underlying: ServiceFactory[PgRequest, PgResponse]) = {
if (customTypes) {
// Make query to DB to get custom types in current context
new CustomOIDProxy(underlying, id)
} else {
// Use empty custom types map
CustomOIDProxy.serviceOIDMap += id -> Map()
super.prepareServiceFactory(underlying)
}
}
}
}
}
/*
* Filter that converts exceptions into ServerErrors.
*/
class HandleErrorsProxy(
delegate: ServiceFactory[PgRequest, PgResponse]) extends ServiceFactoryProxy(delegate) {
override def apply(conn: ClientConnection): Future[Service[PgRequest, PgResponse]] = {
for {
service <- delegate.apply(conn)
} yield HandleErrors.andThen(service)
}
object HandleErrors extends SimpleFilter[PgRequest, PgResponse] {
def apply(request: PgRequest, service: Service[PgRequest, PgResponse]) = {
service.apply(request).flatMap {
case Error(details) =>
Future.exception(Errors.server("%s\n%s".format(request.toString(), details.getOrElse(""))))
case r => Future.value(r)
}
}
}
}
/*
* Filter that does password authentication before issuing requests.
*/
class AuthenticationProxy(
delegate: ServiceFactory[PgRequest, PgResponse],
user: String, password: Option[String],
database: String,
useSsl: Boolean) extends ServiceFactoryProxy(delegate) {
private val logger = Logger(getClass.getName)
override def apply(conn: ClientConnection): Future[Service[PgRequest, PgResponse]] = {
for {
service <- delegate.apply(conn)
optionalSslResponse <- sendSslRequest(service)
_ <- handleSslResponse(optionalSslResponse)
startupResponse <- service(PgRequest(new StartupMessage(user, database)))
passwordResponse <- sendPassword(startupResponse, service)
_ <- verifyResponse(passwordResponse)
} yield service
}
private[this] def sendSslRequest(service: Service[PgRequest, PgResponse]): Future[Option[PgResponse]] = {
if (useSsl) {
service(PgRequest(new SslRequestMessage)).map { response => Some(response) }
} else {
Future.value(None)
}
}
private[this] def handleSslResponse(optionalSslResponse: Option[PgResponse]): Future[Unit] = {
logger.ifDebug("SSL response: %s".format(optionalSslResponse))
if (useSsl && optionalSslResponse == Some(SslNotSupportedResponse)) {
throw Errors.server("SSL requested by server doesn't support it")
} else {
Future(Unit)
}
}
private[this] def sendPassword(
startupResponse: PgResponse, service: Service[PgRequest, PgResponse]): Future[PgResponse] = {
startupResponse match {
case PasswordRequired(encoding) => password match {
case Some(pass) =>
val msg = encoding match {
case ClearText => PasswordMessage(pass)
case Md5(salt) => PasswordMessage(new String(Md5Encryptor.encrypt(user.getBytes, pass.getBytes, salt)))
}
service(PgRequest(msg))
case None => Future.exception(Errors.client("Password has to be specified for authenticated connection"))
}
case r => Future.value(r)
}
}
private[this] def verifyResponse(response: PgResponse): Future[Unit] = {
response match {
case AuthenticatedResponse(statuses, processId, secretKey) =>
logger.ifDebug("Authenticated: %d %d\n%s".format(processId, secretKey, statuses))
Future(Unit)
}
}
}
object CustomOIDProxy {
val serviceOIDMap = new mutable.HashMap[String, Map[String, String]]()
}
/*
* Filter for handling custom types in responses.
*/
class CustomOIDProxy(
delegate: ServiceFactory[PgRequest, PgResponse], id:String) extends ServiceFactoryProxy(delegate) {
val customTypes = """
|SELECT t.typname as type, t.oid as oid
|FROM pg_type t
|LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
|AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid)
|AND n.nspname NOT IN ('pg_catalog', 'information_schema')
""".stripMargin
override def apply(conn: ClientConnection): Future[Service[PgRequest, PgResponse]] = {
for {
service <- delegate.apply(conn)
typeResponse <- service(new PgRequest(new Query(customTypes)))
_ <- handleTypeResponse(typeResponse)
} yield service
}
def handleTypeResponse(response:PgResponse):Future[Unit] = {
val result: ResultSet = response match {
case SelectResult(fields, rows) => ResultSet(fields, rows, Map())
case _ => throw Errors.client("Expected a SelectResult")
}
val typeMap:Map[String, String] = result.rows.map { row =>
(row.get[String]("oid"), row.get[String]("type"))
}.toMap
CustomOIDProxy.serviceOIDMap += id -> typeMap
Future(Unit)
}
}
/*
* Decodes a Packet into a BackendMessage.
*/
class BackendMessageDecoder(val parser: BackendMessageParser) extends SimpleChannelHandler {
private val logger = Logger(getClass.getName)
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
val message = e.getMessage
message match {
case packet: Packet =>
parser.parse(packet) match {
case Some(backendMessage) =>
Channels.fireMessageReceived(ctx, backendMessage)
case None =>
logger.warning("Cannot parse the packet. Disconnecting...")
Channels.disconnect(ctx.getChannel)
}
case _ =>
logger.warning("Only packet is supported...")
Channels.disconnect(ctx.getChannel)
}
}
}
/*
* Decodes a byte stream into a Packet.
*/
class PacketDecoder(@volatile var inSslNegotation: Boolean) extends FrameDecoder {
private val logger = Logger(getClass.getName)
def decode(ctx: ChannelHandlerContext, channel: Channel, buffer: ChannelBuffer): AnyRef = {
if (inSslNegotation && buffer.readableBytes() >= 1) {
val SslCode: Char = buffer.readByte().asInstanceOf[Char]
logger.ifDebug("Got ssl negotiation char packet: %s".format(SslCode))
inSslNegotation = false
new Packet(Some(SslCode), 1, null, true)
} else if (buffer.readableBytes() < 5) {
null
} else {
buffer.markReaderIndex()
val code: Char = buffer.readByte().asInstanceOf[Char]
val totalLength = buffer.readInt()
val length = totalLength - 4
if (buffer.readableBytes() < length) {
buffer.resetReaderIndex()
return null
}
val packet = new Packet(Some(code), totalLength, buffer.readSlice(length))
packet
}
}
}
/*
* Map PgRequest to PgResponse.
*/
class PgClientChannelHandler(val sslContext: SslContext, val useSsl: Boolean) extends SimpleChannelHandler {
private[this] val logger = Logger(getClass.getName)
private[this] val connection = {
if (useSsl) {
new Connection(startState = RequestingSsl)
} else {
new Connection(startState = AuthenticationRequired)
}
}
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
logger.ifDebug("Detected channel disconnected!")
super.channelDisconnected(ctx, e)
}
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
val message = e.getMessage
message match {
case SwitchToSsl =>
logger.ifDebug("Got switchToSSL message; adding ssl handler into pipeline")
val pipeline = ctx.getPipeline
val engine = sslContext.newEngine()
engine.setUseClientMode(true)
pipeline.addFirst("ssl", new SslHandler(engine))
connection.receive(SwitchToSsl).map {
Channels.fireMessageReceived(ctx, _)
}
case msg: BackendMessage =>
connection.receive(msg).map {
Channels.fireMessageReceived(ctx, _)
}
case unsupported =>
logger.warning("Only backend messages are supported...")
Channels.disconnect(ctx.getChannel)
}
}
override def writeRequested(ctx: ChannelHandlerContext, event: MessageEvent) = {
val buf = event.getMessage match {
case PgRequest(msg, flush) =>
val packet = msg.asPacket()
val c = ChannelBuffers.dynamicBuffer()
c.writeBytes(packet.encode)
if (flush) {
c.writeBytes(Flush.asPacket.encode)
}
connection.send(msg)
c
case _ =>
logger.warning("Cannot convert message... Skipping")
event.getMessage
}
Channels.write(ctx, event.getFuture, buf, event.getRemoteAddress)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy