com.avsystem.commons.redis.protocol.messages.scala Maven / Gradle / Ivy

package com.avsystem.commons
package redis.protocol

import java.nio.ByteBuffer

import akka.util.{ByteString, ByteStringBuilder}
import com.avsystem.commons.misc.Sam
import com.avsystem.commons.redis.exception.{InvalidDataException, RedisException}
import com.avsystem.commons.redis.util.SizedArraySeqBuilder

import scala.annotation.tailrec
import scala.collection.compat._
import scala.collection.immutable.VectorBuilder

  * Raw result of executing a single [[com.avsystem.commons.redis.RawCommandPack]].
  * It may be a Redis protocol message ([[RedisMsg]]) or an object that
  * aggregates transaction results or an object that indicates failure.
sealed trait RedisReply
final case class TransactionReply(elements: IndexedSeq[RedisMsg]) extends RedisReply
trait FailureReply extends RedisReply {
  def exception: RedisException
object FailureReply {
  def apply(createException: => RedisException): FailureReply =

  * Redis protocol message. It can be sent over network from or to Redis instance.
sealed trait RedisMsg extends RedisReply
sealed trait ValidRedisMsg extends RedisMsg
case class SimpleStringMsg(string: ByteString) extends ValidRedisMsg {
  override def toString = s"$productPrefix(${RedisMsg.escape(string)})"
object SimpleStringMsg {
  def apply(str: String): SimpleStringMsg = SimpleStringMsg(ByteString(str))
final case class ErrorMsg(errorString: ByteString) extends RedisMsg {
  override def toString = s"$productPrefix(${RedisMsg.escape(errorString)})"
  lazy val errorCode: String = errorString.indexOf(' '.toByte) match {
    case -1 => errorString.utf8String
    case i => errorString.slice(0, i).utf8String
object ErrorMsg {
  def apply(str: String): ErrorMsg = ErrorMsg(ByteString(str))
final case class IntegerMsg(value: Long) extends ValidRedisMsg
case object NullBulkStringMsg extends ValidRedisMsg
sealed case class BulkStringMsg(string: ByteString) extends ValidRedisMsg {
  override def toString: String = s"$productPrefix(${RedisMsg.escape(string)})"
  def isCommandKey: Boolean = false
final class CommandKeyMsg(key: ByteString) extends BulkStringMsg(key) {
  override def isCommandKey: Boolean = true
object CommandKeyMsg {
  def apply(key: ByteString): CommandKeyMsg = new CommandKeyMsg(key)
  def unapply(keyBulkStringMsg: CommandKeyMsg): Opt[ByteString] = Opt(keyBulkStringMsg.string)
case object NullArrayMsg extends ValidRedisMsg
final case class ArrayMsg[+E <: RedisMsg](elements: IndexedSeq[E]) extends ValidRedisMsg
object ArrayMsg {
  final val Empty = ArrayMsg(IndexedSeq.empty)

object SimpleStringStr {
  def unapply(ss: SimpleStringMsg): Opt[String] =

object RedisMsg {
  final val Ok = SimpleStringMsg(ByteString("OK"))
  final val Queued = SimpleStringMsg(ByteString("QUEUED"))
  final val Nokey = SimpleStringMsg(ByteString("NOKEY"))

  def escape(bs: ByteString, quote: Boolean = true): String = {
    val sb = new StringBuilder(if (quote) "\"" else "")
    bs.foreach {
      case '\t' => sb ++= "\\r"
      case '\b' => sb ++= "\\b"
      case '\n' => sb ++= "\\n"
      case '\r' => sb ++= "\\r"
      case '\f' => sb ++= "\\f"
      case '\'' => sb ++= "\\'"
      case '\"' => sb ++= "\\"
      case '\\' => sb ++= "\\\\"
      case b if b > 0x1F && b < 0x7F => sb += b.toChar
      case b => sb ++= f"\\x$b%02x"
    if (quote) {
      sb += '\"'

  private final val CRLF = ByteString("\r\n")
  private final val NullBulk = ByteString("$-1\r\n")
  private final val NullArray = ByteString("*-1\r\n")

  private final val CRByte: Byte = '\r'
  private final val LFByte: Byte = '\n'
  private final val SimpleInd: Byte = '+'
  private final val ErrorInd: Byte = '-'
  private final val IntegerInd: Byte = ':'
  private final val BulkInd: Byte = '$'
  private final val ArrayInd: Byte = '*'

  private final val LongMinValue = ByteString(Long.MinValue.toString)

  def encodedSize(msg: RedisMsg): Int = {
    def integerSize(value: Long): Int = value match {
      case 0 => 1
      case Long.MinValue => LongMinValue.size
      case v if v < 0 => integerSize(-v) + 1
      case v =>
        @tailrec def posIntegerSize(v: Long, acc: Int): Int =
          if (v == 0) acc
          else posIntegerSize(v / 10, acc + 1)
        posIntegerSize(v, 0)

    msg match {
      case NullBulkStringMsg | NullArrayMsg => 5
      case SimpleStringMsg(data) => data.size + 3
      case ErrorMsg(data) => data.size + 3
      case IntegerMsg(value) => integerSize(value) + 3
      case BulkStringMsg(data) => integerSize(data.size) + data.size + 5
      case ArrayMsg(data) => integerSize(data.size) + data.foldLeft(0)((acc, msg) => acc + encodedSize(msg)) + 3

  def encode(msg: RedisMsg): ByteString = {
    val builder = new ByteStringBuilder
    encode(msg, builder)

  def encode(msgs: IterableOnce[RedisMsg]): ByteString = {
    val builder = new ByteStringBuilder
    msgs.iterator.foreach(encode(_, builder))

  @tailrec def encodeInteger(value: Long, bsb: ByteStringBuilder): Unit = value match {
    case 0 => bsb.putByte('0')
    case Long.MinValue => bsb.append(LongMinValue)
    case v if v < 0 => bsb.putByte('-'); encodeInteger(-v, bsb)
    case v =>
      @tailrec def encodePosInteger(value: Long, pow: Long): Unit =
        if (pow > 0) {
          bsb.putByte(('0' + (value / pow)).toByte)
          encodePosInteger(value % pow, pow / 10)
      @tailrec def maxPow10(value: Long, pow: Long): Long =
        if (value < 10) pow else maxPow10(value / 10, pow * 10)
      encodePosInteger(v, maxPow10(v, 1))

  private implicit class ByteStringBuilderOps(private val bsb: ByteStringBuilder) extends AnyVal {
    def append(value: Long): ByteStringBuilder = {
      encodeInteger(value, bsb)

  def encode(msg: RedisMsg, builder: ByteStringBuilder): Unit = {
    def encodeIn(msg: RedisMsg): Unit = msg match {
      case SimpleStringMsg(string) =>
      case ErrorMsg(errorString) =>
      case IntegerMsg(value: Long) =>
      case NullBulkStringMsg =>
      case BulkStringMsg(string) =>
      case NullArrayMsg =>
      case ArrayMsg(elements) =>

  @tailrec def encodeInteger(value: Long, bb: ByteBuffer): Unit = value match {
    case 0 => bb.put('0': Byte)
    case Long.MinValue => LongMinValue.copyToBuffer(bb)
    case v if v < 0 => bb.put('-': Byte); encodeInteger(-v, bb)
    case v =>
      @tailrec def encodePosInteger(value: Long, pow: Long): Unit =
        if (pow > 0) {
          bb.put(('0' + (value / pow)).toByte)
          encodePosInteger(value % pow, pow / 10)
      @tailrec def maxPow10(value: Long, pow: Long): Long =
        if (value < 10) pow else maxPow10(value / 10, pow * 10)
      encodePosInteger(v, maxPow10(v, 1))

  private implicit class ByteBufferOps(private val bb: ByteBuffer) extends AnyVal {
    def putNum(value: Long): ByteBuffer = {
      encodeInteger(value, bb)

    def put(bs: ByteString): ByteBuffer = {

  private final val CRLFBytes = "\r\n".getBytes
  private final val NullBulkBytes = "$-1\r\n".getBytes
  private final val NullArrayBytes = "*-1\r\n".getBytes

  def encode(msg: RedisMsg, buffer: ByteBuffer): Unit = {
    def encodeIn(msg: RedisMsg): Unit = msg match {
      case SimpleStringMsg(string) =>
      case ErrorMsg(errorString) =>
      case IntegerMsg(value: Long) =>
      case NullBulkStringMsg =>
      case BulkStringMsg(string) =>
      case NullArrayMsg =>
      case ArrayMsg(elements) =>

  def decode(bs: ByteString): Seq[RedisMsg] = {
    val builder = new VectorBuilder[RedisMsg]
    val decoder = new Decoder
    decoder.decodeMore(bs)(builder += _)

  object Decoder {
    private final val Initial = 0
    private final val ReadingSimple = 1
    private final val CREncountered = 2
    private final val StartingInt = 3
    private final val ReadingInt = 4
    private final val ReadingBulk = 5

    private final val ZeroDigitByte: Byte = '0'
    private final val NineDigitByte: Byte = '9'
    private final val MinusByte: Byte = '-'

    private class Digit(private val b: Byte) extends AnyVal {
      def isEmpty: Boolean = b < ZeroDigitByte || b > NineDigitByte
      def get: Long = b - ZeroDigitByte
    private object Digit {
      def unapply(b: Byte): Digit = new Digit(b)

  final class Decoder {

    import Decoder._

    private[this] var arrayStack: List[SizedArraySeqBuilder[RedisMsg]] = Nil
    private[this] var state: Int = Initial
    private[this] var currentType: Byte = 0
    private[this] var readingLength: Boolean = false
    private[this] var numberNegative: Boolean = false
    private[this] var numberValue: Long = 0
    private[this] val dataBuilder = new ByteStringBuilder

    def fail(msg: String) = throw new InvalidDataException(msg)

    def decodeMore(bytes: ByteString)(consumer: RedisMsg => Unit): Unit = {
      @tailrec def completed(msg: RedisMsg): Unit = {
        arrayStack match {
          case Nil => consumer(msg)
          case builder :: tail =>
            builder += msg
            if (builder.complete) {
              arrayStack = tail

      @tailrec def decode(idx: Int, prevDataStart: Int): Unit = if (idx < bytes.length) {
        val byte = bytes(idx)
        var dataStart = prevDataStart
        state match {
          case Initial =>
            currentType = byte
            byte match {
              case SimpleInd | ErrorInd =>
                state = ReadingSimple
              case IntegerInd =>
                state = StartingInt
              case BulkInd | ArrayInd =>
                state = StartingInt
                readingLength = true
              case _ => fail("Expected one of: '+', '-', ':', '$', '*'")
          case StartingInt =>
            numberValue = 0
            state = ReadingInt
            byte match {
              case MinusByte =>
                numberNegative = true
              case Digit(digitValue) =>
                numberValue = digitValue
              case _ => fail("Expected '-' sign or digit")
          case ReadingInt => byte match {
            case CRByte =>
              numberNegative = false
              state = CREncountered
            case Digit(digitValue) =>
              numberValue = numberValue * 10 + (if (numberNegative) -digitValue else digitValue)
            case _ => fail("Expected digit or CR")
          case ReadingSimple =>
            if (dataStart < 0) {
              dataStart = idx
            byte match {
              case CRByte =>
                dataBuilder.append(bytes.slice(dataStart, idx))
                dataStart = -1
                state = CREncountered
              case LFByte => fail("LF not allowed in simple string message")
              case _ =>
          case ReadingBulk =>
            if (dataStart < 0) {
              dataStart = idx
            if (dataBuilder.length + idx - dataStart == numberValue) {
              if (byte == CRByte) {
                dataBuilder.append(bytes.slice(dataStart, idx))
                dataStart = -1
                state = CREncountered
              } else fail("Expected CR at the end of bulk string message")
          case CREncountered => byte match {
            case LFByte if readingLength =>
              readingLength = false
              currentType match {
                case BulkInd =>
                  numberValue match {
                    case -1 =>
                      state = Initial
                    case size if size >= 0 =>
                      state = ReadingBulk
                    case _ => fail("Invalid bulk string length")
                case ArrayInd =>
                  state = Initial
                  numberValue match {
                    case -1 => completed(NullArrayMsg)
                    case 0 => completed(ArrayMsg.Empty)
                    case size if size > 0 =>
                      val is = size.toInt
                      arrayStack = new SizedArraySeqBuilder[RedisMsg](is) :: arrayStack
                    case _ => fail("Invalid array size")
                case _ => fail("Length can be read only for bulk strings or arrays")
            case LFByte =>
              def extractData() = {
                val res = dataBuilder.result()
              val msg = currentType match {
                case SimpleInd => SimpleStringMsg(extractData())
                case ErrorInd => ErrorMsg(extractData())
                case BulkInd => BulkStringMsg(extractData())
                case IntegerInd => IntegerMsg(numberValue)
              state = Initial
            case _ => fail("Expected LF after CR")
        decode(idx + 1, dataStart)
      } else state match {
        case ReadingSimple | ReadingBulk if prevDataStart >= 0 =>
        case _ =>
      decode(0, -1)

