All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.openai.HttpClientManager.kt Maven / Gradle / Ivy

There is a newer version: 1.0.33
Show newest version
package com.simiacryptus.openai

import com.google.common.util.concurrent.*
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.http.impl.client.HttpClientBuilder
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import java.io.BufferedOutputStream
import java.io.IOException
import java.time.Duration
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.collections.HashSet
import kotlin.math.pow

@Suppress("MemberVisibilityCanBePrivate")
open class HttpClientManager(
    protected val logLevel: Level = Level.INFO,
    val auxillaryLogOutputStream: BufferedOutputStream? = null
) {

    companion object {
        val log = LoggerFactory.getLogger(HttpClientManager::class.java)
        val scheduledPool: ListeningScheduledExecutorService =
            MoreExecutors.listeningDecorator(
                ScheduledThreadPoolExecutor(
                    4,
                    ThreadFactoryBuilder().setNameFormat("API Scheduler %d").build()
                )
            )
        val workPool: ListeningExecutorService =
            MoreExecutors.listeningDecorator(
                ThreadPoolExecutor(
                    8,
                    16,
                    0,
                    TimeUnit.MILLISECONDS,
                    LinkedBlockingQueue(),
                    ThreadFactoryBuilder().setNameFormat("API Thread %d").build()
                )
            )
        val startTime by lazy { System.currentTimeMillis() }

    }

    fun  withPool(fn: () -> T): T = workPool.submit(Callable {
        return@Callable fn()
    }).get()

    fun  withExpBackoffRetry(retryCount: Int = 7, sleepScale: Long = TimeUnit.SECONDS.toMillis(5), fn: () -> T): T {
        var lastException: Exception? = null
        for (i in 0 until retryCount) {
            try {
                return fn()
            } catch (e: ModelMaxException) {
                throw e
            } catch (e: Exception) {
                val modelMaxException = modelMaxException(e)
                if (null != modelMaxException) throw modelMaxException
                val apiKeyException = apiKeyException(e)
                if (null != apiKeyException) throw apiKeyException
                lastException = e
                this.log(Level.DEBUG, "Request failed; retrying ($i/$retryCount): " + e.message)
                Thread.sleep(sleepScale * 2.0.pow(i.toDouble()).toLong())
            }
        }
        throw lastException!!
    }

    private fun modelMaxException(e: Throwable?): ModelMaxException? {
        if (e == null) return null
        if (e is ModelMaxException) return e
        if (e.cause != null && e.cause != e) return modelMaxException(e.cause)
        return null
    }

    private fun apiKeyException(e: Throwable?): IOException? {
        if (e == null) return null
        if (e is IOException && true == e.message?.contains("Incorrect API key")) return e
        if (e.cause != null && e.cause != e) return apiKeyException(e.cause)
        return null
    }

    protected val clients: MutableMap = WeakHashMap()
    fun getClient(thread: Thread = Thread.currentThread()): CloseableHttpClient =
        if (thread in clients) clients[thread]!!
        else synchronized(clients) {
            val client = HttpClientBuilder.create().build()
            clients[thread] = client
            client
        }

    protected fun closeClient(thread: Thread) {
        try {
            synchronized(clients) {
                clients[thread]
            }?.close()
            thread.interrupt()
        } catch (e: IOException) {
            log(Level.DEBUG, "Error closing client: " + e.message)
        }
    }

    protected fun  withCancellationMonitor(fn: () -> T): T {
        val currentThread = Thread.currentThread()
        return withCancellationMonitor(fn) { currentThread.isInterrupted }
    }

    protected fun  withCancellationMonitor(fn: () -> T, cancelCheck: () -> Boolean): T {
        val threads = HashSet()
        threads.add(Thread.currentThread())
        val isCompleted = AtomicBoolean(false)
        val start = Date()
        val cancellationFuture = scheduledPool.scheduleAtFixedRate({
            if (cancelCheck()) {
                log(Level.DEBUG, "Request cancelled at ${Date()} (started $start); closing client for thread $threads")
                threads.forEach { closeClient(it) }
            }
        }, 0, 10, TimeUnit.MILLISECONDS)
        try {
            return runAsync(threads, fn)
        } finally {
            cancellationFuture.cancel(false)
        }
    }

    protected fun  withTimeout(duration: Duration, fn: () -> T): T {
        val threads = HashSet()
        val currentThread = Thread.currentThread()
        threads.add(currentThread)
        val isTimeout = AtomicBoolean(false)
        val start = Date()
        val cancellationFuture = scheduledPool.schedule({
            log(
                Level.DEBUG,
                "Request timed out after $duration at ${Date()} (started $start); closing client for thread $threads"
            )
            isTimeout.set(true)
            threads.forEach { closeClient(it) }
        }, duration.toMillis(), TimeUnit.MILLISECONDS)
        try {
            return fn()
        } catch (ex: InterruptedException) {
            if (!isTimeout.get()) throw ex
            throw RuntimeException(ex)
        } finally {
            threads.remove(currentThread)
            cancellationFuture.cancel(false)
        }
    }


    private fun  runAsync(
        threads: MutableSet,
        fn: () -> T
    ): T {
        val isDone = Semaphore(0)
        log.info("Async request started")
        val future = workPool.submit(Callable {
            val currentThread = Thread.currentThread()
            try {
                threads.add(currentThread)
                log.info("Async request started; running $fn")
                fn()
            } finally {
                threads.remove(currentThread)
                isDone.release()
                log.info("Async request completed; isDone ${System.identityHashCode(isDone)} released")
            }
        })
        try {
            isDone.acquire()
            log.info("Async request completed; getting value")
            val get = future.get()
            log.info("Async request completed; got value")
            return get
        } finally {
            log.info("Async request completed")
        }
    }

    protected fun  withReliability(requestTimeoutSeconds: Long = (5 * 60), retryCount: Int = 3, fn: () -> T): T =
        withExpBackoffRetry(retryCount) {
            withTimeout(Duration.ofSeconds(requestTimeoutSeconds)) {
                withCancellationMonitor(fn)
            }
        }

    protected fun  withPerformanceLogging(fn: () -> T): T {
        val start = Date()
        try {
            return fn()
        } finally {
            log(Level.DEBUG, "Request completed in ${Date().time - start.time}ms")
        }
    }

    protected fun  withClient(fn: java.util.function.Function): T {
        val client = getClient()
        try {
            synchronized(clients) {
                clients[Thread.currentThread()] = client
            }
            client.use { httpClient ->
                return fn.apply(httpClient)
            }
        } finally {
            synchronized(clients) {
                clients.remove(Thread.currentThread())
            }
        }
    }

    protected open fun log(level: Level = logLevel, msg: String) {
        val message = msg.trim().replace("\n", "\n\t")
        when (level) {
            Level.ERROR -> log.error(message)
            Level.WARN -> log.warn(message)
            Level.INFO -> log.info(message)
            Level.DEBUG -> log.debug(message)
            Level.TRACE -> log.debug(message)
            else -> log.debug(message)
        }
        if (auxillaryLogOutputStream != null) {
            auxillaryLogOutputStream?.write(
                "[$level] [${"%.3f".format((System.currentTimeMillis() - startTime) / 1000.0)}] ${
                    message.replace(
                        "\n",
                        "\n\t"
                    )
                }\n".toByteArray()
            )
            auxillaryLogOutputStream?.flush()
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy