All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cequence.openaiscala.task.CompletionTaskExec.scala Maven / Gradle / Ivy

The newest version!
package io.cequence.openaiscala.task

import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Sink, Source}
import io.cequence.openaiscala.domain.settings.{CreateChatCompletionSettings, CreateCompletionSettings}
import io.cequence.openaiscala.domain.{ChatRole, MessageSpec, ModelId}
import io.cequence.openaiscala.service.OpenAIService
import io.cequence.openaiscala.task.domain.{CompletionTaskIO, PromptCompletionSeparators, TextCompletionTaskSettings}
import org.slf4j.LoggerFactory

import scala.concurrent.{ExecutionContext, Future}

trait CompletionTaskExec[S] {

  def execute(
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ): Future[Seq[CompletionTaskIO]]

  def executionSource(
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ): Source[CompletionTaskIO, _]

  def runSingleCompletion(
    input: String,
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ): Future[String]
}

private class CompletionTaskExecImpl[S](
  service: OpenAIService,
  task: CompletionTask[S],
  separators: Option[PromptCompletionSeparators] = None)( // used for completion... for chat completion only the completion end part as a stop
  implicit ec: ExecutionContext, materializer: Materializer
) extends CompletionTaskExec[S] {

  protected val logger = LoggerFactory.getLogger("Completion Task Executor")

  private val chatModels = Set(
    ModelId.gpt_3_5_turbo,
    ModelId.gpt_3_5_turbo_0301,
    ModelId.gpt_4,
    ModelId.gpt_4_0314,
    ModelId.gpt_4_32k,
    ModelId.gpt_4_32k
  )

  private val retries = 3

  override def execute(
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ): Future[Seq[CompletionTaskIO]] = {
    val execSource = executionSource(completionSettings, taskSettings)
    execSource.runWith(Sink.seq)
  }

  override def executionSource(
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ): Source[CompletionTaskIO, _] = {
    val repetitions = completionSettings.repetitions

    val flowProcess = Flow[Int].mapAsyncUnordered(completionSettings.parallelism) { repetition =>
      logger.info(s"Executing a completion task ${task.getClass.getSimpleName} - repetition $repetition")

      retry(s"OpenAI call for a completion task ${task.getClass.getSimpleName} failed:", logger.error(_: String), retries)(
        executeOnce(completionSettings, taskSettings)
      )
    }

    val source = Source.fromIterator(() => (1 to repetitions).iterator)

    source.via(flowProcess)
  }

  private def executeOnce(
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ) = {
    val input = task.generateInput(taskSettings)

    runSingleCompletion(input, completionSettings, taskSettings).map { output =>
      val expectedOutput = task.expectedOutput(input, taskSettings)

      val evalMetrics = task.hasEvalResult match {
        // has an explicit evaluation function
        case true =>
          task.evalResult(input, output)

        // otherwise simply compare the output with the expected output
        case false =>
          expectedOutput.map(expectedOutput =>
            if (expectedOutput == output) Some(1) else Some(0)
          ).getOrElse(
            throw new RuntimeException(s"Either expected output or evalResult function must be defined.")
          )
      }

      CompletionTaskIO(input, output, expectedOutput, evalMetrics)
    }
  }

  override def runSingleCompletion(
    input: String,
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ) = {
    val isChatModel = chatModels.contains(completionSettings.model)

    if (isChatModel)
      runChatCompletionFor(input, completionSettings, taskSettings)
    else
      runCompletionFor(input, completionSettings, taskSettings)
  }

  private def runCompletionFor(
    input: String,
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ) = {
    val rolePrompt = task.rolePrompt.map(_ + " ").getOrElse("")
    val prompt = task.generatePrompt(input, taskSettings)
    val prompt2 = s"${rolePrompt}${prompt}"
    // add a separator if needed
    val finalPrompt = prompt2 + separators.map(_.promptEnd).getOrElse("")

    service.createCompletion(
      finalPrompt,
      CreateCompletionSettings(
        model = completionSettings.model,
        max_tokens = completionSettings.max_tokens,
        temperature = completionSettings.temperature,
        top_p = completionSettings.top_p,
        stop = separators.map(separators => Seq(separators.completionEnd)).getOrElse(Nil)
      )
    ).map { response =>
      val responseText = response.choices.head.text

      // strip the completion start/end separators (if needed)
      separators.map { separators =>
        val responseTextAux = responseText.stripPrefix(separators.completionStart)
        responseTextAux.stripSuffix(separators.completionEnd)
      }.getOrElse {
        responseText
      }
    }
  }

  private def runChatCompletionFor(
    input: String,
    completionSettings: TextCompletionTaskSettings,
    taskSettings: S
  ) = {
    val prompt = task.generatePrompt(input, taskSettings)

    val messages = Seq(
      task.rolePrompt.map(
        MessageSpec(ChatRole.System, _)
      ),
      Some(
        MessageSpec(ChatRole.User, prompt)
      ),
      task.seedAssistantPrompt.map(
        MessageSpec(ChatRole.Assistant, _)
      )
    )

    service.createChatCompletion(
      messages.flatten,
      CreateChatCompletionSettings(
        model = completionSettings.model,
        max_tokens = completionSettings.max_tokens,
        temperature = completionSettings.temperature,
        top_p = completionSettings.top_p,
        stop = separators.map(separators => Seq(separators.completionEnd)).getOrElse(Nil)
      )
    ).map { response =>
      val text = response.choices.head.message.content
      // add a seed to the output (if provided)
      s"${task.seedAssistantPrompt.getOrElse("")}$text"
    }
  }

  // simple retry... TODO: move elsewhere
  private def retry[T](
    failureMessage: String,
    log: String => Unit,
    maxAttemptNum: Int,
    sleepOnFailureMs: Option[Int] = None)(
    f: => Future[T]
  ): Future[T] = {
    def retryAux(attempt: Int): Future[T] =
      f.recoverWith {
        case e: Exception =>
          if (attempt < maxAttemptNum) {
            log(s"${failureMessage}. ${e.getMessage}. Attempt ${attempt}. Retrying...")
            sleepOnFailureMs.foreach(time =>
              Thread.sleep(time)
            )
            retryAux(attempt + 1)
          } else
            throw e
      }

    retryAux(1)
  }
}

object CompletionTaskExec {

  def apply[S](
    service: OpenAIService,
    task: CompletionTask[S],
    separators: Option[PromptCompletionSeparators] = None)(
    implicit ec: ExecutionContext, materializer: Materializer
  ): CompletionTaskExec[S] = new CompletionTaskExecImpl[S](service, task, separators)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy