All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.malliina.logstreams.client.WebSocketIO.scala Maven / Gradle / Ivy

There is a newer version: 2.8.2
Show newest version
package com.malliina.logstreams.client

import cats.effect.concurrent.Ref
import cats.effect.{Concurrent, ContextShift, IO, Timer}
import com.malliina.http.FullUrl
import com.malliina.values.Username
import fs2.Stream
import fs2.concurrent.{SignallingRef, Topic}
import okhttp3._
import okio.ByteString
import org.slf4j.LoggerFactory
import WebSocketIO.log

import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration.{DurationInt, FiniteDuration}

object WebSocketIO {
  private val log = LoggerFactory.getLogger(getClass)

  def apply(url: FullUrl, headers: Map[String, String], client: OkHttpClient)(implicit
    cs: ContextShift[IO],
    t: Timer[IO]
  ): IO[WebSocketIO] = for {
    topic <- Topic[IO, SocketEvent](SocketEvent.Idle)
    interrupter <- SignallingRef[IO, Boolean](false)
  } yield new WebSocketIO(url, headers, client, topic, interrupter)
}

class WebSocketIO(
  val url: FullUrl,
  headers: Map[String, String],
  client: OkHttpClient,
  topic: Topic[IO, SocketEvent],
  interrupter: SignallingRef[IO, Boolean]
)(implicit
  val cs: ContextShift[IO],
  c: Concurrent[IO],
  t: Timer[IO]
) extends Closeable {
  private val backoffTime: FiniteDuration = 10.seconds
  private val active = new AtomicReference[Option[WebSocket]](None)
  private def connectSocket(): IO[WebSocket] = connect.flatMap { socket =>
    IO(active.set(Option(socket))).map(_ => socket)
  }
  import SocketEvent._
  private val backoff =
    Stream.eval(IO(log.info(s"Reconnecting to '$url' in $backoffTime..."))).flatMap { _ =>
      Stream.sleep(backoffTime).map(_ => Idle)
    }
  private val allEvents: fs2.Stream[IO, SocketEvent] = topic.subscribe(10)
  import SocketEvent._
  val untilFailure = allEvents.drop(1).takeWhile {
    case Failure(_, _, _) => false
    case _                => true
  }
  val events = Stream
    .eval(connectSocket())
    .flatMap(_ => untilFailure ++ backoff)
    .repeat
    .interruptWhen(interrupter)
  val messages: Stream[IO, String] = events.collect {
    case SocketEvent.TextMessage(_, message) => message
  }
  private val listener: WebSocketListener = new WebSocketListener {
    override def onClosed(webSocket: WebSocket, code: Int, reason: String) = {
      log.info(s"Closed  socket to '$url'.")
      topic.push(Closed(webSocket, code, reason))
    }
    override def onClosing(webSocket: WebSocket, code: Int, reason: String) = {
      log.info(s"Closing socket to '$url'.")
      topic.push(Closing(webSocket, code, reason))
    }
    override def onFailure(webSocket: WebSocket, t: Throwable, response: Response) = {
      log.warn(s"Socket to '$url' failed.", t)
      topic.push(Failure(webSocket, Option(t), Option(response)))
    }
    override def onMessage(webSocket: WebSocket, text: String) = {
      log.debug(s"Received text '$text'.")
      topic.push(TextMessage(webSocket, text))
    }
    override def onMessage(webSocket: WebSocket, bytes: ByteString) = {
      log.debug(s"Received bytes $bytes")
      topic.push(BytesMessage(webSocket, bytes))
    }
    override def onOpen(webSocket: WebSocket, response: Response) = {
      log.info(s"Opened socket to '$url'.")
      topic.push(Open(webSocket, response))
    }
  }
  val request = requestFor(url, headers).build()
  val connect = IO(client.newWebSocket(request, listener))

  def send(message: String): Boolean = active.get().exists(_.send(message))

  def close(): Unit = {
    log.info(s"Closing socket to '$url'...")
    interrupter.set(true).unsafeRunSync()
    active.get().foreach(_.cancel())
  }

  def requestFor(url: FullUrl, headers: Map[String, String]): Request.Builder =
    headers.foldLeft(new Request.Builder().url(url.url)) {
      case (r, (key, value)) => r.addHeader(key, value)
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy