All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.cequence.pineconescala.service.PineconeInferenceServiceImpl.scala Maven / Gradle / Ivy

The newest version!
package io.cequence.pineconescala.service

import akka.stream.Materializer
import io.cequence.pineconescala.domain.response.{EvaluateResponse, GenerateEmbeddingsResponse, RerankResponse}
import io.cequence.pineconescala.domain.settings.{GenerateEmbeddingsSettings, RerankSettings}
import io.cequence.wsclient.ResponseImplicits._
import io.cequence.wsclient.service.ws.{PlayWSClientEngine, Timeouts}
import io.cequence.pineconescala.JsonFormats._
import io.cequence.pineconescala.PineconeScalaClientException
import io.cequence.wsclient.domain.WsRequestContext
import io.cequence.wsclient.service.WSClientEngine
import io.cequence.wsclient.service.WSClientWithEngineTypes.WSClientWithEngine
import play.api.libs.json.JsObject

import scala.concurrent.{ExecutionContext, Future}

private class PineconeInferenceServiceImpl(
  apiKey: String,
  explicitTimeouts: Option[Timeouts] = None
)(
  implicit val ec: ExecutionContext,
  val materializer: Materializer
) extends PineconeInferenceService
    with WSClientWithEngine {

  override protected type PEP = EndPoint
  override protected type PT = Tag

  private val regularURL = "api.pinecone.io/"
  private val prodURL = "prod-1-data.ke.pinecone.io/"

  // we use play-ws backend
  override protected val engine: WSClientEngine = PlayWSClientEngine(
    coreUrl = "https://", // TODO: change to regularURL eventually
    requestContext = WsRequestContext(
      authHeaders = Seq(
        ("Api-Key", apiKey),
        ("X-Pinecone-API-Version", "2024-10")
      ),
      explTimeouts = explicitTimeouts
    )
  )

  /**
   * Uses the specified model to generate embeddings for the input sequence.
   *
   * @param inputs
   *   Input sequence for which to generate embeddings.
   * @param settings
   * @return
   *   list of embeddings inside an envelope
   */
  override def createEmbeddings(
    inputs: Seq[String],
    settings: GenerateEmbeddingsSettings
  ): Future[GenerateEmbeddingsResponse] =
    execPOST(
      EndPoint.embed(regularURL),
      bodyParams = jsonBodyParams(
        Tag.inputs -> Some(
          inputs.map(input => Map("text" -> input))
        ),
        Tag.model -> Some(settings.model),
        Tag.parameters -> Some(
          Map(
            "input_type" -> settings.input_type.map(_.toString),
            "truncate" -> settings.truncate.toString
          )
        )
      )
    ).map(
      _.asSafeJson[GenerateEmbeddingsResponse]
    )

  /**
   * Using a reranker to rerank a list of items for a query.
   *
   * @param query
   *   The query to rerank documents against (required)
   * @param documents
   *   The documents to rerank (required)
   * @param settings
   * @return
   *
   * @see
   *   Pinecone
   *   Doc
   */
  override def rerank(
    query: String,
    documents: Seq[Map[String, Any]],
    settings: RerankSettings
  ): Future[RerankResponse] =
    execPOST(
      EndPoint.rerank(regularURL),
      bodyParams = jsonBodyParams(
        Tag.query -> Some(query),
        Tag.documents -> Some(documents),
        Tag.model -> Some(settings.model),
        Tag.top_n -> settings.top_n,
        Tag.return_documents -> Some(settings.return_documents),
        Tag.rank_fields -> (
          if (settings.rank_fields.nonEmpty) Some(settings.rank_fields) else None
        ),
        Tag.parameters -> (
          if (settings.parameters.nonEmpty) Some(settings.parameters) else None
        )
      )
    ).map(
      _.asSafeJson[RerankResponse]
    )

  override def evaluate(
    question: String,
    answer: String,
    groundTruthAnswer: String
  ): Future[EvaluateResponse] =
    execPOST(
      EndPoint.evaluate(prodURL),
      bodyParams = jsonBodyParams(
        Tag.question -> Some(question),
        Tag.answer -> Some(answer),
        Tag.ground_truth_answer -> Some(groundTruthAnswer)
      )
    ).map(
      _.asSafeJson[EvaluateResponse]
    )

  override protected def handleErrorCodes(
    httpCode: Int,
    message: String
  ): Nothing =
    throw new PineconeScalaClientException(s"Code ${httpCode} : ${message}")
}

object PineconeInferenceServiceFactory
    extends SimplePineconeServiceFactory[PineconeInferenceService] {

  override def apply(
    apiKey: String,
    timeouts: Option[Timeouts] = None
  )(
    implicit ec: ExecutionContext,
    materializer: Materializer
  ): PineconeInferenceService =
    new PineconeInferenceServiceImpl(apiKey, timeouts)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy