All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cequence.wsclient.service.ws.stream.WSStreamRequestHelper.scala Maven / Gradle / Ivy

There is a newer version: 0.4.2
Show newest version
package io.cequence.wsclient.service.ws.stream

import akka.NotUsed
import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import akka.stream.Materializer
import akka.stream.scaladsl.Framing.FramingException
import akka.stream.scaladsl.{Flow, Framing, Source}
import akka.util.ByteString
import com.fasterxml.jackson.core.JsonParseException
import io.cequence.wsclient.domain.{
  CequenceWSException,
  CequenceWSTimeoutException,
  CequenceWSUnknownHostException
}
import io.cequence.wsclient.service.ws.WSRequestHelper
import play.api.libs.json.{JsObject, JsString, JsValue, Json}
import play.api.libs.ws.JsonBodyWritables._

import java.net.UnknownHostException
import java.util.concurrent.TimeoutException

/**
 * Stream request support specifically tailored for OpenAI API.
 *
 * @since Feb
 *   2023
 */
trait WSStreamRequestHelper extends WSRequestHelper {

  private val itemPrefix = "data: "
  private val endOfStreamToken = "[DONE]"

  private implicit val jsonMarshaller: Unmarshaller[ByteString, JsValue] =
    Unmarshaller.strict[ByteString, JsValue] { byteString =>
      val string = byteString.utf8String

      val itemStartIndex = string.indexOf(itemPrefix)
      val data =
        if (itemStartIndex > -1)
          string.substring(itemStartIndex + itemPrefix.length)
        else
          string
      if (data.equals(endOfStreamToken)) JsString(endOfStreamToken)
      else Json.parse(data)
    }

  protected def execJsonStreamAux(
    endPoint: PEP,
    method: String,
    endPointParam: Option[String] = None,
    params: Seq[(PT, Option[Any])] = Nil,
    bodyParams: Seq[(PT, Option[JsValue])] = Nil
  )(
    implicit materializer: Materializer
  ): Source[JsValue, NotUsed] = {
    val source = execStreamRequestAux[JsValue](
      endPoint,
      method,
      endPointParam,
      params,
      bodyParams,
      Framing.delimiter(ByteString("\n\n"), 1000, allowTruncation = true),
      {
        case e: JsonParseException =>
          throw new CequenceWSException(
            s"$serviceName.$endPoint: 'Response is not a JSON. ${e.getMessage}."
          )
        case e: FramingException =>
          throw new CequenceWSException(
            s"$serviceName.$endPoint: 'Response is not a JSON. ${e.getMessage}."
          )
      }
    )

    // take until you encounter the end of stream marked with DONE
    source.takeWhile(_ != JsString(endOfStreamToken))
  }

  protected def execStreamRequestAux[T](
    endPoint: PEP,
    method: String,
    endPointParam: Option[String],
    params: Seq[(PT, Option[Any])],
    bodyParams: Seq[(PT, Option[JsValue])],
    framing: Flow[ByteString, ByteString, NotUsed],
    recoverBlock: PartialFunction[Throwable, T]
  )(
    implicit um: Unmarshaller[ByteString, T],
    materializer: Materializer
  ): Source[T, NotUsed] = {
    val request = getWSRequestOptional(Some(endPoint), endPointParam, toStringParams(params))

    val requestWithBody = if (bodyParams.nonEmpty) {
      val bodyParamsX = bodyParams.collect { case (fieldName, Some(jsValue)) =>
        (fieldName.toString, jsValue)
      }
      request.withBody(JsObject(bodyParamsX))
    } else
      request

    val source =
      requestWithBody.withMethod(method).stream().map { response =>
        response.bodyAsSource
          .via(framing)
          .mapAsync(1)(bytes => Unmarshal(bytes).to[T]) // unmarshal one by one
          .recover {
            case e: TimeoutException =>
              throw new CequenceWSTimeoutException(
                s"$serviceName.$endPoint timed out: ${e.getMessage}."
              )
            case e: UnknownHostException =>
              throw new CequenceWSUnknownHostException(
                s"$serviceName.$endPoint cannot resolve a host name: ${e.getMessage}."
              )
          }
          .recover(recoverBlock) // extra recover
      }

    Source.fromFutureSource(source).mapMaterializedValue(_ => NotUsed)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy