io.cequence.wsclient.service.ws.stream.WSStreamRequestHelper.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ws-client-stream_2.13 Show documentation
Show all versions of ws-client-stream_2.13 Show documentation
Generic Play WebServices library
package io.cequence.wsclient.service.ws.stream
import akka.NotUsed
import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import akka.stream.Materializer
import akka.stream.scaladsl.Framing.FramingException
import akka.stream.scaladsl.{Flow, Framing, Source}
import akka.util.ByteString
import com.fasterxml.jackson.core.JsonParseException
import io.cequence.wsclient.domain.{
CequenceWSException,
CequenceWSTimeoutException,
CequenceWSUnknownHostException
}
import io.cequence.wsclient.service.ws.WSRequestHelper
import play.api.libs.json.{JsObject, JsString, JsValue, Json}
import play.api.libs.ws.JsonBodyWritables._
import java.net.UnknownHostException
import java.util.concurrent.TimeoutException
/**
* Stream request support specifically tailored for OpenAI API.
*
* @since Feb
* 2023
*/
trait WSStreamRequestHelper extends WSRequestHelper {
private val itemPrefix = "data: "
private val endOfStreamToken = "[DONE]"
private implicit val jsonMarshaller: Unmarshaller[ByteString, JsValue] =
Unmarshaller.strict[ByteString, JsValue] { byteString =>
val string = byteString.utf8String
val itemStartIndex = string.indexOf(itemPrefix)
val data =
if (itemStartIndex > -1)
string.substring(itemStartIndex + itemPrefix.length)
else
string
if (data.equals(endOfStreamToken)) JsString(endOfStreamToken)
else Json.parse(data)
}
protected def execJsonStreamAux(
endPoint: PEP,
method: String,
endPointParam: Option[String] = None,
params: Seq[(PT, Option[Any])] = Nil,
bodyParams: Seq[(PT, Option[JsValue])] = Nil
)(
implicit materializer: Materializer
): Source[JsValue, NotUsed] = {
val source = execStreamRequestAux[JsValue](
endPoint,
method,
endPointParam,
params,
bodyParams,
Framing.delimiter(ByteString("\n\n"), 1000, allowTruncation = true),
{
case e: JsonParseException =>
throw new CequenceWSException(
s"$serviceName.$endPoint: 'Response is not a JSON. ${e.getMessage}."
)
case e: FramingException =>
throw new CequenceWSException(
s"$serviceName.$endPoint: 'Response is not a JSON. ${e.getMessage}."
)
}
)
// take until you encounter the end of stream marked with DONE
source.takeWhile(_ != JsString(endOfStreamToken))
}
protected def execStreamRequestAux[T](
endPoint: PEP,
method: String,
endPointParam: Option[String],
params: Seq[(PT, Option[Any])],
bodyParams: Seq[(PT, Option[JsValue])],
framing: Flow[ByteString, ByteString, NotUsed],
recoverBlock: PartialFunction[Throwable, T]
)(
implicit um: Unmarshaller[ByteString, T],
materializer: Materializer
): Source[T, NotUsed] = {
val request = getWSRequestOptional(Some(endPoint), endPointParam, toStringParams(params))
val requestWithBody = if (bodyParams.nonEmpty) {
val bodyParamsX = bodyParams.collect { case (fieldName, Some(jsValue)) =>
(fieldName.toString, jsValue)
}
request.withBody(JsObject(bodyParamsX))
} else
request
val source =
requestWithBody.withMethod(method).stream().map { response =>
response.bodyAsSource
.via(framing)
.mapAsync(1)(bytes => Unmarshal(bytes).to[T]) // unmarshal one by one
.recover {
case e: TimeoutException =>
throw new CequenceWSTimeoutException(
s"$serviceName.$endPoint timed out: ${e.getMessage}."
)
case e: UnknownHostException =>
throw new CequenceWSUnknownHostException(
s"$serviceName.$endPoint cannot resolve a host name: ${e.getMessage}."
)
}
.recover(recoverBlock) // extra recover
}
Source.fromFutureSource(source).mapMaterializedValue(_ => NotUsed)
}
}