All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.gguf.GGUFWrapper.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2017-2024 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.johnsnowlabs.ml.gguf

import com.johnsnowlabs.nlp.llama.{LlamaModel, ModelParameters}
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}

import java.io.File
import java.nio.file.{Files, Paths}

class GGUFWrapper(var modelFileName: String, var modelFolder: String) extends Serializable {

  /** For Deserialization */
  def this() = {
    this(null, null)
  }

  // Important for serialization on none-kryo serializers
  @transient private var llamaModel: LlamaModel = _

  def getSession(modelParameters: ModelParameters): LlamaModel =
    this.synchronized {
      if (llamaModel == null) {
        // TODO: Validate when modelFileName or tmpFolder is None??
        val modelFilePath = SparkFiles.get(modelFileName)

        if (Paths.get(modelFilePath).toFile.exists()) {
          modelParameters.setModelFilePath(modelFilePath)
          llamaModel = GGUFWrapper.withSafeGGUFModelLoader(modelParameters)
        } else
          throw new IllegalStateException(
            s"Model file $modelFileName does not exist in SparkFiles.")
      }
      // TODO: if the model is already loaded then the model parameters will not apply. perhaps output a logline here.
      llamaModel
    }

  def saveToFile(file: String): Unit = {
    val modelFilePath = SparkFiles.get(modelFileName)
    val modelOutputPath = Paths.get(file, modelFileName)
    Files.copy(Paths.get(modelFilePath), modelOutputPath)
  }

  // Destructor to free the model when this object is garbage collected
  override def finalize(): Unit = {
    if (llamaModel != null) {
      llamaModel.close()
    }
  }

}

/** Companion object */
object GGUFWrapper {
  private[GGUFWrapper] val logger: Logger = LoggerFactory.getLogger("GGUFWrapper")

  // TODO: make sure this.synchronized is needed or it's not a bottleneck
  private def withSafeGGUFModelLoader(modelParameters: ModelParameters): LlamaModel =
    this.synchronized {
      new LlamaModel(modelParameters)
    }

  def read(sparkSession: SparkSession, modelPath: String): GGUFWrapper = {
    // TODO Better Sanity Check
    val modelFile = new File(modelPath)
    val modelFileExist: Boolean = modelFile.exists()

    if (!modelFile.getName.endsWith(".gguf"))
      throw new IllegalArgumentException(s"Model file $modelPath is not a GGUF model file")

    if (modelFileExist) {
      sparkSession.sparkContext.addFile(modelPath)
    } else throw new IllegalArgumentException(s"Model file $modelPath does not exist")

    new GGUFWrapper(modelFile.getName, modelFile.getParent)
  }

  def readModel(modelFolderPath: String, spark: SparkSession): GGUFWrapper = {
    def findGGUFModelInFolder(folderPath: String): String = {
      val folder = new File(folderPath)
      if (folder.exists && folder.isDirectory) {
        val ggufFile: String = folder.listFiles
          .filter(_.isFile)
          .filter(_.getName.endsWith(".gguf"))
          .map(_.getAbsolutePath)
          .headOption // Should only be one file
          .getOrElse(
            throw new IllegalArgumentException(s"Could not find GGUF model in $folderPath"))

        new File(ggufFile).getAbsolutePath
      } else {
        throw new IllegalArgumentException(s"Path $folderPath is not a directory")
      }
    }

    val uri = new java.net.URI(modelFolderPath.replaceAllLiterally("\\", "/"))
    // In case the path belongs to a different file system but doesn't have the scheme prepended (e.g. dbfs)
    val fileSystem: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
    val actualFolderPath = fileSystem.resolvePath(new Path(modelFolderPath)).toString
    val localFolder = ResourceHelper.copyToLocal(actualFolderPath)
    val modelFile = findGGUFModelInFolder(localFolder)
    read(spark, modelFile)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy