All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.mauricio.async.db.postgresql.PostgreSQLConnection.scala Maven / Gradle / Ivy

/*
 * Copyright 2013 Maurício Linhares
 *
 * Maurício Linhares licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package com.github.mauricio.async.db.postgresql

import com.github.mauricio.async.db.QueryResult
import com.github.mauricio.async.db.column.{
  ColumnDecoderRegistry,
  ColumnEncoderRegistry
}
import com.github.mauricio.async.db.exceptions.{
  ConnectionStillRunningQueryException,
  InsufficientParametersException
}
import com.github.mauricio.async.db.general.ResultSetBuilder
import com.github.mauricio.async.db.pool.TimeoutScheduler
import com.github.mauricio.async.db.postgresql.codec.{
  PostgreSQLConnectionDelegate,
  PostgreSQLConnectionHandler
}
import com.github.mauricio.async.db.postgresql.column.{
  PostgreSQLColumnDecoderRegistry,
  PostgreSQLColumnEncoderRegistry
}
import com.github.mauricio.async.db.postgresql.exceptions._
import com.github.mauricio.async.db.util._
import com.github.mauricio.async.db.{Configuration, Connection}
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference}

import messages.backend._
import messages.frontend._

import scala.concurrent._
import io.netty.channel.EventLoopGroup
import java.util.concurrent.CopyOnWriteArrayList

import com.github.mauricio.async.db.postgresql.util.URLParser
import scala.collection.compat.immutable.ArraySeq

object PostgreSQLConnection {
  final val Counter          = new AtomicLong()
  final val ServerVersionKey = "server_version"
  final val log              = Log.get[PostgreSQLConnection]
}

class PostgreSQLConnection(
  configuration: Configuration = URLParser.DEFAULT,
  encoderRegistry: ColumnEncoderRegistry =
    PostgreSQLColumnEncoderRegistry.Instance,
  decoderRegistry: ColumnDecoderRegistry =
    PostgreSQLColumnDecoderRegistry.Instance,
  group: EventLoopGroup = NettyUtils.DefaultEventLoopGroup,
  implicit val executionContext: ExecutionContext =
    ExecutorServiceUtils.SameThread
) extends PostgreSQLConnectionDelegate
    with Connection
    with TimeoutScheduler {

  import PostgreSQLConnection._

  private final val connectionHandler = new PostgreSQLConnectionHandler(
    configuration,
    encoderRegistry,
    decoderRegistry,
    this,
    group,
    executionContext
  )

  private final val currentCount              = Counter.incrementAndGet()
  private final val preparedStatementsCounter = new AtomicInteger()

  private val parameterStatus =
    new scala.collection.mutable.HashMap[String, String]()
  private val parsedStatements =
    new scala.collection.mutable.HashMap[String, PreparedStatementHolder]()
  private var authenticated = false

  private val connectionFuture = Promise[Connection]()

  private var recentError = false
  private val queryPromiseReference =
    new AtomicReference[Option[Promise[QueryResult]]](None)
  private var currentQuery: Option[ResultSetBuilder[PostgreSQLColumnData]] =
    None
  private var currentPreparedStatement: Option[PreparedStatementHolder] = None
  private var version = Version(0, 0, 0)
  private var notifyListeners =
    new CopyOnWriteArrayList[NotificationResponse => Unit]()

  private var queryResult: Option[QueryResult] = None

  override def eventLoopGroup: EventLoopGroup = group
  def isReadyForQuery: Boolean                = this.queryPromise.isEmpty

  def connect: Future[Connection] = {
    this.connectionHandler.connect.failed.foreach { case e =>
      this.connectionFuture.tryFailure(e)
    }

    this.connectionFuture.future
  }

  override def disconnect: Future[Connection] =
    this.connectionHandler.disconnect.map(c => this)
  override def onTimeout = disconnect

  override def isConnected: Boolean = this.connectionHandler.isConnected

  def parameterStatuses: scala.collection.immutable.Map[String, String] =
    this.parameterStatus.toMap

  override def sendQuery(query: String): Future[QueryResult] = {
    validateQuery(query)

    val promise = Promise[QueryResult]()
    this.setQueryPromise(promise)

    write(new QueryMessage(query))
    addTimeout(promise, configuration.queryTimeout)
    promise.future
  }

  override def sendPreparedStatement(
    query: String,
    values: Seq[Any] = List()
  ): Future[QueryResult] = {
    validateQuery(query)

    val promise = Promise[QueryResult]()
    this.setQueryPromise(promise)

    val holder = this.parsedStatements.getOrElseUpdate(
      query,
      new PreparedStatementHolder(
        query,
        preparedStatementsCounter.incrementAndGet
      )
    )

    if (holder.paramsCount != values.length) {
      this.clearQueryPromise
      throw new InsufficientParametersException(holder.paramsCount, values)
    }

    this.currentPreparedStatement = Some(holder)
    this.currentQuery = Some(
      new ResultSetBuilder(ArraySeq.unsafeWrapArray(holder.columnDatas))
    )
    write(
      if (holder.prepared)
        new PreparedStatementExecuteMessage(
          holder.statementId,
          holder.realQuery,
          values,
          this.encoderRegistry
        )
      else {
        holder.prepared = true
        new PreparedStatementOpeningMessage(
          holder.statementId,
          holder.realQuery,
          values,
          this.encoderRegistry
        )
      }
    )
    addTimeout(promise, configuration.queryTimeout)
    promise.future
  }

  override def onError(exception: Throwable): Unit = {
    this.setErrorOnFutures(exception)
  }

  def hasRecentError: Boolean = this.recentError

  private def setErrorOnFutures(e: Throwable): Unit = {
    this.recentError = true

    log.error("Error on connection", e)

    if (!this.connectionFuture.isCompleted) {
      this.connectionFuture.failure(e)
      this.disconnect
    }

    this.currentPreparedStatement
      .map(p => this.parsedStatements.remove(p.query))
    this.currentPreparedStatement = None
    this.failQueryPromise(e)
  }

  override def onReadyForQuery(): Unit = {
    this.connectionFuture.trySuccess(this)

    this.recentError = false
    queryResult.foreach(this.succeedQueryPromise)
  }

  override def onError(m: ErrorMessage): Unit = {
    log.error("Error with message -> {}", m)

    val error = new GenericDatabaseException(m)
    error.fillInStackTrace()

    this.setErrorOnFutures(error)
  }

  override def onCommandComplete(m: CommandCompleteMessage): Unit = {
    this.currentPreparedStatement = None
    queryResult = Some(
      new QueryResult(
        m.rowsAffected,
        m.statusMessage,
        this.currentQuery.map(_.build())
      )
    )
  }

  override def onParameterStatus(m: ParameterStatusMessage): Unit = {
    this.parameterStatus.put(m.key, m.value)
    if (ServerVersionKey == m.key) {
      this.version = Version(m.value)
    }
  }

  override def onDataRow(m: DataRowMessage): Unit = {
    val items = new Array[Any](m.values.size)
    var x     = 0

    while (x < m.values.size) {
      val buf = m.values(x)
      items(x) = if (buf == null) {
        null
      } else {
        try {
          val columnType = this.currentQuery.get.columnTypes(x)
          this.decoderRegistry.decode(columnType, buf, configuration.charset)
        } finally {
          buf.release()
        }
      }
      x += 1
    }

    this.currentQuery.get.addRow(items)
  }

  override def onRowDescription(m: RowDescriptionMessage): Unit = {
    this.currentQuery = Option(
      new ResultSetBuilder(ArraySeq.unsafeWrapArray(m.columnDatas))
    )
    this.setColumnDatas(m.columnDatas)
  }

  private def setColumnDatas(columnDatas: Array[PostgreSQLColumnData]): Unit = {
    this.currentPreparedStatement.foreach { holder =>
      holder.columnDatas = columnDatas
    }
  }

  override def onAuthenticationResponse(
    message: AuthenticationMessage
  ): Unit = {

    message match {
      case m: AuthenticationOkMessage => {
        log.debug("Successfully logged in to database")
        this.authenticated = true
      }
      case m: AuthenticationChallengeCleartextMessage => {
        write(this.credential(m))
      }
      case m: AuthenticationChallengeMD5 => {
        write(this.credential(m))
      }
    }

  }

  override def onNotificationResponse(message: NotificationResponse): Unit = {
    val iterator = this.notifyListeners.iterator()
    while (iterator.hasNext) {
      iterator.next().apply(message)
    }
  }

  def registerNotifyListener(listener: NotificationResponse => Unit): Unit = {
    this.notifyListeners.add(listener)
  }

  def unregisterNotifyListener(listener: NotificationResponse => Unit): Unit = {
    this.notifyListeners.remove(listener)
  }

  def clearNotifyListeners(): Unit = {
    this.notifyListeners.clear()
  }

  private def credential(
    authenticationMessage: AuthenticationChallengeMessage
  ): CredentialMessage = {
    if (configuration.username != null && configuration.password.isDefined) {
      new CredentialMessage(
        configuration.username,
        configuration.password.get,
        authenticationMessage.challengeType,
        authenticationMessage.salt
      )
    } else {
      throw new MissingCredentialInformationException(
        this.configuration.username,
        this.configuration.password,
        authenticationMessage.challengeType
      )
    }
  }

  private[this] def notReadyForQueryError(
    errorMessage: String,
    race: Boolean
  ) = {
    log.error(errorMessage)
    throw new ConnectionStillRunningQueryException(
      this.currentCount,
      race
    )
  }

  def validateIfItIsReadyForQuery(errorMessage: String) =
    if (this.queryPromise.isDefined)
      notReadyForQueryError(errorMessage, false)

  private def validateQuery(query: String): Unit = {
    this.validateIfItIsReadyForQuery(
      "Can't run query because there is one query pending already"
    )

    if (query == null || query.isEmpty) {
      throw new QueryMustNotBeNullOrEmptyException(query)
    }
  }

  private def queryPromise: Option[Promise[QueryResult]] =
    queryPromiseReference.get()

  private def setQueryPromise(promise: Promise[QueryResult]): Unit = {
    if (!this.queryPromiseReference.compareAndSet(None, Some(promise)))
      notReadyForQueryError(
        "Can't run query due to a race with another started query",
        true
      )
  }

  private def clearQueryPromise: Option[Promise[QueryResult]] = {
    this.queryPromiseReference.getAndSet(None)
  }

  private def failQueryPromise(t: Throwable): Unit = {
    this.clearQueryPromise.foreach { promise =>
      log.error("Setting error on future {}", promise)
      promise.failure(t)
    }
  }

  private def succeedQueryPromise(result: QueryResult): Unit = {
    this.queryResult = None
    this.currentQuery = None
    this.clearQueryPromise.foreach {
      _.success(result)
    }
  }

  private def write(message: ClientMessage): Unit = {
    this.connectionHandler.write(message)
  }

  override def toString: String = {
    s"${this.getClass.getSimpleName}{counter=${this.currentCount}}"
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy