All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.softwaremill.sttp.okhttp.OkHttpClientHandler.scala Maven / Gradle / Ivy

package com.softwaremill.sttp.okhttp

import java.io.IOException
import java.nio.charset.Charset
import java.util.concurrent.TimeUnit

import com.softwaremill.sttp._
import ResponseAs.EagerResponseHandler
import okhttp3.internal.http.HttpMethod
import okhttp3.{
  Call,
  Callback,
  Headers,
  MediaType,
  OkHttpClient,
  MultipartBody => OkHttpMultipartBody,
  Request => OkHttpRequest,
  RequestBody => OkHttpRequestBody,
  Response => OkHttpResponse
}
import okio.{BufferedSink, Okio}

import scala.collection.JavaConverters._
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future}
import scala.language.higherKinds
import scala.util.{Failure, Try}

abstract class OkHttpHandler[R[_], S](client: OkHttpClient,
                                      closeClient: Boolean)
    extends SttpHandler[R, S] {
  private[okhttp] def convertRequest[T](request: Request[T, S]): OkHttpRequest = {
    val builder = new OkHttpRequest.Builder()
      .url(request.uri.toString)

    val body = bodyToOkHttp(request.body)
    builder.method(request.method.m, body.getOrElse {
      if (HttpMethod.requiresRequestBody(request.method.m))
        OkHttpRequestBody.create(null, "")
      else null
    })

    //OkHttp support automatic gzip compression
    request.headers
      .filter(_._1.equalsIgnoreCase(AcceptEncodingHeader) == false)
      .foreach {
        case (name, value) => builder.addHeader(name, value)
      }

    builder.build()
  }

  private def bodyToOkHttp[T](body: RequestBody[S]): Option[OkHttpRequestBody] = {
    body match {
      case NoBody => None
      case StringBody(b, _, _) =>
        Some(OkHttpRequestBody.create(null, b))
      case ByteArrayBody(b, _) =>
        Some(OkHttpRequestBody.create(null, b))
      case ByteBufferBody(b, _) =>
        Some(OkHttpRequestBody.create(null, b.array()))
      case InputStreamBody(b, _) =>
        Some(new OkHttpRequestBody() {
          override def writeTo(sink: BufferedSink): Unit =
            sink.writeAll(Okio.source(b))
          override def contentType(): MediaType = null
        })
      case PathBody(b, _) =>
        Some(OkHttpRequestBody.create(null, b.toFile))
      case StreamBody(s) =>
        streamToRequestBody(s)
      case MultipartBody(ps) =>
        val b = new OkHttpMultipartBody.Builder()
          .setType(OkHttpMultipartBody.FORM)
        ps.foreach(addMultipart(b, _))
        Some(b.build())
    }
  }

  private def addMultipart(builder: OkHttpMultipartBody.Builder,
                           mp: Multipart): Unit = {
    val allHeaders = mp.additionalHeaders + (ContentDispositionHeader -> mp.contentDispositionHeaderValue)
    val headers = Headers.of(allHeaders.asJava)

    bodyToOkHttp(mp.body).foreach(builder.addPart(headers, _))
  }

  private[okhttp] def readResponse[T](
      res: OkHttpResponse,
      responseAs: ResponseAs[T, S]): R[Response[T]] = {

    val code = res.code()

    val body = if (codeIsSuccess(code)) {
      responseMonad.map(responseHandler(res).handle(responseAs, responseMonad))(
        Right(_))
    } else {
      responseMonad.map(responseHandler(res).handle(asString, responseMonad))(
        Left(_))
    }

    val headers = res
      .headers()
      .names()
      .asScala
      .flatMap(name => res.headers().values(name).asScala.map((name, _)))

    responseMonad.map(body)(Response(_, res.code(), headers.toList, Nil))
  }

  private def responseHandler(res: OkHttpResponse) =
    new EagerResponseHandler[S] {
      override def handleBasic[T](bra: BasicResponseAs[T, S]): Try[T] =
        bra match {
          case IgnoreResponse =>
            Try(res.close())
          case ResponseAsString(encoding) =>
            val body = Try(
              res.body().source().readString(Charset.forName(encoding)))
            res.close()
            body
          case ResponseAsByteArray =>
            val body = Try(res.body().bytes())
            res.close()
            body
          case ras @ ResponseAsStream() =>
            responseBodyToStream(res).map(ras.responseIsStream)
          case ResponseAsFile(file, overwrite) =>
            val body = Try(
              ResponseAs.saveFile(file, res.body().byteStream(), overwrite))
            res.close()
            body
        }
    }

  def streamToRequestBody(stream: S): Option[OkHttpRequestBody] = None

  def responseBodyToStream(res: OkHttpResponse): Try[S] =
    Failure(new IllegalStateException("Streaming isn't supported"))

  override def close(): Unit = if (closeClient) {
    client.dispatcher().executorService().shutdown()
  }
}

object OkHttpHandler {

  private[okhttp] def defaultClient(readTimeout: Long,
                                    connectionTimeout: Long): OkHttpClient =
    new OkHttpClient.Builder()
      .followRedirects(false)
      .followSslRedirects(false)
      .connectTimeout(connectionTimeout, TimeUnit.MILLISECONDS)
      .readTimeout(readTimeout, TimeUnit.MILLISECONDS)
      .build()

  private[okhttp] def updateClientIfCustomReadTimeout[T, S](
      r: Request[T, S],
      client: OkHttpClient): OkHttpClient = {
    val readTimeout = r.options.readTimeout
    if (readTimeout == DefaultReadTimeout) client
    else
      client
        .newBuilder()
        .readTimeout(if (readTimeout.isFinite()) readTimeout.toMillis else 0,
                     TimeUnit.MILLISECONDS)
        .build()

  }
}

class OkHttpSyncHandler private (client: OkHttpClient, closeClient: Boolean)
    extends OkHttpHandler[Id, Nothing](client, closeClient) {
  override def send[T](r: Request[T, Nothing]): Response[T] = {
    val request = convertRequest(r)
    val response = OkHttpHandler
      .updateClientIfCustomReadTimeout(r, client)
      .newCall(request)
      .execute()
    readResponse(response, r.response)
  }

  override def responseMonad: MonadError[Id] = IdMonad
}

object OkHttpSyncHandler {
  private def apply(client: OkHttpClient,
                    closeClient: Boolean): SttpHandler[Id, Nothing] =
    new FollowRedirectsHandler[Id, Nothing](
      new OkHttpSyncHandler(client, closeClient))

  def apply(
      connectionTimeout: FiniteDuration = SttpHandler.DefaultConnectionTimeout)
    : SttpHandler[Id, Nothing] =
    OkHttpSyncHandler(OkHttpHandler.defaultClient(DefaultReadTimeout.toMillis,
                                                  connectionTimeout.toMillis),
                      closeClient = true)

  def usingClient(client: OkHttpClient): SttpHandler[Id, Nothing] =
    OkHttpSyncHandler(client, closeClient = false)
}

abstract class OkHttpAsyncHandler[R[_], S](client: OkHttpClient,
                                           rm: MonadAsyncError[R],
                                           closeClient: Boolean)
    extends OkHttpHandler[R, S](client, closeClient) {
  override def send[T](r: Request[T, S]): R[Response[T]] = {
    val request = convertRequest(r)

    rm.flatten(rm.async[R[Response[T]]] { cb =>
      def success(r: R[Response[T]]) = cb(Right(r))
      def error(t: Throwable) = cb(Left(t))

      OkHttpHandler
        .updateClientIfCustomReadTimeout(r, client)
        .newCall(request)
        .enqueue(new Callback {
          override def onFailure(call: Call, e: IOException): Unit =
            error(e)

          override def onResponse(call: Call, response: OkHttpResponse): Unit =
            try success(readResponse(response, r.response))
            catch { case e: Exception => error(e) }
        })
    })
  }

  override def responseMonad: MonadError[R] = rm
}

class OkHttpFutureHandler private (client: OkHttpClient, closeClient: Boolean)(
    implicit ec: ExecutionContext)
    extends OkHttpAsyncHandler[Future, Nothing](client,
                                                new FutureMonad,
                                                closeClient) {}

object OkHttpFutureHandler {
  private def apply(client: OkHttpClient, closeClient: Boolean)(
      implicit ec: ExecutionContext): SttpHandler[Future, Nothing] =
    new FollowRedirectsHandler[Future, Nothing](
      new OkHttpFutureHandler(client, closeClient))

  def apply(connectionTimeout: FiniteDuration =
              SttpHandler.DefaultConnectionTimeout)(
      implicit ec: ExecutionContext = ExecutionContext.Implicits.global)
    : SttpHandler[Future, Nothing] =
    OkHttpFutureHandler(OkHttpHandler.defaultClient(DefaultReadTimeout.toMillis,
                                                    connectionTimeout.toMillis),
                        closeClient = true)

  def usingClient(client: OkHttpClient)(implicit ec: ExecutionContext =
                                          ExecutionContext.Implicits.global)
    : SttpHandler[Future, Nothing] =
    OkHttpFutureHandler(client, closeClient = false)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy