io.cequence.openaiscala.task.CompletionTaskExec.scala Maven / Gradle / Ivy
The newest version!
package io.cequence.openaiscala.task
import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Sink, Source}
import io.cequence.openaiscala.domain.settings.{CreateChatCompletionSettings, CreateCompletionSettings}
import io.cequence.openaiscala.domain.{ChatRole, MessageSpec, ModelId}
import io.cequence.openaiscala.service.OpenAIService
import io.cequence.openaiscala.task.domain.{CompletionTaskIO, PromptCompletionSeparators, TextCompletionTaskSettings}
import org.slf4j.LoggerFactory
import scala.concurrent.{ExecutionContext, Future}
trait CompletionTaskExec[S] {
def execute(
completionSettings: TextCompletionTaskSettings,
taskSettings: S
): Future[Seq[CompletionTaskIO]]
def executionSource(
completionSettings: TextCompletionTaskSettings,
taskSettings: S
): Source[CompletionTaskIO, _]
def runSingleCompletion(
input: String,
completionSettings: TextCompletionTaskSettings,
taskSettings: S
): Future[String]
}
private class CompletionTaskExecImpl[S](
service: OpenAIService,
task: CompletionTask[S],
separators: Option[PromptCompletionSeparators] = None)( // used for completion... for chat completion only the completion end part as a stop
implicit ec: ExecutionContext, materializer: Materializer
) extends CompletionTaskExec[S] {
protected val logger = LoggerFactory.getLogger("Completion Task Executor")
private val chatModels = Set(
ModelId.gpt_3_5_turbo,
ModelId.gpt_3_5_turbo_0301,
ModelId.gpt_4,
ModelId.gpt_4_0314,
ModelId.gpt_4_32k,
ModelId.gpt_4_32k
)
private val retries = 3
override def execute(
completionSettings: TextCompletionTaskSettings,
taskSettings: S
): Future[Seq[CompletionTaskIO]] = {
val execSource = executionSource(completionSettings, taskSettings)
execSource.runWith(Sink.seq)
}
override def executionSource(
completionSettings: TextCompletionTaskSettings,
taskSettings: S
): Source[CompletionTaskIO, _] = {
val repetitions = completionSettings.repetitions
val flowProcess = Flow[Int].mapAsyncUnordered(completionSettings.parallelism) { repetition =>
logger.info(s"Executing a completion task ${task.getClass.getSimpleName} - repetition $repetition")
retry(s"OpenAI call for a completion task ${task.getClass.getSimpleName} failed:", logger.error(_: String), retries)(
executeOnce(completionSettings, taskSettings)
)
}
val source = Source.fromIterator(() => (1 to repetitions).iterator)
source.via(flowProcess)
}
private def executeOnce(
completionSettings: TextCompletionTaskSettings,
taskSettings: S
) = {
val input = task.generateInput(taskSettings)
runSingleCompletion(input, completionSettings, taskSettings).map { output =>
val expectedOutput = task.expectedOutput(input, taskSettings)
val evalMetrics = task.hasEvalResult match {
// has an explicit evaluation function
case true =>
task.evalResult(input, output)
// otherwise simply compare the output with the expected output
case false =>
expectedOutput.map(expectedOutput =>
if (expectedOutput == output) Some(1) else Some(0)
).getOrElse(
throw new RuntimeException(s"Either expected output or evalResult function must be defined.")
)
}
CompletionTaskIO(input, output, expectedOutput, evalMetrics)
}
}
override def runSingleCompletion(
input: String,
completionSettings: TextCompletionTaskSettings,
taskSettings: S
) = {
val isChatModel = chatModels.contains(completionSettings.model)
if (isChatModel)
runChatCompletionFor(input, completionSettings, taskSettings)
else
runCompletionFor(input, completionSettings, taskSettings)
}
private def runCompletionFor(
input: String,
completionSettings: TextCompletionTaskSettings,
taskSettings: S
) = {
val rolePrompt = task.rolePrompt.map(_ + " ").getOrElse("")
val prompt = task.generatePrompt(input, taskSettings)
val prompt2 = s"${rolePrompt}${prompt}"
// add a separator if needed
val finalPrompt = prompt2 + separators.map(_.promptEnd).getOrElse("")
service.createCompletion(
finalPrompt,
CreateCompletionSettings(
model = completionSettings.model,
max_tokens = completionSettings.max_tokens,
temperature = completionSettings.temperature,
top_p = completionSettings.top_p,
stop = separators.map(separators => Seq(separators.completionEnd)).getOrElse(Nil)
)
).map { response =>
val responseText = response.choices.head.text
// strip the completion start/end separators (if needed)
separators.map { separators =>
val responseTextAux = responseText.stripPrefix(separators.completionStart)
responseTextAux.stripSuffix(separators.completionEnd)
}.getOrElse {
responseText
}
}
}
private def runChatCompletionFor(
input: String,
completionSettings: TextCompletionTaskSettings,
taskSettings: S
) = {
val prompt = task.generatePrompt(input, taskSettings)
val messages = Seq(
task.rolePrompt.map(
MessageSpec(ChatRole.System, _)
),
Some(
MessageSpec(ChatRole.User, prompt)
),
task.seedAssistantPrompt.map(
MessageSpec(ChatRole.Assistant, _)
)
)
service.createChatCompletion(
messages.flatten,
CreateChatCompletionSettings(
model = completionSettings.model,
max_tokens = completionSettings.max_tokens,
temperature = completionSettings.temperature,
top_p = completionSettings.top_p,
stop = separators.map(separators => Seq(separators.completionEnd)).getOrElse(Nil)
)
).map { response =>
val text = response.choices.head.message.content
// add a seed to the output (if provided)
s"${task.seedAssistantPrompt.getOrElse("")}$text"
}
}
// simple retry... TODO: move elsewhere
private def retry[T](
failureMessage: String,
log: String => Unit,
maxAttemptNum: Int,
sleepOnFailureMs: Option[Int] = None)(
f: => Future[T]
): Future[T] = {
def retryAux(attempt: Int): Future[T] =
f.recoverWith {
case e: Exception =>
if (attempt < maxAttemptNum) {
log(s"${failureMessage}. ${e.getMessage}. Attempt ${attempt}. Retrying...")
sleepOnFailureMs.foreach(time =>
Thread.sleep(time)
)
retryAux(attempt + 1)
} else
throw e
}
retryAux(1)
}
}
object CompletionTaskExec {
def apply[S](
service: OpenAIService,
task: CompletionTask[S],
separators: Option[PromptCompletionSeparators] = None)(
implicit ec: ExecutionContext, materializer: Materializer
): CompletionTaskExec[S] = new CompletionTaskExecImpl[S](service, task, separators)
}