All Downloads are FREE. Search and download functionalities are using the official Maven repository.

sttp.client3.HttpURLConnectionBackend.scala Maven / Gradle / Ivy

There is a newer version: 3.10.1
Show newest version
package sttp.client3

import java.io._
import java.net._
import java.nio.channels.Channels
import java.nio.charset.CharacterCodingException
import java.nio.file.Files
import java.util.concurrent.ThreadLocalRandom
import java.util.zip.{GZIPInputStream, InflaterInputStream}
import sttp.capabilities.Effect
import sttp.client3.HttpURLConnectionBackend.EncodingHandler
import sttp.client3.internal._
import sttp.client3.monad.IdMonad
import sttp.client3.testing.SttpBackendStub
import sttp.client3.ws.{GotAWebSocketException, NotAWebSocketException}
import sttp.model.{Header, HeaderNames, ResponseMetadata, StatusCode, Uri}
import sttp.monad.MonadError

import scala.collection.JavaConverters._
import scala.concurrent.duration.Duration

class HttpURLConnectionBackend private (
    opts: SttpBackendOptions,
    customizeConnection: HttpURLConnection => Unit,
    createURL: String => URL,
    openConnection: (URL, Option[java.net.Proxy]) => URLConnection,
    customEncodingHandler: EncodingHandler
) extends SttpBackend[Identity, Any] {
  override def send[T, R >: Any with Effect[Identity]](r: Request[T, R]): Response[T] =
    adjustExceptions(r) {
      val c = openConnection(r.uri)
      c.setRequestMethod(r.method.method)
      r.headers.foreach { h => c.setRequestProperty(h.name, h.value) }
      c.setDoInput(true)
      c.setReadTimeout(timeout(r.options.readTimeout))
      c.setConnectTimeout(timeout(opts.connectionTimeout))

      // redirects are handled by FollowRedirectsBackend
      c.setInstanceFollowRedirects(false)

      customizeConnection(c)

      if (r.body != NoBody) {
        c.setDoOutput(true)
        // we need to take care to:
        // (1) only call getOutputStream after the headers are set
        // (2) call it ony once
        writeBody(r, c).foreach { os =>
          os.flush()
          os.close()
        }
      }

      try {
        val is = c.getInputStream
        readResponse(c, is, r)
      } catch {
        case e: CharacterCodingException     => throw e
        case e: UnsupportedEncodingException => throw e
        case e: SocketException              => throw e
        case _: IOException if c.getResponseCode != -1 =>
          readResponse(c, c.getErrorStream, r)
      }
    }

  override implicit val responseMonad: MonadError[Identity] = IdMonad

  private def openConnection(uri: Uri): HttpURLConnection = {
    val url = createURL(uri.toString)
    val conn = opts.proxy match {
      case Some(p) if uri.host.forall(!p.ignoreProxy(_)) =>
        p.auth.foreach { proxyAuth =>
          Authenticator.setDefault(new Authenticator() {
            override def getPasswordAuthentication: PasswordAuthentication = {
              new PasswordAuthentication(proxyAuth.username, proxyAuth.password.toCharArray)
            }
          })
        }

        openConnection(url, Some(p.asJavaProxy))
      case _ => openConnection(url, None)
    }

    conn.asInstanceOf[HttpURLConnection]
  }

  private def writeBody(r: Request[_, Nothing], c: HttpURLConnection): Option[OutputStream] = {
    r.body match {
      case NoBody =>
        // skip
        None

      case b: BasicRequestBody =>
        val os = c.getOutputStream
        writeBasicBody(b, os)
        Some(os)

      case StreamBody(_) =>
        // we have an instance of nothing - everything's possible!
        None

      case mp: MultipartBody[Nothing] =>
        setMultipartBody(r, mp, c)
    }
  }

  private def timeout(t: Duration): Int =
    if (t.isFinite) t.toMillis.toInt
    else 0

  private def writeBasicBody(body: BasicRequestBody, os: OutputStream): Unit = {
    body match {
      case StringBody(b, encoding, _) =>
        val writer = new OutputStreamWriter(os, encoding)
        writer.write(b)
        // don't close - as this will close the underlying OS and cause errors
        // with multi-part
        writer.flush()

      case ByteArrayBody(b, _) =>
        os.write(b)

      case ByteBufferBody(b, _) =>
        val channel = Channels.newChannel(os)
        channel.write(b)

      case InputStreamBody(b, _) =>
        transfer(b, os)

      case FileBody(f, _) =>
        Files.copy(f.toPath, os)
    }
  }

  private val BoundaryChars =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".toCharArray

  private def setMultipartBody(
      r: Request[_, Nothing],
      mp: MultipartBody[Nothing],
      c: HttpURLConnection
  ): Option[OutputStream] = {
    val boundary = {
      val tlr = ThreadLocalRandom.current()
      List
        .fill(32)(BoundaryChars(tlr.nextInt(BoundaryChars.length)))
        .mkString
    }

    // inspired by: https://github.com/scalaj/scalaj-http/blob/master/src/main/scala/scalaj/http/Http.scala#L542
    val partsWithHeaders = mp.parts.map { p =>
      val contentDisposition = s"${HeaderNames.ContentDisposition}: ${p.contentDispositionHeaderValue}"
      val otherHeaders = p.headers.map(h => s"${h.name}: ${h.value}")
      val allHeaders = List(contentDisposition) ++ otherHeaders
      (allHeaders.mkString(CrLf), p)
    }

    val dashes = "--"

    val dashesLen = dashes.length.toLong
    val crLfLen = CrLf.length.toLong
    val boundaryLen = boundary.length.toLong
    val finalBoundaryLen = dashesLen + boundaryLen + dashesLen + crLfLen

    // https://stackoverflow.com/questions/31406022/how-is-an-http-multipart-content-length-header-value-calculated
    val contentLength = partsWithHeaders
      .map { case (headers, p) =>
        val bodyLen: Option[Long] = p.body match {
          case StringBody(b, encoding, _) =>
            Some(b.getBytes(encoding).length.toLong)
          case ByteArrayBody(b, _)   => Some(b.length.toLong)
          case ByteBufferBody(_, _)  => None
          case InputStreamBody(_, _) => None
          case FileBody(b, _)        => Some(b.toFile.length())
          case NoBody                => None
          case StreamBody(_)         => None
          case MultipartBody(_)      => None
        }

        val headersLen = headers.getBytes(Utf8).length

        bodyLen.map(bl => dashesLen + boundaryLen + crLfLen + headersLen + crLfLen + crLfLen + bl + crLfLen)
      }
      .foldLeft(Option(finalBoundaryLen)) {
        case (Some(acc), Some(l)) => Some(acc + l)
        case _                    => None
      }

    val baseContentType = r.headers.find(_.is(HeaderNames.ContentType)).map(_.value).getOrElse("multipart/form-data")
    c.setRequestProperty(HeaderNames.ContentType, s"$baseContentType; boundary=" + boundary)

    contentLength.foreach { cl =>
      c.setFixedLengthStreamingMode(cl)
      c.setRequestProperty(HeaderNames.ContentLength, cl.toString)
    }

    var total = 0L

    val os = c.getOutputStream
    def writeMeta(s: String): Unit = {
      val utf8Bytes = s.getBytes(Utf8)
      os.write(utf8Bytes)
      total += utf8Bytes.length.toLong
    }

    partsWithHeaders.foreach { case (headers, p) =>
      writeMeta(dashes)
      writeMeta(boundary)
      writeMeta(CrLf)
      writeMeta(headers)
      writeMeta(CrLf)
      writeMeta(CrLf)
      p.body match {
        case NoBody                 => // skip
        case body: BasicRequestBody => writeBasicBody(body, os)
        case StreamBody(_)          => // not possible
        case MultipartBody(_)       => throwNestedMultipartNotAllowed
      }
      writeMeta(CrLf)
    }

    // final boundary
    writeMeta(dashes)
    writeMeta(boundary)
    writeMeta(dashes)
    writeMeta(CrLf)

    Some(os)
  }

  private def readResponse[T](
      c: HttpURLConnection,
      is: InputStream,
      request: Request[T, Nothing]
  ): Response[T] = {
    val headers = c.getHeaderFields.asScala.toVector
      .filter(_._1 != null)
      .flatMap { case (k, vv) => vv.asScala.map(Header(k, _)) }
    val contentEncoding = Option(c.getHeaderField(HeaderNames.ContentEncoding)).filter(_.nonEmpty)
    val code = StatusCode(c.getResponseCode)
    val wrappedIs =
      if (c.getRequestMethod != "HEAD" && !code.equals(StatusCode.NoContent) && !request.autoDecompressionDisabled) {
        wrapInput(contentEncoding, handleNullInput(is))
      } else handleNullInput(is)
    val responseMetadata = ResponseMetadata(code, c.getResponseMessage, headers)
    val body = bodyFromResponseAs(request.response, responseMetadata, Left(wrappedIs))

    Response(body, code, c.getResponseMessage, headers, Nil, request.onlyMetadata)
  }

  private val bodyFromResponseAs = new BodyFromResponseAs[Identity, InputStream, Nothing, Nothing]() {
    override protected def withReplayableBody(
        response: InputStream,
        replayableBody: Either[Array[Byte], SttpFile]
    ): Identity[InputStream] =
      replayableBody match {
        case Left(bytes) => new ByteArrayInputStream(bytes)
        case Right(file) => new BufferedInputStream(new FileInputStream(file.toFile))
      }
    override protected def regularIgnore(response: InputStream): Identity[Unit] = response.close()
    override protected def regularAsByteArray(response: InputStream): Identity[Array[Byte]] = toByteArray(response)
    override protected def regularAsFile(response: InputStream, file: SttpFile): Identity[SttpFile] = {
      FileHelpers.saveFile(file.toFile, response)
      file
    }
    override protected def regularAsStream(response: InputStream): (Nothing, () => Identity[Unit]) =
      throw new IllegalStateException()
    override protected def handleWS[T](
        responseAs: WebSocketResponseAs[T, _],
        meta: ResponseMetadata,
        ws: Nothing
    ): Identity[T] = ws
    override protected def cleanupWhenNotAWebSocket(response: InputStream, e: NotAWebSocketException): Identity[Unit] =
      ()
    override protected def cleanupWhenGotWebSocket(response: Nothing, e: GotAWebSocketException): Identity[Unit] = ()
  }

  private def handleNullInput(is: InputStream): InputStream =
    if (is == null)
      new ByteArrayInputStream(Array.empty[Byte])
    else
      is

  private def wrapInput(contentEncoding: Option[String], is: InputStream): InputStream =
    contentEncoding.map(_.toLowerCase) match {
      case None                                                    => is
      case Some("gzip")                                            => new GZIPInputStream(is)
      case Some("deflate")                                         => new InflaterInputStream(is)
      case Some(ce) if customEncodingHandler.isDefinedAt((is, ce)) => customEncodingHandler(is -> ce)
      case Some(ce) =>
        throw new UnsupportedEncodingException(s"Unsupported encoding: $ce")
    }

  private def adjustExceptions[T](request: Request[_, _])(t: => T): T =
    SttpClientException.adjustExceptions(responseMonad)(t)(
      SttpClientException.defaultExceptionToSttpClientException(request, _)
    )

  override def close(): Unit = {}
}

object HttpURLConnectionBackend {

  type EncodingHandler = PartialFunction[(InputStream, String), InputStream]

  private[client3] val defaultOpenConnection: (URL, Option[java.net.Proxy]) => URLConnection = {
    case (url, None)        => url.openConnection()
    case (url, Some(proxy)) => url.openConnection(proxy)
  }

  def apply(
      options: SttpBackendOptions = SttpBackendOptions.Default,
      customizeConnection: HttpURLConnection => Unit = _ => (),
      createURL: String => URL = new URL(_),
      openConnection: (URL, Option[java.net.Proxy]) => URLConnection = {
        case (url, None)        => url.openConnection()
        case (url, Some(proxy)) => url.openConnection(proxy)
      },
      customEncodingHandler: EncodingHandler = PartialFunction.empty
  ): SttpBackend[Identity, Any] =
    new FollowRedirectsBackend[Identity, Any](
      new HttpURLConnectionBackend(options, customizeConnection, createURL, openConnection, customEncodingHandler)
    )

  /** Create a stub backend for testing, which uses the [[Identity]] response wrapper, and doesn't support streaming.
    *
    * See [[SttpBackendStub]] for details on how to configure stub responses.
    */
  def stub: SttpBackendStub[Identity, Any] = SttpBackendStub.synchronous
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy