All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.HasLlamaCppInferenceProperties.scala Maven / Gradle / Ivy

There is a newer version: 6.0.3
Show newest version
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.annotators.seq2seq.AutoGGUFModel
import com.johnsnowlabs.nlp.llama.InferenceParameters
import com.johnsnowlabs.nlp.llama.args._
import com.johnsnowlabs.nlp.serialization.StructFeature
import org.apache.spark.ml.param._

import scala.collection.mutable
import scala.jdk.CollectionConverters._

/** Contains settable inference parameters for the [[AutoGGUFModel]].
  *
  * @groupname param Parameters
  * @groupname setParam Parameter setters
  * @groupname getParam Parameter getters
  * @groupprio setParam  1
  * @groupprio getParam  2
  * @groupprio param  3
  * @groupdesc param
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
  *   parameter values through setters and getters, respectively.
  */
trait HasLlamaCppInferenceProperties {
  this: ParamsAndFeaturesWritable with HasProtectedParams =>

  /** @group param */
  val inputPrefix =
    new Param[String](this, "inputPrefix", "Set the prompt to start generation with")

  /** @group param */
  val inputSuffix =
    new Param[String](this, "inputSuffix", "Set a suffix for infilling")

  /** @group param */
  val cachePrompt = new BooleanParam(
    this,
    "cachePrompt",
    "Whether to remember the prompt to avoid reprocessing it")

  /** @group param */
  val nPredict = new IntParam(this, "nPredict", "Set the number of tokens to predict")

  /** @group param */
  val topK = new IntParam(this, "topK", "Set top-k sampling")

  /** @group param */
  val topP = new FloatParam(this, "topP", "Set top-p sampling")

  /** @group param */
  val minP = new FloatParam(this, "minP", "Set min-p sampling")

  /** @group param */
  val tfsZ = new FloatParam(this, "tfsZ", "Set tail free sampling, parameter z")

  /** @group param */
  val typicalP = new FloatParam(this, "typicalP", "Set locally typical sampling, parameter p")

  /** @group param */
  val temperature = new FloatParam(this, "temperature", "Set the temperature")

  /** @group param */
  val dynamicTemperatureRange =
    new FloatParam(this, "dynatempRange", "Set the dynamic temperature range")

  /** @group param */
  val dynamicTemperatureExponent =
    new FloatParam(this, "dynatempExponent", "Set the dynamic temperature exponent")

  /** @group param */
  val repeatLastN =
    new IntParam(this, "repeatLastN", "Set the last n tokens to consider for penalties")

  /** @group param */
  val repeatPenalty =
    new FloatParam(this, "repeatPenalty", "Set the penalty of repeated sequences of tokens")

  /** @group param */
  val frequencyPenalty =
    new FloatParam(this, "frequencyPenalty", "Set the repetition alpha frequency penalty")

  /** @group param */
  val presencePenalty =
    new FloatParam(this, "presencePenalty", "Set the repetition alpha presence penalty")

  /** @group param */
  val miroStat = new Param[String](this, "miroStat", "Set MiroStat sampling strategies.")

  /** @group param */
  val miroStatTau =
    new FloatParam(this, "mirostatTau", "Set the MiroStat target entropy, parameter tau")

  /** @group param */
  val miroStatEta =
    new FloatParam(this, "mirostatEta", "Set the MiroStat learning rate, parameter eta")

  /** @group param */
  val penalizeNl = new BooleanParam(this, "penalizeNl", "Whether to penalize newline tokens")

  /** @group param */
  val nKeep =
    new IntParam(this, "nKeep", "Set the number of tokens to keep from the initial prompt")

  /** @group param */
  val seed = new IntParam(this, "seed", "Set the RNG seed")

  /** @group param */
  val nProbs = new IntParam(
    this,
    "nProbs",
    "Set the amount top tokens probabilities to output if greater than 0.")

  /** @group param */
  val minKeep = new IntParam(
    this,
    "minKeep",
    "Set the amount of tokens the samplers should return at least (0 = disabled)")

  /** @group param */
  val grammar =
    new Param[String](this, "grammar", "Set BNF-like grammar to constrain generations")

  /** @group param */
  val penaltyPrompt = new Param[String](
    this,
    "penaltyPrompt",
    "Override which part of the prompt is penalized for repetition.")

  /** @group param */
  val ignoreEos = new BooleanParam(
    this,
    "ignoreEos",
    "Set whether to ignore end of stream token and continue generating (implies --logit-bias 2-inf)")

  // Modify the likelihood of tokens appearing in the completion by their id.
  val tokenIdBias: StructFeature[Map[Int, Float]] =
    new StructFeature[Map[Int, Float]](this, "tokenIdBias")

  // Modify the likelihood of tokens appearing in the completion by their string.
  /** @group param */
  val tokenBias: StructFeature[Map[String, Float]] =
    new StructFeature[Map[String, Float]](this, "tokenBias")

  /** @group param */
  val disableTokenIds =
    new IntArrayParam(this, "disableTokenIds", "Set the token ids to disable in the completion")

  /** @group param */
  val stopStrings = new StringArrayParam(
    this,
    "stopStrings",
    "Set strings upon seeing which token generation is stopped")

  /** @group param */
  val samplers = new StringArrayParam(
    this,
    "samplers",
    "Set which samplers to use for token generation in the given order")

  /** @group param */
  val useChatTemplate = new BooleanParam(
    this,
    "useChatTemplate",
    "Set whether or not generate should apply a chat template")

  /** Set the prompt to start generation with
    *
    * @group setParam
    */
  def setInputPrefix(inputPrefix: String): this.type = { set(this.inputPrefix, inputPrefix) }

  /** Set a suffix for infilling
    *
    * @group setParam
    */
  def setInputSuffix(inputSuffix: String): this.type = { set(this.inputSuffix, inputSuffix) }

  /** Whether to remember the prompt to avoid reprocessing it
    *
    * @group setParam
    */
  def setCachePrompt(cachePrompt: Boolean): this.type = { set(this.cachePrompt, cachePrompt) }

  /** Set the number of tokens to predict
    *
    * @group setParam
    */
  def setNPredict(nPredict: Int): this.type = { set(this.nPredict, nPredict) }

  /** Set top-k sampling
    *
    * @group setParam
    */
  def setTopK(topK: Int): this.type = { set(this.topK, topK) }

  /** Set top-p sampling
    *
    * @group setParam
    */
  def setTopP(topP: Float): this.type = { set(this.topP, topP) }

  /** Set min-p sampling
    *
    * @group setParam
    */
  def setMinP(minP: Float): this.type = { set(this.minP, minP) }

  /** Set tail free sampling, parameter z
    * @group setParam
    */
  def setTfsZ(tfsZ: Float): this.type = { set(this.tfsZ, tfsZ) }

  /** Set locally typical sampling, parameter p
    *
    * @group setParam
    */
  def setTypicalP(typicalP: Float): this.type = { set(this.typicalP, typicalP) }

  /** Set the temperature
    *
    * @group setParam
    */
  def setTemperature(temperature: Float): this.type = { set(this.temperature, temperature) }

  /** Set the dynamic temperature range
    *
    * @group setParam
    */
  def setDynamicTemperatureRange(dynatempRange: Float): this.type = {
    set(this.dynamicTemperatureRange, dynatempRange)
  }

  /** Set the dynamic temperature exponent
    *
    * @group setParam
    */
  def setDynamicTemperatureExponent(dynatempExponent: Float): this.type = {
    set(this.dynamicTemperatureExponent, dynatempExponent)
  }

  /** Set the last n tokens to consider for penalties
    *
    * @group setParam
    */
  def setRepeatLastN(repeatLastN: Int): this.type = { set(this.repeatLastN, repeatLastN) }

  /** Set the penalty of repeated sequences of tokens
    *
    * @group setParam
    */
  def setRepeatPenalty(repeatPenalty: Float): this.type = {
    set(this.repeatPenalty, repeatPenalty)
  }

  /** Set the repetition alpha frequency penalty
    *
    * @group setParam
    */
  def setFrequencyPenalty(frequencyPenalty: Float): this.type = {
    set(this.frequencyPenalty, frequencyPenalty)
  }

  /** Set the repetition alpha presence penalty
    *
    * @group setParam
    */
  def setPresencePenalty(presencePenalty: Float): this.type = {
    set(this.presencePenalty, presencePenalty)
  }

  /** Set MiroStat sampling strategies.
    *
    *   - DISABLED: No MiroStat
    *   - V1: MiroStat V1
    *   - V2: MiroStat V2
    *
    * @group setParam
    */
  def setMiroStat(mirostat: String): this.type = set(this.miroStat, mirostat)

  /** Set the MiroStat target entropy, parameter tau
    *
    * @group setParam
    */
  def setMiroStatTau(mirostatTau: Float): this.type = { set(this.miroStatTau, mirostatTau) }

  /** Set the MiroStat learning rate, parameter eta
    *
    * @group setParam
    */
  def setMiroStatEta(mirostatEta: Float): this.type = { set(this.miroStatEta, mirostatEta) }

  /** Set whether to penalize newline tokens
    *
    * @group setParam
    */
  def setPenalizeNl(penalizeNl: Boolean): this.type = { set(this.penalizeNl, penalizeNl) }

  /** Set the number of tokens to keep from the initial prompt
    *
    * @group setParam
    */
  def setNKeep(nKeep: Int): this.type = { set(this.nKeep, nKeep) }

  /** Set the RNG seed
    *
    * @group setParam
    */
  def setSeed(seed: Int): this.type = { set(this.seed, seed) }

  /** Set the amount top tokens probabilities to output if greater than 0.
    *
    * @group setParam
    */
  def setNProbs(nProbs: Int): this.type = { set(this.nProbs, nProbs) }

  /** Set the amount of tokens the samplers should return at least (0 = disabled)
    *
    * @group setParam
    */
  def setMinKeep(minKeep: Int): this.type = { set(this.minKeep, minKeep) }

  /** Set BNF-like grammar to constrain generations
    *
    * @group setParam
    */
  def setGrammar(grammar: String): this.type = { set(this.grammar, grammar) }

  /** Override which part of the prompt is penalized for repetition.
    *
    * @group setParam
    */
  def setPenaltyPrompt(penaltyPrompt: String): this.type = {
    set(this.penaltyPrompt, penaltyPrompt)
  }

  /** Set whether to ignore end of stream token and continue generating (implies --logit-bias
    * 2-inf)
    *
    * @group setParam
    */
  def setIgnoreEos(ignoreEos: Boolean): this.type = { set(this.ignoreEos, ignoreEos) }

  /** Set the tokens to disable during completion.
    *
    * @group setParam
    */
  def setTokenBias(tokenBias: Map[String, Float]): this.type = {
    set(this.tokenBias, tokenBias)
  }

  /** Set the tokens to disable during completion. (Override for PySpark)
    *
    * @group setParam
    */
  def setTokenBias(tokenBias: java.util.HashMap[String, java.lang.Double]): this.type = {
    val scalaTokenBias = tokenBias.asScala.map { case (k, v) => k -> v.floatValue() }
    set(this.tokenBias, scalaTokenBias.toMap)
  }

  /** Set the token ids to disable in the completion.
    *
    * @group setParam
    */
  def setTokenIdBias(tokenIdBias: Map[Int, Float]): this.type = {
    set(this.tokenIdBias, tokenIdBias)
  }

  /** Set the token ids to disable in the completion. (Override for PySpark)
    *
    * @group setParam
    */
  def setTokenIdBias(tokenIdBias: java.util.HashMap[Integer, java.lang.Double]): this.type = {
    val scalaTokenIdBias = tokenIdBias.asScala.map { case (k, v) => k.toInt -> v.toFloat }
    set(this.tokenIdBias, scalaTokenIdBias.toMap)
  }

  /** Set the token ids to disable in the completion. This corresponds to `setTokenBias` with a
    * value of `Float.NEGATIVE_INFINITY`.
    *
    * @group setParam
    */
  def setDisableTokenIds(disableTokenIds: Array[Int]): this.type = {
    set(this.disableTokenIds, disableTokenIds)
  }

  /** Set strings upon seeing which token generation is stopped
    *
    * @group setParam
    */
  def setStopStrings(stopStrings: Array[String]): this.type = {
    set(this.stopStrings, stopStrings)
  }

  /** Set which samplers to use for token generation in the given order .
    *
    * Available Samplers are:
    *
    *   - TOP_K: Top-k sampling
    *   - TFS_Z: Tail free sampling
    *   - TYPICAL_P: Locally typical sampling p
    *   - TOP_P: Top-p sampling
    *   - MIN_P: Min-p sampling
    *   - TEMPERATURE: Temperature sampling
    * @group setParam
    */
  def setSamplers(samplers: Array[String]): this.type = { set(this.samplers, samplers) }

  /** Set whether or not generate should apply a chat template
    *
    * @group setParam
    */
  def setUseChatTemplate(useChatTemplate: Boolean): this.type = {
    set(this.useChatTemplate, useChatTemplate)
  }

  // ---------------- GETTERS ----------------
  /** @group getParam */
  def getInputPrefix: String = $(inputPrefix)

  /** @group getParam */
  def getInputSuffix: String = $(inputSuffix)

  /** @group getParam */
  def getCachePrompt: Boolean = $(cachePrompt)

  def getNPredict: Int = $(nPredict)

  /** @group getParam */
  def getTopK: Int = $(topK)

  /** @group getParam */
  def getTopP: Float = $(topP)

  /** @group getParam */
  def getMinP: Float = $(minP)

  /** @group getParam */
  def getTfsZ: Float = $(tfsZ)

  /** @group getParam */
  def getTypicalP: Float = $(typicalP)

  /** @group getParam */
  def getTemperature: Float = $(temperature)

  /** @group getParam */
  def getDynamicTemperatureRange: Float = $(dynamicTemperatureRange)

  /** @group getParam */
  def getDynamicTemperatureExponent: Float = $(dynamicTemperatureExponent)

  /** @group getParam */
  def getRepeatLastN: Int = $(repeatLastN)

  /** @group getParam */
  def getRepeatPenalty: Float = $(repeatPenalty)

  /** @group getParam */
  def getFrequencyPenalty: Float = $(frequencyPenalty)

  /** @group getParam */
  def getPresencePenalty: Float = $(presencePenalty)

  /** @group getParam */
  def getMiroStat: String = $(miroStat)

  /** @group getParam */
  def getMiroStatTau: Float = $(miroStatTau)

  /** @group getParam */
  def getMiroStatEta: Float = $(miroStatEta)

  /** @group getParam */
  def getPenalizeNl: Boolean = $(penalizeNl)

  /** @group getParam */
  def getNKeep: Int = $(nKeep)

  /** @group getParam */
  def getSeed: Int = $(seed)

  /** @group getParam */
  def getNProbs: Int = $(nProbs)

  /** @group getParam */
  def getMinKeep: Int = $(minKeep)

  /** @group getParam */
  def getGrammar: String = $(grammar)

  /** @group getParam */
  def getPenaltyPrompt: String = $(penaltyPrompt)

  /** @group getParam */
  def getIgnoreEos: Boolean = $(ignoreEos)

  /** @group getParam */
  def getTokenIdBias: Map[Int, Float] = $$(tokenIdBias)

  /** @group getParam */
  def getTokenBias: Map[String, Float] = $$(tokenBias)

  /** @group getParam */
  def getDisableTokenIds: Array[Int] = $(disableTokenIds)

  /** @group getParam */
  def getStopStrings: Array[String] = $(stopStrings)

  /** @group getParam */
  def getSamplers: Array[String] = $(samplers)

  /** @group getParam */
  def getUseChatTemplate: Boolean = $(useChatTemplate)

  protected def getInferenceParameters: InferenceParameters = {
    val inferenceParams = new InferenceParameters("")
    if (isDefined(cachePrompt)) inferenceParams.setCachePrompt(getCachePrompt)
    if (isDefined(disableTokenIds)) {
      val javaCollection: java.util.Collection[Integer] =
        getDisableTokenIds.map(int2Integer).toSeq.asJava
      inferenceParams.disableTokenIds(javaCollection)
    }
    if (isDefined(dynamicTemperatureExponent))
      inferenceParams.setDynamicTemperatureExponent(getDynamicTemperatureExponent)
    if (isDefined(dynamicTemperatureRange))
      inferenceParams.setDynamicTemperatureRange(getDynamicTemperatureRange)
    if (isDefined(frequencyPenalty)) inferenceParams.setFrequencyPenalty(getFrequencyPenalty)
    if (isDefined(grammar)) inferenceParams.setGrammar(getGrammar)
    if (isDefined(ignoreEos)) inferenceParams.setIgnoreEos(getIgnoreEos)
    if (isDefined(inputPrefix)) inferenceParams.setInputPrefix(getInputPrefix)
    if (isDefined(inputSuffix)) inferenceParams.setInputSuffix(getInputSuffix)
    if (isDefined(minKeep)) inferenceParams.setMinKeep(getMinKeep)
    if (isDefined(minP)) inferenceParams.setMinP(getMinP)
    if (isDefined(miroStat)) inferenceParams.setMiroStat(MiroStat.valueOf(getMiroStat))
    if (isDefined(miroStatEta)) inferenceParams.setMiroStatEta(getMiroStatEta)
    if (isDefined(miroStatTau)) inferenceParams.setMiroStatTau(getMiroStatTau)
    if (isDefined(nKeep)) inferenceParams.setNKeep(getNKeep)
    if (isDefined(nPredict)) inferenceParams.setNPredict(getNPredict)
    if (isDefined(nProbs)) inferenceParams.setNProbs(getNProbs)
    if (isDefined(penalizeNl)) inferenceParams.setPenalizeNl(getPenalizeNl)
    if (isDefined(penaltyPrompt)) inferenceParams.setPenaltyPrompt(getPenaltyPrompt)
    if (isDefined(presencePenalty)) inferenceParams.setPresencePenalty(getPresencePenalty)
    if (isDefined(repeatLastN)) inferenceParams.setRepeatLastN(getRepeatLastN)
    if (isDefined(repeatPenalty)) inferenceParams.setRepeatPenalty(getRepeatPenalty)
    if (isDefined(samplers)) inferenceParams.setSamplers(getSamplers.map(Sampler.valueOf): _*)
    if (isDefined(seed)) inferenceParams.setSeed(getSeed)
    if (isDefined(stopStrings)) inferenceParams.setStopStrings(getStopStrings: _*)
    if (isDefined(temperature)) inferenceParams.setTemperature(getTemperature)
    if (isDefined(tfsZ)) inferenceParams.setTfsZ(getTfsZ)
    if (isDefined(topK)) inferenceParams.setTopK(getTopK)
    if (isDefined(topP)) inferenceParams.setTopP(getTopP)
    if (isDefined(typicalP)) inferenceParams.setTypicalP(getTypicalP)
    if (isDefined(useChatTemplate)) inferenceParams.setUseChatTemplate(getUseChatTemplate)
    if (tokenBias.isSet) {
      val tokenBiasMap: mutable.Map[String, java.lang.Float] = mutable.Map(getTokenBias.map {
        case (key, value) => (key, float2Float(value))
      }.toSeq: _*)
      inferenceParams.setTokenBias(tokenBiasMap.asJava)
    }
    if (tokenIdBias.isSet) {
      val tokenIdBiasMap: mutable.Map[Integer, java.lang.Float] =
        mutable.Map(getTokenIdBias.map { case (key, value) =>
          (int2Integer(key), float2Float(value))
        }.toSeq: _*)
      inferenceParams.setTokenIdBias(tokenIdBiasMap.asJava)
    }

    inferenceParams
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy