All Downloads are FREE. Search and download functionalities are using the official Maven repository.

zio.http.StreamingForm.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2021 - 2023 Sporta Technologies PVT LTD & the ZIO HTTP contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package zio.http

import java.nio.charset.Charset

import scala.annotation.tailrec

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.stream.{Take, ZStream}

import zio.http.StreamingForm.Buffer
import zio.http.internal.{FormAST, FormState}

final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: Boundary, bufferSize: Int = 8192) {
  def charset: Charset = boundary.charset

  /**
   * Runs the streaming form and collects all parts in memory, returning a Form
   */
  def collectAll(implicit trace: Trace): ZIO[Any, Throwable, Form] =
    streamFormFields(bufferUpToBoundary = true).mapZIO {
      case sb: FormField.StreamingBinary =>
        sb.collect
      case other: FormField              =>
        ZIO.succeed(other)
    }.runCollect.map { formData =>
      Form(formData)
    }

  def fields(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
    streamFormFields(bufferUpToBoundary = false)

  private def streamFormFields(
    bufferUpToBoundary: Boolean,
  )(implicit trace: Trace): ZStream[Any, Throwable, FormField] =
    ZStream.unwrapScoped {
      implicit val unsafe: Unsafe = Unsafe.unsafe

      for {
        runtime    <- ZIO.runtime[Any]
        buffer     <- ZIO.succeed(new Buffer(bufferSize, crlfBoundary, bufferUpToBoundary))
        abort      <- Promise.make[Nothing, Unit]
        fieldQueue <- Queue.bounded[Take[Throwable, FormField]](4)
        state  = initialState
        reader =
          source.runForeachChunk { bytes =>
            def handleBoundary(ast: Chunk[FormAST]): Option[FormField] =
              if (state.inNonStreamingPart) {
                FormField.fromFormAST(ast, charset) match {
                  case Right(formData) =>
                    buffer.reset()
                    state.reset
                    Some(formData)
                  case Left(e)         => throw e.asException
                }
              } else {
                buffer.reset()
                state.reset
                None
              }

            def handleByte(byte: Byte, isLastByte: Boolean): Option[FormField] = {
              state.formState match {
                case formState: FormState.FormStateBuffer =>
                  val nextFormState = formState.append(byte)
                  state.currentQueue match {
                    case Some(queue) =>
                      val takes = buffer.addByte(byte, isLastByte)
                      if (takes.nonEmpty) {
                        runtime.unsafe.run(queue.offerAll(takes).raceFirst(abort.await)).getOrThrowFiberFailure()
                      }
                    case None        =>
                  }
                  nextFormState match {
                    case newFormState: FormState.FormStateBuffer =>
                      if (
                        state.currentQueue.isEmpty &&
                        (newFormState.phase eq FormState.Phase.Part2) &&
                        !state.inNonStreamingPart
                      ) {
                        val contentType = FormField.getContentType(newFormState.tree)
                        if (contentType.binary) {
                          runtime.unsafe.run {
                            for {
                              newQueue <- Queue.bounded[Take[Nothing, Byte]](3)
                              _ <- newQueue.offer(Take.chunk(newFormState.tree.collect { case FormAST.Content(bytes) =>
                                bytes
                              }.flatten))
                              streamingFormData <- FormField
                                .incomingStreamingBinary(newFormState.tree, newQueue)
                                .mapError(_.asException)
                              _ = state.withCurrentQueue(newQueue)
                            } yield Some(streamingFormData)
                          }.getOrThrowFiberFailure()
                        } else {
                          val _ = state.withInNonStreamingPart(true)
                          None
                        }
                      } else {
                        None
                      }
                    case FormState.BoundaryEncapsulated(ast)     =>
                      handleBoundary(ast)
                    case FormState.BoundaryClosed(ast)           =>
                      handleBoundary(ast)
                  }
                case _                                    =>
                  None
              }
            }

            val builder = Chunk.newBuilder[FormField]
            val it      = bytes.iterator
            var hasNext = it.hasNext
            while (hasNext) {
              val byte = it.next()
              hasNext = it.hasNext
              handleByte(byte, !hasNext) match {
                case Some(field) => builder += field
                case _           => ()
              }
            }
            val fields  = builder.result()
            fieldQueue.offer(Take.chunk(fields)).when(fields.nonEmpty)
          }
        // FIXME: .blocking here is temporary until we figure out a better way to avoid running effects within mapAccumImmediate
        _ <- ZIO
          .blocking(reader)
          .catchAllCause(cause => fieldQueue.offer(Take.failCause(cause)))
          .ensuring(fieldQueue.offer(Take.end))
          .forkScoped
          .interruptible
        _ <- Scope.addFinalizerExit { exit =>
          // If the fieldStream fails, we need to make sure the reader stream can be interrupted, as it may be blocked
          // in the unsafe.run(queue.offer) call (interruption does not propagate into the unsafe.run). This is implemented
          // by setting the abort promise which is raced within the unsafe run when offering the element to the queue.
          abort.succeed(()).when(exit.isFailure)
        }
        fieldStream = ZStream.fromQueue(fieldQueue).flattenTake
      } yield fieldStream
    }

  private def initialState: StreamingForm.State =
    StreamingForm.initialState(boundary)

  private def crlfBoundary: Array[Byte] = Array[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes.toArray[Byte]
}

object StreamingForm {
  private final class State(
    val formState: FormState,
    private var _currentQueue: Option[Queue[Take[Nothing, Byte]]],
    private var _inNonStreamingPart: Boolean,
  ) {
    def currentQueue: Option[Queue[Take[Nothing, Byte]]] = _currentQueue
    def inNonStreamingPart: Boolean                      = _inNonStreamingPart

    def withCurrentQueue(queue: Queue[Take[Nothing, Byte]]): State = {
      _currentQueue = Some(queue)
      this
    }

    def withInNonStreamingPart(value: Boolean): State = {
      _inNonStreamingPart = value
      this
    }

    def reset: State = {
      _currentQueue = None
      _inNonStreamingPart = false
      formState.reset()
      this
    }
  }

  private def initialState(boundary: Boundary): State = {
    new State(FormState.fromBoundary(boundary), None, _inNonStreamingPart = false)
  }

  private final class Buffer(bufferSize: Int, crlfBoundary: Array[Byte], bufferUpToBoundary: Boolean) {
    private var buffer: Array[Byte] = Array.ofDim(bufferSize)
    private var index: Int          = 0
    private val boundarySize        = crlfBoundary.length

    private def ensureHasCapacity(requiredCapacity: Int): Unit = {
      @tailrec
      def calculateNewCapacity(existing: Int, required: Int): Int = {
        val newCap = existing * 2
        if (newCap < required) calculateNewCapacity(newCap, required)
        else newCap
      }

      val l = buffer.length
      if (l <= requiredCapacity) {
        val newArray = Array.ofDim[Byte](calculateNewCapacity(l, requiredCapacity))
        java.lang.System.arraycopy(buffer, 0, newArray, 0, l)
        buffer = newArray
      } else ()
    }

    private def matchesPartialBoundary(idx: Int): Boolean = {
      val bs     = boundarySize
      var i      = 0
      var result = false
      while (i < bs && i <= idx && !result) {
        val i0 = idx - i
        var i1 = 0
        while (i >= i1 && buffer(i0 + i1) == crlfBoundary(i1) && !result) {
          if (i == i1) result = true
          i1 += 1
        }
        i += 1
      }
      result
    }

    def addByte(byte: Byte, isLastByte: Boolean): Chunk[Take[Nothing, Byte]] = {
      val idx = index
      ensureHasCapacity(idx + boundarySize + 1)
      buffer(idx) = byte
      index += 1

      var i                 = 0
      var foundFullBoundary = idx >= boundarySize - 1
      while (i < boundarySize && foundFullBoundary) {
        if (buffer(idx + 1 - crlfBoundary.length + i) != crlfBoundary(i)) {
          foundFullBoundary = false
        }
        i += 1
      }

      if (foundFullBoundary) {
        reset()
        val toTake = idx + 1 - boundarySize
        if (toTake == 0) Chunk(Take.end)
        else Chunk(Take.chunk(Chunk.fromArray(buffer.take(toTake))), Take.end)
      } else if (!bufferUpToBoundary && isLastByte && byte != '-' && !matchesPartialBoundary(idx)) {
        reset()
        Chunk(Take.chunk(Chunk.fromArray(buffer.take(idx + 1))))
      } else {
        Chunk.empty
      }
    }

    def reset(): Unit =
      index = 0
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy