All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zio.redis.internal.RespValue.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 John A. De Goes and the ZIO contributors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package zio.redis.internal

import zio._
import zio.redis.options.Cluster.Slot
import zio.redis.{RedisError, RedisUri}
import zio.stream._

import java.nio.charset.StandardCharsets

private[redis] sealed trait RespValue extends Product with Serializable { self =>
  import RespValue._
  import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded}

  final def asBytes: Chunk[Byte] =
    self match {
      case NullBulkString  => NullStringEncoded
      case NullArray       => NullArrayEncoded
      case SimpleString(s) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.SimpleString
        builder ++= encode(s)
        builder ++= CrLf

        builder.result()

      case Error(s) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Error
        builder ++= encode(s)
        builder ++= CrLf

        builder.result()

      case Integer(i) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Integer
        builder ++= encode(i.toString)
        builder ++= CrLf

        builder.result()

      case BulkString(bytes) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.BulkString
        builder ++= encode(bytes.length.toString)
        builder ++= CrLf
        builder ++= bytes
        builder ++= CrLf

        builder.result()

      case Array(elements) =>
        val builder = new ChunkBuilder.Byte()

        builder += Headers.Array
        builder ++= encode(elements.size.toString)
        builder ++= CrLf

        elements.foreach(builder ++= _.asBytes)

        builder.result()
    }

  private[this] def encode(s: String) = s.getBytes(StandardCharsets.UTF_8)
}

private[redis] object RespValue {
  final case class SimpleString(value: String) extends RespValue

  final case class Error(value: String) extends RespValue {
    def asRedisError: RedisError =
      if (value.startsWith("ERR")) RedisError.ProtocolError(value.drop(3).trim)
      else if (value.startsWith("WRONGTYPE")) RedisError.WrongType(value.drop(9).trim)
      else if (value.startsWith("BUSYGROUP")) RedisError.BusyGroup(value.drop(9).trim)
      else if (value.startsWith("NOGROUP")) RedisError.NoGroup(value.drop(7).trim)
      else if (value.startsWith("NOSCRIPT")) RedisError.NoScript(value.drop(8).trim)
      else if (value.startsWith("NOTBUSY")) RedisError.NotBusy(value.drop(7).trim)
      else if (value.startsWith("CROSSSLOT")) RedisError.CrossSlot(value.drop(9).trim)
      else if (value.startsWith("ASK")) RedisError.Ask(parseRedirectError(value))
      else if (value.startsWith("MOVED")) RedisError.Moved(parseRedirectError(value))
      else RedisError.ProtocolError(value.trim)

    private def parseRedirectError(value: String) = {
      val splittingError = value.split(' ')
      (Slot(splittingError(1).toLong), RedisUri(splittingError(2)))
    }
  }

  final case class Integer(value: Long) extends RespValue

  final case class BulkString(value: Chunk[Byte]) extends RespValue {
    def asLong: Long = internal.unsafeReadLong(value, 0)

    def asString: String = internal.decode(value)
  }

  final case class Array(values: Chunk[RespValue]) extends RespValue

  case object NullBulkString extends RespValue

  case object NullArray extends RespValue

  object ArrayValues {
    def unapplySeq(v: RespValue): Option[Seq[RespValue]] =
      v match {
        case Array(values) => Some(values)
        case _             => None
      }
  }

  final val Decoder: ZPipeline[Any, RedisError.ProtocolError, Byte, Option[RespValue]] = {
    import internal.State

    // ZSink fold will return a State.Start when contFn is false
    val lineProcessor =
      ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
        case State.Done(value) => ZIO.some(value)
        case State.Failed      => ZIO.fail(RedisError.ProtocolError("Invalid data received."))
        case State.Start       => ZIO.none
        case other             => ZIO.dieMessage(s"Deserialization bug, should not get $other")
      }

    ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor)
  }

  def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))

  def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))

  private object internal {
    object Headers {
      final val SimpleString: Byte = '+'
      final val Error: Byte        = '-'
      final val Integer: Byte      = ':'
      final val BulkString: Byte   = '$'
      final val Array: Byte        = '*'
    }

    final val CrLf: Chunk[Byte]              = Chunk('\r', '\n')
    final val NullArrayEncoded: Chunk[Byte]  = Chunk('*', '-', '1', '\r', '\n')
    final val NullArrayPrefix: Chunk[Byte]   = Chunk('*', '-', '1')
    final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n')
    final val NullStringPrefix: Chunk[Byte]  = Chunk('$', '-', '1')

    sealed trait State { self =>
      import State._

      final def inProgress: Boolean =
        self match {
          case Done(_) | Failed => false
          case _                => true
        }

      final def feed(bytes: Chunk[Byte]): State =
        self match {
          case Start if bytes.isEmpty             => Start
          case Start if bytes == NullStringPrefix => Done(NullBulkString)
          case Start if bytes == NullArrayPrefix  => Done(NullArray)

          case Start if bytes.nonEmpty =>
            bytes.head match {
              case Headers.SimpleString => Done(SimpleString(decode(bytes.tail)))
              case Headers.Error        => Done(Error(decode(bytes.tail)))
              case Headers.Integer      => Done(Integer(unsafeReadLong(bytes, 1)))
              case Headers.BulkString   =>
                val size = unsafeReadSize(bytes)
                CollectingBulkString(size, ChunkBuilder.make(size))
              case Headers.Array        =>
                val size = unsafeReadSize(bytes)

                if (size > 0)
                  CollectingArray(size, ChunkBuilder.make(size), Start.feed)
                else
                  Done(Array(Chunk.empty))

              case _ => Failed
            }

          case CollectingArray(rem, vals, next) =>
            next(bytes) match {
              case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
              case Done(v)            => Done(Array((vals += v).result()))
              case state              => CollectingArray(rem, vals, state.feed)
            }

          case CollectingBulkString(rem, vals) =>
            if (bytes.length >= rem) {
              vals ++= bytes.take(rem)
              Done(BulkString(vals.result()))
            } else {
              vals ++= bytes
              vals ++= CrLf
              CollectingBulkString(rem - bytes.length - 2, vals)
            }

          case _ => Failed
        }
    }

    object State {
      case object Start  extends State
      case object Failed extends State

      final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State)
          extends State

      final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State

      final case class Done(value: RespValue) extends State
    }

    def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

    def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = {
      var pos = startFrom
      var res = 0L
      var neg = false

      if (bytes(pos) == '-') {
        neg = true
        pos += 1
      }

      val len = bytes.length

      while (pos < len) {
        res = res * 10 + bytes(pos) - '0'
        pos += 1
      }

      if (neg) -res else res
    }

    def unsafeReadSize(bytes: Chunk[Byte]): Int = {
      var pos = 1
      var res = 0
      val len = bytes.length

      while (pos < len) {
        res = res * 10 + bytes(pos) - '0'
        pos += 1
      }

      res
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy