All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cequence.openaiscala.task.examples.RunCollectionZeroesOnesTask.scala Maven / Gradle / Ivy

The newest version!
package io.cequence.openaiscala.task.examples

import akka.actor.ActorSystem
import akka.stream.{ActorMaterializer, Materializer}
import io.cequence.openaiscala.domain.ModelId
import io.cequence.openaiscala.service.OpenAIServiceFactory
import io.cequence.openaiscala.task.binary.ConvertBinaryToDecimalTask
import io.cequence.openaiscala.task.domain.{BinaryTaskCoreSettings, TextCompletionTaskSettings}
import io.cequence.openaiscala.task.CompletionTaskExec

import scala.concurrent.ExecutionContext

object RunCollectionZeroesOnesTask extends App {

  // change implicits if needed
  private implicit val actorSystem = ActorSystem()
  private implicit val materializer: Materializer = ActorMaterializer()
  private implicit val ec: ExecutionContext = ExecutionContext.Implicits.global

  private val service = OpenAIServiceFactory()
  private val executor = CompletionTaskExec(service, new ConvertBinaryToDecimalTask(withAssistantSeedPrompt = false))

  // general settings
  private val settings = TextCompletionTaskSettings(
    repetitions = 5,

    model = ModelId.gpt_4,
    temperature = Some(0),
    max_tokens = Some(500),
  )

  // task specific settings
  private val taskSpecificSettings = BinaryTaskCoreSettings(
    minStringSize = 10,
    maxStringSize = 10,
    withSpaces = true,
    withExamples = true,
    useDensityUniformDistribution = true
  )

  {
    for {
      results <- executor.execute(settings, taskSpecificSettings)
    } yield {
      def collectAux(evalMetricsValue: Option[Int]) =
        results.filter(_.evalMetrics == evalMetricsValue)

      val trues = collectAux(Some(1))
      val falses = collectAux(Some(0))
      val dontKnows = collectAux(None)

      println(s"trues #: ${trues.size}, falses #: ${falses.size}, dontKnows #: ${dontKnows.size}")

      println("falses:")

      falses.foreach(x =>
        println(s"${x.input} -> ${x.output} vs ${x.expectedOutput.getOrElse("")}")
      )

      println("dontKnows:")

      dontKnows.foreach(x =>
        println(s"${x.input} -> ${x.output} vs ${x.expectedOutput.getOrElse("")}")
      )

      System.exit(0)
    }
  } recover {
    case e: Exception =>
      e.printStackTrace()
      System.exit(1)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy