All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cequence.openaiscala.task.FineTuningTrainingSetGenerator.scala Maven / Gradle / Ivy

The newest version!
package io.cequence.openaiscala.task

import akka.stream.scaladsl.Source
import io.cequence.openaiscala.task.domain.PromptCompletionSeparators
import org.slf4j.LoggerFactory
import play.api.libs.json.Json

trait FineTuningTrainingSetGenerator[S] {

  def generate(
    count: Int,
    taskSettings: S
  ): Source[String, _]
}

private class FineTuningTrainingSetGeneratorImpl[S](
  task: CompletionTask[S],
  separators: Option[PromptCompletionSeparators] = None
) extends FineTuningTrainingSetGenerator[S] {

  protected val logger = LoggerFactory.getLogger("Fine-tuning Training Set Generator")

  private val actualSeparators = separators.getOrElse(PromptCompletionSeparators.Default)

  override def generate(
    count: Int,
    taskSettings: S
  ): Source[String, _] = {
    val source = Source.fromIterator(() => (1 to count).iterator)
    source.map(_ => generateOnce(taskSettings))
  }

  private def generateOnce(
    taskSettings: S
  ) = {
    val input = task.generateInput(taskSettings)

    val rolePrompt = task.rolePrompt.map(_ + " ").getOrElse("")
    val prompt = task.generatePrompt(input, taskSettings)
    val finalPrompt = rolePrompt + prompt

    val expectedOutput = task.expectedOutput(input, taskSettings).getOrElse(
      throw new IllegalStateException(s"Expected output is not defined for input: $input")
    )

    val jsonContent = Json.obj(
      "prompt" -> s"$finalPrompt${actualSeparators.promptEnd}",
      "completion" -> s"${actualSeparators.completionStart}${expectedOutput} ${actualSeparators.completionEnd}"
    )

    Json.stringify(jsonContent)
  }
}

object FineTuningTrainingSetGenerator {

  def apply[S](
    task: CompletionTask[S],
    separators: Option[PromptCompletionSeparators] = None
  ): FineTuningTrainingSetGenerator[S] = new FineTuningTrainingSetGeneratorImpl[S](task, separators)
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy