io.cequence.openaiscala.task.FineTuningTrainingSetGenerator.scala Maven / Gradle / Ivy
The newest version!
package io.cequence.openaiscala.task
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.task.domain.PromptCompletionSeparators
import org.slf4j.LoggerFactory
import play.api.libs.json.Json
trait FineTuningTrainingSetGenerator[S] {
def generate(
count: Int,
taskSettings: S
): Source[String, _]
}
private class FineTuningTrainingSetGeneratorImpl[S](
task: CompletionTask[S],
separators: Option[PromptCompletionSeparators] = None
) extends FineTuningTrainingSetGenerator[S] {
protected val logger = LoggerFactory.getLogger("Fine-tuning Training Set Generator")
private val actualSeparators = separators.getOrElse(PromptCompletionSeparators.Default)
override def generate(
count: Int,
taskSettings: S
): Source[String, _] = {
val source = Source.fromIterator(() => (1 to count).iterator)
source.map(_ => generateOnce(taskSettings))
}
private def generateOnce(
taskSettings: S
) = {
val input = task.generateInput(taskSettings)
val rolePrompt = task.rolePrompt.map(_ + " ").getOrElse("")
val prompt = task.generatePrompt(input, taskSettings)
val finalPrompt = rolePrompt + prompt
val expectedOutput = task.expectedOutput(input, taskSettings).getOrElse(
throw new IllegalStateException(s"Expected output is not defined for input: $input")
)
val jsonContent = Json.obj(
"prompt" -> s"$finalPrompt${actualSeparators.promptEnd}",
"completion" -> s"${actualSeparators.completionStart}${expectedOutput} ${actualSeparators.completionEnd}"
)
Json.stringify(jsonContent)
}
}
object FineTuningTrainingSetGenerator {
def apply[S](
task: CompletionTask[S],
separators: Option[PromptCompletionSeparators] = None
): FineTuningTrainingSetGenerator[S] = new FineTuningTrainingSetGeneratorImpl[S](task, separators)
}