All Downloads are FREE. Search and download functionalities are using the official Maven repository.

software.purpledragon.sttp.scribe.ScribeBackend.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2018 Michael Stringer
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package software.purpledragon.sttp.scribe

import java.io.{FileOutputStream, InputStream, UnsupportedEncodingException}
import java.net.URLDecoder
import java.util.zip.{GZIPInputStream, InflaterInputStream}
import com.github.scribejava.core.model.{OAuthRequest, Response => ScribeResponse, Token, Verb}
import com.github.scribejava.core.oauth.OAuthService
import software.purpledragon.sttp.scribe.QueryParamEncodingStyle._
import sttp.client._
import sttp.client.monad.{IdMonad, MonadError}
import sttp.client.ws.WebSocketResponse
import sttp.model._

import scala.annotation.tailrec
import scala.collection.compat.immutable
import scala.jdk.CollectionConverters._
import scala.util.Using

abstract class ScribeBackend(
    service: OAuthService,
    isTokenExpiredResponse: TokenExpiredResponseCheck,
    encodingStyle: QueryParamEncodingStyle
) extends SttpBackend[Identity, Nothing, NothingT] {

  /**
   * Url query parameter encoding is handled slightly differently by sttp and scribe. This allows
   * you to configure which implementation the backend should use.
   */
  def withEncodingStyle(encodingStyle: QueryParamEncodingStyle): ScribeBackend

  override def send[T](request: Request[T, Nothing]): Response[T] = {
    send(request, retrying = false)
  }

  @tailrec
  private def send[T](request: Request[T, Nothing], retrying: Boolean): Response[T] = {
    val (url, params) = encodingStyle match {
      case Sttp =>
        (request.uri.toString, Nil)
      case Scribe =>
        (request.uri.copy(querySegments = Nil).toString, request.uri.paramsSeq)
    }
    val oAuthRequest = new OAuthRequest(method2Verb(request.method), url)

    params foreach { case (name, value) =>
      oAuthRequest.addQuerystringParameter(name, value)
    }
    request.headers foreach { header =>
      oAuthRequest.addHeader(header.name, header.value)
    }

    val contentType = request.headers
      .find(_.name.equalsIgnoreCase(HeaderNames.ContentType))
      .map(_.value.takeWhile(_ != ';'))
    setRequestPayload(request.body, contentType, oAuthRequest)

    signRequest(oAuthRequest)

    val response = service.execute(oAuthRequest)

    if (
      !retrying && response.getCode == StatusCode.Unauthorized.code && isTokenExpiredResponse(response) &&
      renewAccessToken(response)
    ) {
      // renewed access token - retry the request
      send(request, retrying = true)
    } else {
      handleResponse(response, request.response)
    }
  }

  override def openWebsocket[T, WS_RESULT](
      request: Request[T, Nothing],
      handler: NothingT[WS_RESULT]
  ): Identity[WebSocketResponse[WS_RESULT]] = {
    // we don't handle websockets
    handler
  }

  override def close(): Identity[Unit] = ()

  override val responseMonad: MonadError[Identity] = IdMonad

  protected def signRequest(request: OAuthRequest): Unit

  protected def renewAccessToken(response: ScribeResponse): Boolean

  private def handleResponse[T](r: ScribeResponse, responseAs: ResponseAs[T, Nothing]): Response[T] = {
    val statusCode = StatusCode(r.getCode)

    // scribe includes the status line as a header with a key of 'null' :-()
    val headers = r.getHeaders.asScala.toList
      .filterNot(_._1 == null)
      .map(h => Header(h._1, h._2))

    val metadata = ResponseMetadata(headers, statusCode, r.getMessage)
    val contentEncoding = Option(r.getHeader(HeaderNames.ContentEncoding))
    val is = wrapInput(r.getStream, contentEncoding)
    val body = readResponseBody(is, responseAs, metadata)

    Response(body, statusCode, r.getMessage, headers, Nil)
  }

  private def readResponseBody[T](is: InputStream, responseAs: ResponseAs[T, Nothing], meta: ResponseMetadata): T = {
    responseAs match {
      case MappedResponseAs(raw, g) => g(readResponseBody(is, raw, meta), meta)

      case ResponseAsFromMetadata(f) => readResponseBody(is, f(meta), meta)

      case IgnoreResponse =>
        @tailrec def consume(): Unit = if (is.read() != -1) consume()

        consume()

      case ResponseAsByteArray =>
        toByteArray(is)

      case ResponseAsStream() =>
        // only possible when the user requests the response as a stream of
        // Nothing. Oh well ...
        throw new IllegalStateException()

      case ResponseAsFile(output) =>
        val file = output.toFile

        if (!file.exists()) {
          if (file.getParentFile != null) {
            file.getParentFile.mkdirs()
          }
          file.createNewFile()
        }

        Using.resource(new FileOutputStream(file)) { os =>
          transfer(is, os)
        }
        output
    }
  }

  private def encodingFromContentType(contentType: String): Option[String] = {
    contentType
      .split(";")
      .map(_.trim.toLowerCase)
      .collectFirst {
        case s if s.startsWith("charset=") => s.substring("charset=".length)
      }
  }

  // scalastyle:off cyclomatic.complexity
  private def setRequestPayload(body: RequestBody[_], contentType: Option[String], request: OAuthRequest): Unit = {
    body match {
      case StringBody(content, encoding, _)
          if contentType.contains(MediaType.ApplicationXWwwFormUrlencoded.toString()) =>
        // have to add these as "body parameters" so that they get included in the oauth signature
        val FormParam = "(.*)=(.*)".r
        val bodyParams: Seq[(String, String)] =
          immutable.ArraySeq.unsafeWrapArray(content.split("&")).collect { case FormParam(key, value) =>
            (URLDecoder.decode(key, encoding), URLDecoder.decode(value, encoding))
          }
        bodyParams.foreach(p => request.addBodyParameter(p._1, p._2))

      case StringBody(content, encoding, _) =>
        request.setPayload(content)
        request.setCharset(encoding)

      case ByteArrayBody(content, _) =>
        request.setPayload(content)

      case ByteBufferBody(content, _) =>
        request.setPayload(content.array())

      case FileBody(content, _) =>
        request.setPayload(content.toFile)

      case InputStreamBody(_, _) =>
        throw new UnsupportedOperationException("scribe does not support InputStream bodies")

      case StreamBody(_) =>
        throw new UnsupportedOperationException("scribe does not support Stream bodies")

      case MultipartBody(_) =>
        throw new UnsupportedOperationException("scribe does not support Multipart bodies")

      case NoBody =>
      // nothing to set
    }
  }

  // scalastyle:on

  private def method2Verb(method: Method): Verb = {
    method match {
      case Method.GET => Verb.GET
      case Method.POST => Verb.POST
      case Method.PUT => Verb.PUT
      case Method.DELETE => Verb.DELETE
      case Method.OPTIONS => Verb.OPTIONS
      case Method.PATCH => Verb.PATCH
      case Method.TRACE => Verb.TRACE
      case m => throw new NotImplementedError(s"Scribe does not support $m")
    }
  }

  private def wrapInput(is: InputStream, encoding: Option[String]) = {
    encoding.map(_.toLowerCase) match {
      case None => is
      case Some("gzip") => new GZIPInputStream(is)
      case Some("deflate") => new InflaterInputStream(is)
      case Some(ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce")
    }
  }

}

trait OAuthTokenProvider[T <: Token] {
  def accessTokenForRequest: T

  def tokenRenewed(token: T): Unit
}

sealed trait QueryParamEncodingStyle

object QueryParamEncodingStyle {

  case object Sttp extends QueryParamEncodingStyle

  case object Scribe extends QueryParamEncodingStyle

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy