All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.com.xebia.functional.xef.mlflow.MlflowClient.kt Maven / Gradle / Ivy

There is a newer version: 0.0.5-alpha.111
Show newest version
package com.xebia.functional.xef.mlflow

import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.json.Json

class MlflowClient(private val gatewayUrl: String, client: HttpClient) : AutoClose by autoClose() {

  private val internal =
    client.config {
      install(ContentNegotiation) {
        json(
          Json {
            encodeDefaults = false
            isLenient = true
            ignoreUnknownKeys = true
          }
        )
      }
    }

  private val json = Json { ignoreUnknownKeys = true }

  private suspend fun routes(): List {

    val response = internal.get("$gatewayUrl/api/2.0/gateway/routes/")
    if (response.status.isSuccess()) {
      val textResponse = response.bodyAsText()
      val data = json.decodeFromString(textResponse)
      return data.routes
    } else {
      throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
    }
  }

  suspend fun searchRoutes(): List = routes()

  suspend fun getRoute(name: String): RouteDefinition? = routes().find { it.name == name }

  suspend fun prompt(
    route: String,
    prompt: String,
    candidateCount: Int? = null,
    temperature: Double? = null,
    maxTokens: Int? = null,
    stop: List? = null
  ): PromptResponse {
    val body = Prompt(prompt, temperature, candidateCount, stop, maxTokens)
    val response =
      internal.post("$gatewayUrl/gateway/$route/invocations") {
        accept(ContentType.Application.Json)
        contentType(ContentType.Application.Json)
        setBody(body)
      }

    return if (response.status.isSuccess()) response.body()
    else if (response.status.value == 422)
      throw MLflowValidationError(
        response.status,
        response.body().detail?.firstOrNull()?.msg ?: "Unknown error"
      )
    else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
  }

  suspend fun chat(
    route: String,
    messages: List,
    candidateCount: Int? = null,
    temperature: Double? = null,
    maxTokens: Int? = null,
    stop: List? = null
  ): ChatResponse {
    val body = Chat(messages, temperature, candidateCount, stop, maxTokens)
    val response =
      internal.post("$gatewayUrl/gateway/$route/invocations") {
        accept(ContentType.Application.Json)
        contentType(ContentType.Application.Json)
        setBody(body)
      }

    return if (response.status.isSuccess()) response.body()
    else if (response.status.value == 422)
      throw MLflowValidationError(
        response.status,
        response.body().detail?.firstOrNull()?.msg ?: "Unknown error"
      )
    else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
  }

  suspend fun embeddings(route: String, text: List): EmbeddingsResponse {
    val body = Embeddings(text)
    val response =
      internal.post("$gatewayUrl/gateway/$route/invocations") {
        accept(ContentType.Application.Json)
        contentType(ContentType.Application.Json)
        setBody(body)
      }

    return if (response.status.isSuccess()) response.body()
    else if (response.status.value == 422)
      throw MLflowValidationError(
        response.status,
        response.body().detail?.firstOrNull()?.msg ?: "Unknown error"
      )
    else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
  }

  class MLflowValidationError(httpStatusCode: HttpStatusCode, error: String) :
    IllegalStateException("$httpStatusCode: $error")

  class MLflowClientUnexpectedError(httpStatusCode: HttpStatusCode, error: String) :
    IllegalStateException("$httpStatusCode: $error")
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy