All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.otavia.postgres.PostgresDriver.scala Maven / Gradle / Ivy

/*
 * Copyright 2022 Yan Kun 
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package cc.otavia.postgres

import cc.otavia.buffer.pool.AdaptiveBuffer
import cc.otavia.buffer.{Buffer, BufferUtils}
import cc.otavia.core.channel.{Channel, ChannelHandlerContext, ChannelInflight, ChannelOption}
import cc.otavia.core.slf4a.{Logger, LoggerFactory}
import cc.otavia.postgres.impl.*
import cc.otavia.postgres.protocol.{Constants, DataFormat, DataType}
import cc.otavia.postgres.utils.ScramAuthentication.ScramClientInitialMessage
import cc.otavia.postgres.utils.{MD5Authentication, PgBufferUtils, ScramAuthentication}
import cc.otavia.sql.*
import cc.otavia.sql.impl.HexSequence
import cc.otavia.sql.statement.*
import cc.otavia.sql.statement.PrepareQuery.*
import cc.otavia.sql.statement.SimpleQuery.*

import java.nio.charset.StandardCharsets
import java.util.Collections
import scala.collection.mutable
import scala.language.unsafeNulls

class PostgresDriver(override val options: PostgresConnectOptions) extends Driver(options) {

    import PostgresDriver.*

    private var logger: Logger             = _
    private var ctx: ChannelHandlerContext = _

    private var status: Int = ST_CONNECTING

    private var scramAuthentication: ScramAuthentication = _
    private var encoding: String                         = _

    private var currentOutboundMessageId = ChannelInflight.INVALID_CHANNEL_MESSAGE_ID

    private var metadata: PostgresDatabaseMetadata = _
    private var processId: Int                     = 0
    private var secretKey: Int                     = 0

    private val response: Response = new Response

    private val rowDesc: RowDesc          = new RowDesc()
    private var continueParseRow: Boolean = false

    private val rowBuffer: mutable.ArrayBuffer[Product]    = mutable.ArrayBuffer.empty
    private val rowOffsets: mutable.ArrayBuffer[RowOffset] = mutable.ArrayBuffer.empty

    private val rowParser: PostgresRowParser = new PostgresRowParser()
    private val rowWriter: PostgresRowWriter = new PostgresRowWriter()

    private val prepareStatements: mutable.HashMap[String, PreparedStatement] = mutable.HashMap.empty
    private val psSeq             = new HexSequence() // used for generating named prepared statement name
    private var compiled: Boolean = false             // whether the current prepared query compiled

    private var modifyRows: Int = 0

    private var error: Throwable = _

    override def setChannelOptions(channel: Channel): Unit = {
        channel.setOption(ChannelOption.CHANNEL_MAX_FUTURE_INFLIGHT, options.pipeliningLimit)
        channel.setOption(ChannelOption.CHANNEL_FUTURE_BARRIER, this.futureBarrier)
        // TODO:
    }

    private def futureBarrier(msg: AnyRef): Boolean = msg match
        case prepareQuery: PrepareQuery[?] => !prepareStatements.contains(prepareQuery.sql)
        case _                             => false

    final override protected def checkDecodePacket(buffer: Buffer): Boolean =
        if (buffer.readableBytes >= 5) {
            val start        = buffer.readerOffset
            val packetLength = buffer.getInt(start + 1) + 1
            if (buffer.readableBytes >= packetLength) true else false
        } else false

    override protected def encode(ctx: ChannelHandlerContext, output: AdaptiveBuffer, msg: AnyRef, mid: Long): Unit = {
        currentOutboundMessageId = mid
        msg match
            case _: Authentication =>
                if (status == ST_CONNECTING) {
                    sendStartupMessage()
                    status = ST_AUTHENTICATING
                }
            case simpleQuery: SimpleQuery[?]   => encodeSimpleQuery(simpleQuery, output)
            case prepareQuery: PrepareQuery[?] => encodePrepareQuery(prepareQuery, output)
    }

    override protected def decode(ctx: ChannelHandlerContext, input: AdaptiveBuffer): Unit = {
        while (checkDecodePacket(input)) {
            val packetStart = input.readerOffset

            val id     = input.readByte
            val length = input.readInt

            id match
                case Constants.MSG_TYPE_READY_FOR_QUERY       => decodeReadyForQuery(input)
                case Constants.MSG_TYPE_DATA_ROW              => decodeDataRow(input, length)
                case Constants.MSG_TYPE_COMMAND_COMPLETE      => decodeCommandComplete(input, length)
                case Constants.MSG_TYPE_BIND_COMPLETE         => decodeBindComplete()
                case Constants.MSG_TYPE_ROW_DESCRIPTION       => decodeRowDescription(input)
                case Constants.MSG_TYPE_ERROR_RESPONSE        => decodeErrorResponse(input, length)
                case Constants.MSG_TYPE_NOTICE_RESPONSE       => ???
                case Constants.MSG_TYPE_AUTHENTICATION        => decodeAuthentication(input, length)
                case Constants.MSG_TYPE_EMPTY_QUERY_RESPONSE  => ???
                case Constants.MSG_TYPE_PARSE_COMPLETE        =>
                case Constants.MSG_TYPE_CLOSE_COMPLETE        => ???
                case Constants.MSG_TYPE_NO_DATA               =>
                case Constants.MSG_TYPE_PORTAL_SUSPENDED      => ???
                case Constants.MSG_TYPE_PARAMETER_DESCRIPTION => decodeParameterDescription(input, length)
                case Constants.MSG_TYPE_PARAMETER_STATUS      => decodeParameterStatus(input)
                case Constants.MSG_TYPE_BACKEND_KEY_DATA      => decodeBackendKeyData(input)
                case Constants.MSG_TYPE_NOTIFICATION_RESPONSE => ???
                case _                                        =>

            // skip remaining data
            if (input.readerOffset - packetStart < length + 1) input.readerOffset(packetStart + length + 1)
        }
        if (input.readableBytes == 0) input.compact()
    }

    private def sendStartupMessage(): Unit = {
        val packet = ctx.outboundAdaptiveBuffer

        val startIdx = packet.writerOffset
        packet.writeInt(0) // set payload length by calculation

        // protocol version
        packet.writeShort(3)
        packet.writeShort(0)

        PgBufferUtils.writeCString(packet, USER)
        PgBufferUtils.writeCString(packet, options.user, StandardCharsets.UTF_8)

        PgBufferUtils.writeCString(packet, DATABASE)
        PgBufferUtils.writeCString(packet, options.database, StandardCharsets.UTF_8)

        for ((key, value) <- options.properties) {
            PgBufferUtils.writeCString(packet, key, StandardCharsets.UTF_8)
            PgBufferUtils.writeCString(packet, value, StandardCharsets.UTF_8)
        }

        packet.writeByte(0)

        packet.setInt(startIdx, packet.writerOffset - startIdx)

    }

    private def encodeQuery(sql: String, output: Buffer): Unit = {
        val packet = output
        packet.writeByte(QUERY)
        val pos = packet.writerOffset
        packet.writeInt(0)
        PgBufferUtils.writeCString(packet, sql, StandardCharsets.UTF_8)
        packet.setInt(pos, packet.writerOffset - pos)
    }

    private def encodeSimpleQuery(query: SimpleQuery[?], output: Buffer): Unit = {
        val sql    = query.sql
        val packet = output
        packet.writeByte(QUERY)
        val pos = packet.writerOffset
        packet.writeInt(0)
        PgBufferUtils.writeCString(packet, sql, StandardCharsets.UTF_8)
        packet.setInt(pos, packet.writerOffset - pos)
    }

    private def decodeReadyForQuery(payload: Buffer): Unit = {
        val id = payload.readByte
        if (id == I) {
            // IDLE
        } else if (id == T) {
            // ACTIVE
        } else {
            // FAILED
            error = new TransactionFailed
        }

        val inflight = ctx.inflightFutures
        val promise  = inflight.first
        val query    = promise.getAsk()
        if (error != null) {
            val cause = error
            error = null
            ctx.fireChannelExceptionCaught(cause, promise.messageId)
        } else {
            query match
                case query: PrepareQuery[?] if !compiled =>
                    encodePrepareQuery(query, ctx.outboundAdaptiveBuffer) // stage 2: execute prepared statement
                    compiled = true
                    ctx.inflightFutures.setBarrierMode(false)
                case _ =>
                    query match
                        case authentication: Authentication =>
                            if (status == ST_AUTHENTICATED) ctx.fireChannelRead(None, promise.messageId)
                            else ctx.fireChannelExceptionCaught(new UnknownError(), promise.messageId)
                        case update: (ExecuteUpdate | PrepareUpdate | PrepareUpdateBatch) =>
                            ctx.fireChannelRead(ModifyRows(modifyRows), promise.messageId)
                            modifyRows = 0
                        case query: ExecuteQuery[?] =>
                            if (rowBuffer.nonEmpty)
                                ctx.fireChannelRead(rowBuffer.head.asInstanceOf[Row], promise.messageId)
                            else ctx.fireChannelExceptionCaught(new Error("No data fetch"), promise.messageId)
                            rowBuffer.clear()
                        case query: PrepareQuery.FetchQuery[?] if !query.all =>
                            if (rowBuffer.nonEmpty)
                                ctx.fireChannelRead(rowBuffer.head.asInstanceOf[Row], promise.messageId)
                            else ctx.fireChannelExceptionCaught(new Error("No data fetch"), promise.messageId)
                            rowBuffer.clear()
                        case queries: ExecuteQueries[?] =>
                            if (rowBuffer.nonEmpty) {
                                val rowSet = RowSet(rowBuffer.toArray)
                                ctx.fireChannelRead(rowSet, promise.messageId)
                            } else ctx.fireChannelExceptionCaught(new Error("No data fetch"), promise.messageId)
                            rowBuffer.clear()
                        case query: PrepareQuery.FetchQuery[?] if query.all =>
                            if (rowBuffer.nonEmpty) {
                                val rowSet = RowSet(rowBuffer.toArray)
                                ctx.fireChannelRead(rowSet, promise.messageId)
                            } else ctx.fireChannelExceptionCaught(new Error("No data fetch"), promise.messageId)
                            rowBuffer.clear()
        }

    }

    private def recycleRowOffset(): Unit = {
        for (off <- rowOffsets) off.recycle()
        rowOffsets.clear()
    }

    private def decodeDataRow(payload: Buffer, payloadLength: Int): Unit =
        if (continueParseRow) {
            recycleRowOffset()
            val len    = payload.readUnsignedShort
            var i      = 0
            var offset = 0
            while (i < len) {
                val length = payload.getInt(payload.readerOffset + offset)
                rowOffsets.addOne(RowOffset(offset, length))
                if (length > 0) offset += 4 + length else offset += 4
                i += 1
            }
            rowParser.setPayload(payload)
            rowParser.setRowOffsets(rowOffsets)

            val future = ctx.inflightFutures.first
            val cmd    = future.getAsk()

            cmd match
                case executeQuery: ExecuteQuery[?] =>
                    val row = executeQuery.codec.decode(rowParser)
                    continueParseRow = false
                    rowBuffer.addOne(row)
                case executeQueries: ExecuteQueries[?] =>
                    val codec = executeQueries.codec
                    val row   = codec.decode(rowParser)
                    rowBuffer.addOne(row)
                case fetchQuery: PrepareQuery.FetchQuery[?] =>
                    decodeDataRow0(fetchQuery.codec, prepareStatements(fetchQuery.sql))
                    if (!fetchQuery.all) continueParseRow = false
        }

    private def decodeDataRow0(codec: RowCodec[?], ps: PreparedStatement): Unit = {
        rowParser.setRowDesc(ps.rowDesc)
        val row = codec.decode(rowParser)
        rowBuffer.addOne(row)
    }

    private def decodeRowDescription(payload: Buffer): Unit = {
        ctx.inflightFutures.first.getAsk() match
            case value: PrepareQuery[?] =>
                val ps           = prepareStatements(value.sql)
                val columnLength = payload.readUnsignedShort
                ps.rowDesc.setLength(columnLength)
                var i = 0
                while (i < columnLength) {
                    val columnDesc = ps.rowDesc(i)
                    decodeColumnDescription(payload, columnDesc)
                    if (columnDesc.dataFormat == DataFormat.TEXT && columnDesc.dataType.supportsBinary)
                        columnDesc.dataFormat = DataFormat.BINARY
                    i += 1
                }
            case _ =>
                val columnLength = payload.readUnsignedShort
                this.rowDesc.setLength(columnLength)
                var i = 0
                while (i < columnLength) {
                    val columnDesc = this.rowDesc(i)
                    decodeColumnDescription(payload, columnDesc)
                    i += 1
                }
                rowParser.setRowDesc(rowDesc)
                continueParseRow = true
    }

    private def decodeColumnDescription(payload: Buffer, columnDesc: ColumnDesc): Unit = {
        columnDesc.name = PgBufferUtils.readCString(payload)
        columnDesc.relationId = payload.readInt
        columnDesc.relationAttributeNo = payload.readShort
        columnDesc.dataType = DataType.fromOid(payload.readInt)
        columnDesc.length = payload.readShort
        columnDesc.typeModifier = payload.readInt
        val df = payload.readUnsignedShort
        columnDesc.dataFormat = DataFormat.fromOrdinal(df)
    }

    private def decodeCommandComplete(payload: Buffer, length: Int): Unit = {
        if (payload.skipIfNextAre(CMD_COMPLETED_UPDATE)) {
            val rows = BufferUtils.readStringAsInt(payload)
            modifyRows += rows
        } else if (payload.skipIfNextAre(CMD_COMPLETED_DELETE)) {
            val rows = BufferUtils.readStringAsInt(payload)
            modifyRows += rows
        } else if (payload.skipIfNextAre(CMD_COMPLETED_INSERT)) {
            val rows = BufferUtils.readStringAsInt(payload)
            modifyRows += rows
        } else if (payload.skipIfNextAre(CMD_COMPLETED_SELECT)) {
            //
        } else if (payload.skipIfNextAre(CMD_COMPLETED_FETCH)) {
            //
        } else if (payload.skipIfNextAre(CMD_COMPLETED_COPY)) {
            //
        } else if (payload.skipIfNextAre(CMD_COMPLETED_MOVE)) {
            //
        }
    }

    private def decodeBindComplete(): Unit = continueParseRow = true

    private def decodeAuthentication(payload: Buffer, length: Int): Unit = {
        val typ = payload.readInt
        typ match
            case Constants.AUTH_TYPE_OK => successAuth()
            case Constants.AUTH_TYPE_MD5_PASSWORD =>
                val salt = new Array[Byte](4)
                payload.readBytes(salt)
                sendPasswordMessage(salt)
            case Constants.AUTH_TYPE_CLEARTEXT_PASSWORD => sendPasswordMessage(null)
            case Constants.AUTH_TYPE_SASL =>
                scramAuthentication = new ScramAuthentication(options.user, options.password)
                val msg = scramAuthentication.initialSaslMsg(payload)
                sendScramClientInitialMessage(msg)
            case Constants.AUTH_TYPE_SASL_CONTINUE =>
                sendScramClientFinalMessage(scramAuthentication.recvServerFirstMsg(payload))
                logger.debug("sasl continue send")
            case Constants.AUTH_TYPE_SASL_FINAL =>
                try {
                    scramAuthentication.checkServerFinalMsg(payload, length - 8)
                    logger.debug("sasl final")
                } catch {
                    case e: UnsupportedOperationException => error = e
                }
            case _ =>
                val error = new UnsupportedOperationException(
                  s"Authentication type $typ is not supported in the client"
                )
    }

    private def sendPasswordMessage(salt: Array[Byte] | Null): Unit = {
        val packet = ctx.outboundAdaptiveBuffer
        packet.writeByte(PASSWORD_MESSAGE)
        val pos = packet.writerOffset
        packet.writeInt(0)
        val hash =
            if (salt != null) MD5Authentication.encode(options.user, options.password, salt) else options.password
        PgBufferUtils.writeCString(packet, hash, StandardCharsets.UTF_8)
        packet.setInt(pos, packet.writerOffset - pos)

        ctx.writeAndFlush(packet)
    }

    private def sendScramClientInitialMessage(msg: ScramClientInitialMessage): Unit = {
        val packet = ctx.outboundAdaptiveBuffer
        packet.writeByte(PASSWORD_MESSAGE)
        val pos = packet.writerOffset
        packet.writeInt(0)
        PgBufferUtils.writeCString(packet, msg.mechanism, StandardCharsets.UTF_8)

        val msgPos = packet.writerOffset
        packet.writeInt(0)
        packet.writeCharSequence(msg.message, StandardCharsets.UTF_8)

        // rewind to set the message and total length
        packet.setInt(msgPos, packet.writerOffset - msgPos - Integer.BYTES)
        packet.setInt(pos, packet.writerOffset - pos)

        ctx.writeAndFlush(packet)
    }

    private def sendScramClientFinalMessage(msg: String): Unit = {
        val packet = ctx.outboundAdaptiveBuffer
        packet.writeByte(PASSWORD_MESSAGE)
        val pos = packet.writerOffset
        packet.writeInt(0)
        packet.writeCharSequence(msg, StandardCharsets.UTF_8)

        packet.setInt(pos, packet.writerOffset - pos)

        ctx.writeAndFlush(packet)
    }

    private def successAuth(): Unit = {
        status = ST_AUTHENTICATED
    }

    private def decodeParameterStatus(payload: Buffer): Unit = {
        val key   = PgBufferUtils.readCString(payload)
        val value = PgBufferUtils.readCString(payload)
        if (key == "client_encoding") encoding = value
        if (key == "server_version") metadata = PostgresDatabaseMetadata.parse(value)
    }

    private def decodeBackendKeyData(payload: Buffer): Unit = {
        processId = payload.readInt
        secretKey = payload.readInt
    }

    private def decodeErrorResponse(payload: Buffer, length: Int): Unit = {
        decodeResponse(payload, length)
        val exception = response.toExecption()
        error = exception
    }

    private def decodeResponse(payload: Buffer, length: Int): Unit = {
        var tpe: Byte = payload.readByte
        while (tpe != 0) {
            tpe match
                case Constants.ERR_OR_NOTICE_SEVERITY => response.setSeverity(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_CODE     => response.setCode(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_MESSAGE  => response.setMessage(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_DETAIL   => response.setDetail(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_HINT     => response.setHint(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_INTERNAL_POSITION =>
                    response.setInternalPosition(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_INTERNAL_QUERY =>
                    response.setInternalQuery(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_POSITION   => response.setPosition(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_WHERE      => response.setWhere(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_FILE       => response.setFile(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_LINE       => response.setLine(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_ROUTINE    => response.setRoutine(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_SCHEMA     => response.setSchema(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_TABLE      => response.setTable(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_COLUMN     => response.setColumn(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_DATA_TYPE  => response.setDataType(PgBufferUtils.readCString(payload))
                case Constants.ERR_OR_NOTICE_CONSTRAINT => response.setConstraint(PgBufferUtils.readCString(payload))
                case _                                  => payload.skipReadableBytes(payload.bytesBefore(0.toByte) + 1)

            tpe = payload.readByte
        }
    }

    private def encodePrepareQuery(query: PrepareQuery[?], output: Buffer): Unit = {
        prepareStatements.get(query.sql) match
            case Some(ps) if ps.parsed =>
                encodePrepareQueryParams(ps, query, output)
                encodeSync(output)
                ctx.writeAndFlush(output)
                output.compact()
            case None =>
                ctx.inflightFutures.setBarrierMode(true)
                compiled = false
                sendPrepareStatement(query.sql, output)
            case _ =>
    }

    private def sendPrepareStatement(sql: String, output: Buffer): Unit = {
        val statement = nextStatementName()
        val ps        = new PreparedStatement()
        ps.sql = sql
        ps.statement = statement
        prepareStatements.put(ps.sql, ps)
        encodeParse(sql, statement, output)
        encodeDescribe(statement, output)
        encodeSync(output)
        ctx.writeAndFlush(output)
    }

    private def nextStatementName(): Array[Byte] = psSeq.next()

    private def encodeParse(sql: String, statement: Array[Byte], output: Buffer): Unit = {
        output.writeByte(PARSE)
        val pos = output.writerOffset
        output.writeInt(0)
        output.writeBytes(statement)
        PgBufferUtils.writeCString(output, sql)
        // Let pg figure out the parameter types
        output.writeShort(0)
        output.setInt(pos, output.writerOffset - pos) // set packet payload length
    }

    private def encodeDescribe(statement: Array[Byte], output: Buffer): Unit = {
        output.writeByte(DESCRIBE)
        val pos = output.writerOffset
        output.writeInt(0)
        if (statement.length > 1) {
            output.writeByte('S')
            output.writeBytes(statement)
        } else {
            output.writeByte('S')
            output.writeByte(0)
        }
        output.setInt(pos, output.writerOffset - pos)
    }

    private def encodePrepareQueryParams(ps: PreparedStatement, query: PrepareQuery[?], output: Buffer): Unit = {
        query match
            case updateBatch: PrepareUpdateBatch if updateBatch.parameterCodec == null =>
                for (product <- updateBatch.params) {
                    encodeBindProduct(ps, product, output)
                    encodeExecute(0, output)
                }
            case updateBatch: PrepareUpdateBatch if updateBatch.parameterCodec != null =>
                for (product <- updateBatch.params) {
                    encodeBindTyped(ps, product, updateBatch.parameterCodec, output)
                    encodeExecute(0, output)
                }
            case fetchOneBindProduct: FetchOneBindProduct[?] =>
                encodeBindTyped(ps, fetchOneBindProduct.parm, fetchOneBindProduct.pcodec, output)
                encodeExecute(0, output)
            case fetchAllBindProduct: FetchAllBindProduct[?] =>
                encodeBindTyped(ps, fetchAllBindProduct.parm, fetchAllBindProduct.pcodec, output)
                encodeExecute(0, output)
            case prepareUpdate: PrepareUpdate if prepareUpdate.parameterCodec == null =>
                encodeBindProduct(ps, prepareUpdate.param, output)
                encodeExecute(0, output)
            case prepareUpdate: PrepareUpdate if prepareUpdate.parameterCodec != null =>
                encodeBindTyped(ps, prepareUpdate.param, prepareUpdate.parameterCodec, output)
                encodeExecute(0, output)
            case _ =>
                encodeBindPrimary(ps, query, output)
                encodeExecute(0, output)
    }

    private def encodeBindPrimary(ps: PreparedStatement, query: PrepareQuery[?], output: Buffer): Unit = {
        val pos = writeBindHead(query.parameterLength, ps, output)

        // encode param args
        query match
            case fetchOneBindInt: FetchOneBindInt[?] =>
                if (ps.paramDesc.head.supportsBinary) {
                    output.writeInt(4)
                    output.writeInt(fetchOneBindInt.parm)
                } else ???
            case fetchAllBindInt: FetchAllBindInt[?] =>
                if (ps.paramDesc.head.supportsBinary) {
                    output.writeInt(4)
                    output.writeInt(fetchAllBindInt.parm)
                } else ???
            case _ =>

        writeBindTail(ps, pos, output)
    }

    private def writeBindHead(paramSize: Short, ps: PreparedStatement, output: Buffer): Int = {
        output.writeByte(BIND)
        val pos = output.writerOffset
        output.writeInt(0)
        output.writeByte(0)
        output.writeBytes(ps.statement)
        output.writeShort(paramSize)
        for (paramDesc <- ps.paramDesc) output.writeShort(if (paramDesc.supportsBinary) 1 else 0)
        output.writeShort(ps.paramDesc.length.toShort)
        pos
    }

    private def writeBindTail(ps: PreparedStatement, packetLenPos: Int, output: Buffer): Unit = {
        // MAKE resultColumns non null to avoid null check
        // Result columns are all in Binary format
        if (ps.rowDesc.length > 0) {
            output.writeShort(ps.rowDesc.length.toShort)
            var i = 0
            while (i < ps.rowDesc.length) {
                val colDesc = ps.rowDesc(i)
                output.writeShort(if (colDesc.dataType.supportsBinary) 1 else 0)
                i += 1
            }
            // for (rowDesc <- ps.rowDesc.columns) output.writeShort(if (rowDesc.dataType.supportsBinary) 1 else 0)
        } else {
            output.writeShort(1)
            output.writeShort(1)
        }
        output.setInt(packetLenPos, output.writerOffset - packetLenPos)
    }

    private def encodeBindProduct(ps: PreparedStatement, params: Product, output: Buffer): Unit = {
        val pos = writeBindHead(params.productArity.toShort, ps, output)

        var i = 0
        while (i < params.productArity) {
            val param = params.productElement(i)
            if (param == null || param == None) { // NULL value
                output.writeInt(-1)
            } else {
                val datatype = ps.paramDesc(i)
                if (datatype.supportsBinary) {
                    val pos = output.writerOffset
                    output.writeInt(0)
                    RowValueCodec.encodeBinary(datatype, param, output)
                    output.setInt(pos, output.writerOffset - pos - 4)
                } else ???
            }
            i += 1
        }

        writeBindTail(ps, pos, output)
    }

    private def encodeBindTyped(ps: PreparedStatement, params: Product, pcodec: RowCodec[?], output: Buffer): Unit = {
        rowWriter.setPreparedStatement(ps)
        rowWriter.setBuffer(output)
        val pos = writeBindHead(params.productArity.toShort, ps, output)
        pcodec.encodeProduct(params, rowWriter)
        writeBindTail(ps, pos, output)
    }

    private def encodeExecute(fetch: Int, output: Buffer): Unit = {
        output.writeByte(EXECUTE)
        val pos = output.writerOffset
        output.writeInt(0)
        output.writeByte(0)
        output.writeInt(fetch)                        // Zero denotes "no limit" maybe for ReadStream
        output.setInt(pos, output.writerOffset - pos) // set packet length
    }

    private def encodeSync(output: Buffer): Unit = {
        output.writeByte(SYNC)
        output.writeInt(4)
    }

    private def decodeParameterDescription(payload: Buffer, length: Int): Unit = {
        val paramDataTypes = new Array[DataType](payload.readUnsignedShort)
        for (idx <- paramDataTypes.indices) {
            paramDataTypes(idx) = DataType.fromOid(payload.readInt)
        }
        val prepare = ctx.inflightFutures.first.getAsk().asInstanceOf[PrepareQuery[?]]
        val ps      = prepareStatements(prepare.sql)
        ps.paramDesc = paramDataTypes
        ps.parsed = true
    }

    override def handlerAdded(ctx: ChannelHandlerContext): Unit = {
        super.handlerAdded(ctx)
        this.ctx = ctx
        logger = LoggerFactory.getLogger(getClass, ctx.system)

        rowParser.setRowDesc(rowDesc)
        rowParser.setRowOffsets(rowOffsets)
    }

}

object PostgresDriver {

    private val ST_CONNECTING          = 0
    private val ST_AUTHENTICATING      = 1
    private val ST_CONNECTED           = 2
    private val ST_AUTHENTICATED       = 4
    private val ST_AUTHENTICATE_FAILED = 5
    private val ST_CLOSING             = 6

    private val USER: Array[Byte]     = "user".getBytes(StandardCharsets.UTF_8)
    private val DATABASE: Array[Byte] = "database".getBytes(StandardCharsets.UTF_8)

    private val PASSWORD_MESSAGE: Byte = 'p'
    private val QUERY: Byte            = 'Q'
    private val TERMINATE: Byte        = 'X'
    private val PARSE: Byte            = 'P'
    private val BIND: Byte             = 'B'
    private val DESCRIBE: Byte         = 'D'
    private val EXECUTE: Byte          = 'E'
    private val CLOSE: Byte            = 'C'
    private val SYNC: Byte             = 'S'

    private val I: Byte = 'I'
    private val T: Byte = 'T'

    private val CMD_COMPLETED_INSERT: Array[Byte] = "INSERT 0 ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_DELETE: Array[Byte] = "DELETE ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_UPDATE: Array[Byte] = "UPDATE ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_SELECT: Array[Byte] = "SELECT ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_MOVE: Array[Byte]   = "MOVE ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_FETCH: Array[Byte]  = "FETCH ".getBytes(StandardCharsets.US_ASCII)
    private val CMD_COMPLETED_COPY: Array[Byte]   = "COPY ".getBytes(StandardCharsets.US_ASCII)

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy