All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spice.http.client.OkHttpClientInstance.scala Maven / Gradle / Ivy

There is a newer version: 0.7.2
Show newest version
package spice.http.client

import cats.effect.unsafe.implicits.global
import cats.effect.{Deferred, IO}
import fabric.io.JsonFormatter
import okhttp3.{Authenticator, Request, Response, Route}
import spice.http.content.FormDataEntry.{FileEntry, StringEntry}
import spice.http.content._
import spice.http._
import spice.net.{ContentType, URL}
import spice.streamer._

import java.io.{File, IOException}
import java.net.{InetAddress, InetSocketAddress, Socket}
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.util
import java.util.concurrent.TimeUnit
import javax.net.ssl._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

/**
 * Asynchronous HttpClient for simple request response support.
 *
 * Adds support for simple restful request/response JSON support.
 */
class OkHttpClientInstance(client: HttpClient) extends HttpClientInstance {
  private lazy val instance = {
    val b = new okhttp3.OkHttpClient.Builder()
    b.sslSocketFactory(new SSLSocketFactory {
      //      private val default = SSLSocketFactory.getDefault.asInstanceOf[SSLSocketFactory]
      private val disabled = {
        val trustAllCerts = Array[TrustManager](new X509TrustManager {
          override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {}

          override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {}

          override def getAcceptedIssuers: Array[X509Certificate] = null
        })
        val sc = SSLContext.getInstance("SSL")
        sc.init(null, trustAllCerts, new SecureRandom)
        sc.getSocketFactory
      }

      private def f: SSLSocketFactory = disabled //if (config.validateSSLCertificates) default else disabled

      override def getDefaultCipherSuites: Array[String] = f.getDefaultCipherSuites

      override def getSupportedCipherSuites: Array[String] = f.getSupportedCipherSuites

      override def createSocket(socket: Socket, s: String, i: Int, b: Boolean): Socket = f.createSocket(socket, s, i, b)

      override def createSocket(s: String, i: Int): Socket = f.createSocket(s, i)

      override def createSocket(s: String, i: Int, inetAddress: InetAddress, i1: Int): Socket = f.createSocket(s, i, inetAddress, i1)

      override def createSocket(inetAddress: InetAddress, i: Int): Socket = f.createSocket(inetAddress, i)

      override def createSocket(inetAddress: InetAddress, i: Int, inetAddress1: InetAddress, i1: Int): Socket = f.createSocket(inetAddress, i, inetAddress1, i1)
    }, new X509TrustManager {
      //      private val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
      //      private val default = {
      //        factory.init(KeyStore.getInstance(KeyStore.getDefaultType))
      //        factory.getTrustManagers.apply(0).asInstanceOf[X509TrustManager]
      //      }

      override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {
        if (client.validateSSLCertificates) {
          //          default.checkClientTrusted(x509Certificates, s)
        }
      }

      override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {
        if (client.validateSSLCertificates) {
          //          default.checkServerTrusted(x509Certificates, s)
        }
      }

      override def getAcceptedIssuers: Array[X509Certificate] = //if (config.validateSSLCertificates) {
      //        default.getAcceptedIssuers
      //      } else {
        Array.empty[X509Certificate]
      //      }
    })
    b.hostnameVerifier(new HostnameVerifier {
      override def verify(s: String, sslSession: SSLSession): Boolean = true
    })
    b.connectionPool(client.connectionPool.asInstanceOf[OkHttpConnectionPool].pool)
    b.connectTimeout(client.timeout.toMillis, TimeUnit.MILLISECONDS)
    b.readTimeout(client.timeout.toMillis, TimeUnit.MILLISECONDS)
    b.writeTimeout(client.timeout.toMillis, TimeUnit.MILLISECONDS)
    b.dns((hostname: String) => {
      val list = new util.ArrayList[InetAddress]()
      client.dns.lookup(hostname).unsafeRunSync() match {
        case Some(ip) => list.add(InetAddress.getByAddress(ip.address.map(_.toByte).toArray))
        case None => // None
      }
      list
    })
    client.pingInterval.foreach(d => b.pingInterval(d.toMillis, TimeUnit.MILLISECONDS))
    client.proxy.foreach {
      case Proxy(t, h, p, c) =>
        val `type` = t match {
          case ProxyType.Direct => java.net.Proxy.Type.DIRECT
          case ProxyType.Http => java.net.Proxy.Type.HTTP
          case ProxyType.Socks => java.net.Proxy.Type.SOCKS
        }
        val sa = new InetSocketAddress(h, p)
        b.proxy(new java.net.Proxy(`type`, sa))
        c.foreach { creds =>
          b.proxyAuthenticator(new Authenticator {
            override def authenticate(route: Route, response: Response): Request = {
              val credential = okhttp3.Credentials.basic(creds.username, creds.password)
              response.request().newBuilder().header("Proxy-Authorization", credential).build()
            }
          })
        }
    }
    b.build()
  }

  /*def disableSSLVerification(): Unit = {
    val trustAllCerts = Array[TrustManager](new X509TrustManager {
      override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {}

      override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String): Unit = {}

      override def getAcceptedIssuers: Array[X509Certificate] = null
    })
    val sc = SSLContext.getInstance("SSL")
    sc.init(null, trustAllCerts, new SecureRandom)
    HttpsURLConnection.setDefaultSSLSocketFactory(sc.getSocketFactory)
    val allHostsValid = new HostnameVerifier {
      override def verify(s: String, sslSession: SSLSession): Boolean = true
    }
    HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid)
  }*/

  override def send(request: HttpRequest): IO[Try[HttpResponse]] = Deferred[IO, Try[HttpResponse]].flatMap { deferred =>
    val req = requestToOk(request)
    instance.newCall(req).enqueue(new okhttp3.Callback {
      override def onResponse(call: okhttp3.Call, res: okhttp3.Response): Unit = {
        val response = responseFromOk(res)
        deferred.complete(Success(response)).unsafeRunAndForget()
      }

      override def onFailure(call: okhttp3.Call, exc: IOException): Unit = {
        deferred.complete(Failure(exc)).unsafeRunAndForget()
      }
    })
    OkHttpClientImplementation.process(deferred.get)
  }

  private def requestToOk(request: HttpRequest): okhttp3.Request = {
    val r = new okhttp3.Request.Builder().url(request.url.toString)

    // Headers
    request.headers.map.foreach {
      case (key, values) => values.foreach(r.addHeader(key, _))
    }

    def ct(contentType: ContentType): okhttp3.MediaType = okhttp3.MediaType.parse(contentType.outputString)

    // Content
    val body = request.content.map {
      case StringContent(value, contentType, _) => okhttp3.RequestBody.create(value, ct(contentType))
      case FileContent(file, contentType, _) => okhttp3.RequestBody.create(file, ct(contentType))
      case BytesContent(array, contentType, _) => okhttp3.RequestBody.create(array, ct(contentType))
      case JsonContent(json, compact, contentType, _) =>
        val jsonString = if (compact) JsonFormatter.Compact(json) else JsonFormatter.Default(json)
        okhttp3.RequestBody.create(jsonString, ct(contentType))
      case FormDataContent(entries) =>
        val form = new okhttp3.MultipartBody.Builder()
        form.setType(ct(ContentType.`multipart/form-data`))
        entries.foreach {
          case (key, entry) => entry match {
            case StringEntry(value, _) => form.addFormDataPart(key, value)
            case FileEntry(fileName, file, headers) =>
              val partType = Headers.`Content-Type`.value(headers).getOrElse(ContentType.`application/octet-stream`)
              val part = okhttp3.RequestBody.create(file, ct(partType))
              form.addFormDataPart(key, fileName, part)
          }
        }
        form.build()
      case URLContent(url, contentType, _) =>
        val conn = url.openConnection()
        val length = conn.getContentLength
        val input = conn.getInputStream
        try {
          val bytes = input.readNBytes(length)
          okhttp3.RequestBody.create(bytes, ct(contentType))
        } finally {
          Try(input.close())
        }
      case c => throw new RuntimeException(s"Unsupported request content: $c")
    }.getOrElse {
      if (request.method != HttpMethod.Get) {
        okhttp3.RequestBody.create("", None.orNull)
      } else {
        None.orNull
      }
    }

    // Method
    r
      .method(request.method.value, body)
      .header("Content-Length", Option(body).map(_.contentLength().toString).getOrElse("0"))
      .build()
  }

  private def responseFromOk(r: okhttp3.Response): HttpResponse = {
    // Status
    val status = HttpStatus.getByCode(r.code()).getOrElse(HttpStatus(code = r.code(), message = r.message()))

    // Headers
    val headersMap = r.headers().names().asScala.toList.map { key =>
      key -> r.headers(key).asScala.toList
    }.toMap
    val headers = Headers(headersMap)

    // Content
    val contentType = Headers.`Content-Type`.value(headers).getOrElse(ContentType.`application/octet-stream`)
    val contentLength = Headers.`Content-Length`.value(headers)
    val content = Option(r.body()).map { responseBody =>
      if (contentToString(contentType, contentLength)) {
        Content.string(responseBody.string(), contentType)
      } else if (contentToBytes(contentType, contentLength)) {
        Content.bytes(responseBody.bytes(), contentType)
      } else {
        val suffix = contentType.extension.getOrElse("client")
        val file = File.createTempFile("spice", s".$suffix", new File(client.saveDirectory))
        Streamer(responseBody.byteStream(), file).unsafeRunSync()
        Content.file(file, contentType)
      }
    }

    HttpResponse(
      status = status,
      headers = headers,
      content = content
    )
  }

  private def contentToString(contentType: ContentType, contentLength: Option[Long]): Boolean = {
    contentType.`type` == "text" || contentType.subType == "json"
  }

  private def contentToBytes(contentType: ContentType, contentLength: Option[Long]): Boolean = {
    contentLength.exists(l => l > 0L && l < 512000L)
  }

  override def webSocket(url: URL): WebSocket = new OkHttpWebSocket(url, instance)

  override def dispose(): IO[Unit] = for {
    _ <- IO(Try(instance.dispatcher().executorService().shutdown()))
    _ <- IO(Try(instance.connectionPool().evictAll()))
    _ <- IO(Try(instance.cache().close()))
  } yield {
    ()
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy