All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowWrapper.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.ml.tensorflow.io.ChunkBytes
import com.johnsnowlabs.ml.tensorflow.sentencepiece.LoadSentencepiece
import com.johnsnowlabs.ml.tensorflow.sign.ModelSignatureManager
import com.johnsnowlabs.nlp.annotators.ner.dl.LoadsContrib
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.util.{FileHelper, ZipArchiveUtil}
import org.apache.commons.io.FileUtils
import org.apache.commons.io.filefilter.WildcardFileFilter
import org.apache.hadoop.fs.Path
import org.slf4j.{Logger, LoggerFactory}
import org.tensorflow._
import org.tensorflow.exceptions.TensorFlowException
import org.tensorflow.proto.framework.{ConfigProto, GraphDef}

import java.io._
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.UUID
import scala.util.{Failure, Success, Try}

case class Variables(variables: Array[Array[Byte]], index: Array[Byte])

case class ModelSignature(operation: String, value: String, matchingPatterns: List[String])

class TensorflowWrapper(var variables: Variables, var graph: Array[Byte]) extends Serializable {

  /** For Deserialization */
  def this() = {
    this(null, null)
  }

  // Important for serialization on none-kyro serializers
  @transient private var m_session: Session = _
  @transient private val logger = LoggerFactory.getLogger("TensorflowWrapper")

  def getTFSession(configProtoBytes: Option[Array[Byte]] = None): Session = this.synchronized {

    if (m_session == null) {
      val t = new TensorResources()
      val config = configProtoBytes.getOrElse(TensorflowWrapper.TFSessionConfig)

      // save the binary data of variables to file - variables per se
      val path = Files.createTempDirectory(
        UUID.randomUUID().toString.takeRight(12) + TensorflowWrapper.TFVarsSuffix)
      val folder = path.toAbsolutePath.toString

      val varData = Paths.get(folder, TensorflowWrapper.VariablesPathValue)
      ChunkBytes.writeByteChunksInFile(varData, variables.variables)

      // save the binary data of variables to file - variables' index
      val varIdx = Paths.get(folder, TensorflowWrapper.VariablesIdxValue)
      Files.write(varIdx, variables.index)

      LoadsContrib.loadContribToTensorflow()

      // import the graph
      val _graph = new Graph()
      _graph.importGraphDef(GraphDef.parseFrom(graph))

      // create the session and load the variables
      val session = new Session(_graph, ConfigProto.parseFrom(config))
      val variablesPath =
        Paths.get(folder, TensorflowWrapper.VariablesKey).toAbsolutePath.toString

      session.runner
        .addTarget(TensorflowWrapper.SaveRestoreAllOP)
        .feed(TensorflowWrapper.SaveConstOP, t.createTensor(variablesPath))
        .run()

      // delete variable files
      Files.delete(varData)
      Files.delete(varIdx)

      m_session = session
    }
    m_session
  }

  def getTFSessionWithSignature(
      configProtoBytes: Option[Array[Byte]] = None,
      initAllTables: Boolean = true,
      loadSP: Boolean = false,
      savedSignatures: Option[Map[String, String]] = None): Session = this.synchronized {

    if (m_session == null) {
      val t = new TensorResources()
      val config = configProtoBytes.getOrElse(TensorflowWrapper.TFSessionConfig)

      // save the binary data of variables to file - variables per se
      val path = Files.createTempDirectory(
        UUID.randomUUID().toString.takeRight(12) + TensorflowWrapper.TFVarsSuffix)
      val folder = path.toAbsolutePath.toString
      val varData = Paths.get(folder, TensorflowWrapper.VariablesPathValue)
      ChunkBytes.writeByteChunksInFile(varData, variables.variables)

      // save the binary data of variables to file - variables' index
      val varIdx = Paths.get(folder, TensorflowWrapper.VariablesIdxValue)
      Files.write(varIdx, variables.index)

      LoadsContrib.loadContribToTensorflow()
      if (loadSP) {
        LoadSentencepiece.loadSPToTensorflowLocally()
        LoadSentencepiece.loadSPToTensorflow()
      }
      // import the graph
      val g = new Graph()
      g.importGraphDef(GraphDef.parseFrom(graph))

      // create the session and load the variables
      val session = new Session(g, ConfigProto.parseFrom(config))

      /** a workaround to fix the issue with '''asset_path_initializer''' suggested at
        * https://github.com/tensorflow/java/issues/434 until we export models natively and not
        * just the GraphDef
        */
      try {
        session.initialize()
      } catch {
        case _: Exception => println("detect asset_path_initializer")
      }
      TensorflowWrapper
        .processInitAllTableOp(
          initAllTables,
          t,
          session,
          folder,
          TensorflowWrapper.VariablesKey,
          savedSignatures = savedSignatures)

      // delete variable files
      Files.delete(varData)
      Files.delete(varIdx)

      m_session = session
    }
    m_session
  }

  def createSession(configProtoBytes: Option[Array[Byte]] = None): Session = {

    if (m_session == null) {
      val config = configProtoBytes.getOrElse(TensorflowWrapper.TFSessionConfig)

      LoadsContrib.loadContribToTensorflow()

      // import the graph
      val g = new Graph()
      g.importGraphDef(GraphDef.parseFrom(graph))

      // create the session and load the variables
      val session = new Session(g, ConfigProto.parseFrom(config))

      m_session = session
    }
    m_session
  }

  def saveToFile(file: String, configProtoBytes: Option[Array[Byte]] = None): Unit = {
    val t = new TensorResources()

    // 1. Create tmp director
    val folder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_ner")
      .toAbsolutePath
      .toString

    val variablesFile = Paths.get(folder, TensorflowWrapper.VariablesKey).toString

    // 2. Save variables
    getTFSession(configProtoBytes).runner
      .addTarget(TensorflowWrapper.SaveControlDependenciesOP)
      .feed(TensorflowWrapper.SaveConstOP, t.createTensor(variablesFile))
      .run()

    // 3. Save Graph
    // val graphDef = graph.toGraphDef
    val graphFile = Paths.get(folder, TensorflowWrapper.SavedModelPB).toString
    FileUtils.writeByteArrayToFile(new File(graphFile), graph)

    // 4. Zip folder
    ZipArchiveUtil.zip(folder, file)

    // 5. Remove tmp directory
    FileHelper.delete(folder)
    t.clearTensors()
  }

  /*
   * saveToFileV2 is V2 compatible
   * */
  def saveToFileV1V2(
      file: String,
      configProtoBytes: Option[Array[Byte]] = None,
      savedSignatures: Option[Map[String, String]] = None): Unit = {

    val t = new TensorResources()
    val _tfSignatures: Map[String, String] =
      savedSignatures.getOrElse(ModelSignatureManager.apply())

    // 1. Create tmp director
    val folder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_ner")
      .toAbsolutePath
      .toString

    val variablesFile = Paths.get(folder, TensorflowWrapper.VariablesKey).toString

    // 2. Save variables
    def runSessionLegacy = {
      getTFSession(configProtoBytes).runner
        .addTarget(TensorflowWrapper.SaveControlDependenciesOP)
        .feed(TensorflowWrapper.SaveConstOP, t.createTensor(variablesFile))
        .run()
    }

    /** addTarget operation is the result of '''saverDef.getSaveTensorName()''' feed operation is
      * the result of '''saverDef.getFilenameTensorName()'''
      *
      * @return
      *   List[Tensor]
      */
    def runSessionNew = {
      getTFSession(configProtoBytes).runner
        .addTarget(_tfSignatures.getOrElse("saveTensorName_", "StatefulPartitionedCall_1"))
        .feed(
          _tfSignatures.getOrElse("filenameTensorName_", "saver_filename"),
          t.createTensor(variablesFile))
        .run()
    }

    Try(runSessionLegacy) match {
      case Success(_) => logger.debug("Running legacy session to save variables...")
      case Failure(_) => runSessionNew
    }

    // 3. Save Graph
    val graphFile = Paths.get(folder, TensorflowWrapper.SavedModelPB).toString
    FileUtils.writeByteArrayToFile(new File(graphFile), graph)

    val tfChkPointsVars = FileUtils
      .listFilesAndDirs(
        new File(folder),
        new WildcardFileFilter("part*"),
        new WildcardFileFilter("variables*"))
      .toArray()

    // TF2 Saved Model generate parts for variables on second save
    // This makes sure they are compatible with V1
    if (tfChkPointsVars.length > 3) {
      val variablesDir = tfChkPointsVars(1).toString

      val varData = Paths.get(folder, TensorflowWrapper.VariablesPathValue)
      ChunkBytes.writeByteChunksInFile(varData, variables.variables)

      val varIdx = Paths.get(folder, TensorflowWrapper.VariablesIdxValue)
      Files.write(varIdx, variables.index)

      FileHelper.delete(variablesDir)
    }

    // 4. Zip folder
    ZipArchiveUtil.zip(folder, file)

    // 5. Remove tmp directory
    FileHelper.delete(folder)
    t.clearTensors()
  }

}

/** Companion object */
object TensorflowWrapper {
  private[TensorflowWrapper] val logger: Logger = LoggerFactory.getLogger("TensorflowWrapper")

  /** log_device_placement=True, allow_soft_placement=True, gpu_options.allow_growth=True */
  private final val TFSessionConfig: Array[Byte] = Array[Byte](50, 2, 32, 1, 56, 1)

  // Variables
  val VariablesKey = "variables"
  val VariablesPathValue = "variables.data-00000-of-00001"
  val VariablesIdxValue = "variables.index"

  // Operations
  val InitAllTableOP = "init_all_tables"
  val SaveRestoreAllOP = "save/restore_all"
  val SaveConstOP = "save/Const"
  val SaveControlDependenciesOP = "save/control_dependency"

  // Model
  val SavedModelPB = "saved_model.pb"

  // TF vars suffix folder
  val TFVarsSuffix = "_tf_vars"

  // size of bytes store in each chunk/array
  // (Integer.MAX_VALUE - 8) * BUFFER_SIZE = store over 2 Petabytes
  private val BUFFER_SIZE = 1024 * 1024

  /** Utility method to load the TF saved model bundle */
  def withSafeSavedModelBundleLoader(
      tags: Array[String],
      savedModelDir: String): SavedModelBundle = {
    Try(SavedModelBundle.load(savedModelDir, tags: _*)) match {
      case Success(bundle) => bundle
      case Failure(s) =>
        throw new Exception(s"Could not retrieve the SavedModelBundle + ${s.printStackTrace()}")
    }
  }

  /** Utility method to load the TF saved model components without a provided bundle */
  private def unpackWithoutBundle(folder: String) = {
    val graph = readGraph(Paths.get(folder, SavedModelPB).toString)
    val session = new Session(graph, ConfigProto.parseFrom(TFSessionConfig))

    /** a workaround to fix the issue with '''asset_path_initializer''' suggested at
      * https://github.com/tensorflow/java/issues/434 until we export models natively and not just
      * the GraphDef
      */
    try {
      session.initialize()
    } catch {
      case _: Exception => println("detect asset_path_initializer")
    }
    val varPath = Paths.get(folder, VariablesPathValue)
    val idxPath = Paths.get(folder, VariablesIdxValue)
    (graph, session, varPath, idxPath)
  }

  /** Utility method to load the TF saved model components from a provided bundle */
  private def unpackFromBundle(folder: String, model: SavedModelBundle) = {
    val graph = model.graph()
    val session = model.session()
    val varPath = Paths.get(folder, VariablesKey, VariablesPathValue)
    val idxPath = Paths.get(folder, VariablesKey, VariablesIdxValue)
    (graph, session, varPath, idxPath)
  }

  /** Utility method to process init all table operation key */
  private def processInitAllTableOp(
      initAllTables: Boolean,
      tensorResources: TensorResources,
      session: Session,
      variablesDir: String,
      variablesKey: String = VariablesKey,
      savedSignatures: Option[Map[String, String]] = None) = {

    val _tfSignatures: Map[String, String] =
      savedSignatures.getOrElse(ModelSignatureManager.apply())

    lazy val legacySessionRunner = session.runner
      .addTarget(SaveRestoreAllOP)
      .feed(
        SaveConstOP,
        tensorResources.createTensor(Paths.get(variablesDir, variablesKey).toString))

    /** addTarget operation is the result of '''saverDef.getRestoreOpName()''' feed operation is
      * the result of '''saverDef.getFilenameTensorName()'''
      */
    lazy val newSessionRunner = session.runner
      .addTarget(_tfSignatures.getOrElse("restoreOpName_", "StatefulPartitionedCall_2"))
      .feed(
        _tfSignatures.getOrElse("filenameTensorName_", "saver_filename"),
        tensorResources.createTensor(Paths.get(variablesDir, variablesKey).toString))

    def runRestoreNewNoInit = {
      newSessionRunner.run()
    }

    def runRestoreNewInit = {
      newSessionRunner.addTarget(InitAllTableOP).run()
    }

    def runRestoreLegacyNoInit = {
      legacySessionRunner.run()
    }

    def runRestoreLegacyInit = {
      legacySessionRunner.addTarget(InitAllTableOP).run()
    }

    if (initAllTables) {
      Try(runRestoreLegacyInit) match {
        case Success(_) => logger.debug("Running restore legacy with init...")
        case Failure(_) => runRestoreNewInit
      }
    } else {
      Try(runRestoreLegacyNoInit) match {
        case Success(_) => logger.debug("Running restore legacy with no init...")
        case Failure(_) => runRestoreNewNoInit
      }
    }
  }

  /** Utility method to load a Graph from path */
  def readGraph(graphFile: String): Graph = {
    val graphBytesDef = FileUtils.readFileToByteArray(new File(graphFile))
    val graph = new Graph()
    try {
      graph.importGraphDef(GraphDef.parseFrom(graphBytesDef))
    } catch {
      case e: TensorFlowException
          if e.getMessage.contains("Op type not registered 'BlockLSTM'") =>
        throw new UnsupportedOperationException("Spark NLP tried to load a TensorFlow Graph using Contrib module, but" +
          " failed to load it on this system. If you are on Windows, please follow the correct steps for setup: " +
          "https://github.com/JohnSnowLabs/spark-nlp/issues/1022" +
          s" If not the case, please report this issue. Original error message:\n\n${e.getMessage}")
    }
    graph
  }

  /** Read method to create tmp folder, unpack archive and read file as SavedModelBundle
    *
    * @param file
    *   : the file to read
    * @param zipped
    *   : boolean flag to know if compression is applied
    * @param useBundle
    *   : whether to use the SaveModelBundle object to parse the TF saved model
    * @param tags
    *   : tags to retrieve on the model bundle
    * @param initAllTables
    *   : boolean flag whether to retrieve the TF init operation
    * @return
    *   Returns a greeting based on the `name` field.
    */
  def read(
      file: String,
      zipped: Boolean = true,
      useBundle: Boolean = false,
      tags: Array[String] = Array.empty[String],
      initAllTables: Boolean = false,
      savedSignatures: Option[Map[String, String]] = None)
      : (TensorflowWrapper, Option[Map[String, String]]) = {

    val t = new TensorResources()

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_ner")
      .toAbsolutePath
      .toString

    // 2. Unpack archive
    val folder =
      if (zipped)
        ZipArchiveUtil.unzip(new File(file), Some(tmpFolder))
      else
        file

    LoadsContrib.loadContribToTensorflow()

    // 3. Read file as SavedModelBundle
    val (graph, session, varPath, idxPath, signatures) =
      if (useBundle) {
        val model: SavedModelBundle =
          withSafeSavedModelBundleLoader(tags = tags, savedModelDir = folder)
        val (graph, session, varPath, idxPath) = unpackFromBundle(folder, model)
        if (initAllTables) session.runner().addTarget(InitAllTableOP)

        // Extract saved model signatures
        val saverDef = model.metaGraphDef().getSaverDef
        val signatures = ModelSignatureManager.extractSignatures(model, saverDef)

        (graph, session, varPath, idxPath, signatures)
      } else {
        val (graph, session, varPath, idxPath) = unpackWithoutBundle(folder)
        processInitAllTableOp(
          initAllTables,
          t,
          session,
          folder,
          savedSignatures = savedSignatures)

        (graph, session, varPath, idxPath, None)
      }

    val varBytes = ChunkBytes.readFileInByteChunks(varPath, BUFFER_SIZE)
    val idxBytes = Files.readAllBytes(idxPath)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)
    t.clearTensors()
    val tfWrapper =
      new TensorflowWrapper(Variables(varBytes, idxBytes), graph.toGraphDef.toByteArray)
    tfWrapper.m_session = session
    (tfWrapper, signatures)
  }

  def readWithSP(
      file: String,
      zipped: Boolean = true,
      useBundle: Boolean = false,
      tags: Array[String] = Array.empty[String],
      initAllTables: Boolean = false,
      loadSP: Boolean = false): TensorflowWrapper = {
    val t = new TensorResources()

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_ner")
      .toAbsolutePath
      .toString

    // 2. Unpack archive
    val folder =
      if (zipped)
        ZipArchiveUtil.unzip(new File(file), Some(tmpFolder))
      else
        file

    if (loadSP) {
      LoadSentencepiece.loadSPToTensorflowLocally()
      LoadSentencepiece.loadSPToTensorflow()
    }
    // 3. Read file as SavedModelBundle
    val (graph, session, varPath, idxPath) =
      if (useBundle) {
        val model: SavedModelBundle =
          withSafeSavedModelBundleLoader(tags = tags, savedModelDir = folder)
        val (graph, session, varPath, idxPath) = unpackFromBundle(folder, model)
        if (initAllTables) session.runner().addTarget(InitAllTableOP)
        (graph, session, varPath, idxPath)
      } else {
        val (graph, session, varPath, idxPath) = unpackWithoutBundle(folder)
        processInitAllTableOp(initAllTables, t, session, folder)
        (graph, session, varPath, idxPath)
      }

    val varBytes = ChunkBytes.readFileInByteChunks(varPath, BUFFER_SIZE)
    val idxBytes = Files.readAllBytes(idxPath)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)
    t.clearTensors()

    val tfWrapper =
      new TensorflowWrapper(Variables(varBytes, idxBytes), graph.toGraphDef.toByteArray)
    tfWrapper.m_session = session
    tfWrapper
  }

  def readZippedSavedModel(
      rootDir: String = "",
      fileName: String = "",
      tags: Array[String] = Array.empty[String],
      initAllTables: Boolean = false): TensorflowWrapper = {
    val tensorResources = new TensorResources()

    val listFiles = ResourceHelper.listResourceDirectory(rootDir)

    val path =
      if (listFiles.length > 1)
        s"${listFiles.head.split("/").head}/$fileName"
      else
        listFiles.head

    val uri = new URI(path.replaceAllLiterally("\\", "/"))

    val inputStream = ResourceHelper.getResourceStream(uri.toString)

    // 1. Create tmp folder
    val tmpFolder =
      Files
        .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_classifier_dl_zip")
        .toAbsolutePath
        .toString

    val zipFile = new File(tmpFolder, "tmp_classifier_dl.zip")

    Files.copy(inputStream, zipFile.toPath)

    // 2. Unpack archive
    val folder = ZipArchiveUtil.unzip(zipFile, Some(tmpFolder))

    // 3. Create second tmp folder
    val finalFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_classifier_dl")
      .toAbsolutePath
      .toString

    val variablesFile = Paths.get(finalFolder, VariablesKey).toAbsolutePath
    Files.createDirectory(variablesFile)

    // 4. Copy the saved_model.zip into tmp folder
    val savedModelInputStream =
      ResourceHelper.getResourceStream(new Path(folder, SavedModelPB).toString)
    val savedModelFile = new File(finalFolder, SavedModelPB)
    Files.copy(savedModelInputStream, savedModelFile.toPath)

    val varIndexInputStream =
      ResourceHelper.getResourceStream(new Path(folder, VariablesIdxValue).toString)
    val varIndexFile = new File(variablesFile.toString, VariablesIdxValue)
    Files.copy(varIndexInputStream, varIndexFile.toPath)

    val varDataInputStream =
      ResourceHelper.getResourceStream(new Path(folder, VariablesPathValue).toString)
    val varDataFile = new File(variablesFile.toString, VariablesPathValue)
    Files.copy(varDataInputStream, varDataFile.toPath)

    // 5. Read file as SavedModelBundle
    val model = withSafeSavedModelBundleLoader(tags = tags, savedModelDir = finalFolder)

    val (graph, session, varPath, idxPath) = unpackFromBundle(finalFolder, model)

    if (initAllTables) session.runner().addTarget(InitAllTableOP)

    val varBytes = ChunkBytes.readFileInByteChunks(varPath, BUFFER_SIZE)
    val idxBytes = Files.readAllBytes(idxPath)

    // 6. Remove tmp folder
    FileHelper.delete(tmpFolder)
    FileHelper.delete(finalFolder)
    FileHelper.delete(folder)
    tensorResources.clearTensors()

    val tfWrapper =
      new TensorflowWrapper(Variables(varBytes, idxBytes), graph.toGraphDef.toByteArray)
    tfWrapper.m_session = session
    tfWrapper
  }

  def readChkPoints(
      file: String,
      zipped: Boolean = true,
      tags: Array[String] = Array.empty[String],
      initAllTables: Boolean = false): TensorflowWrapper = {
    val t = new TensorResources()

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_ner")
      .toAbsolutePath
      .toString

    // 2. Unpack archive
    val folder =
      if (zipped)
        ZipArchiveUtil.unzip(new File(file), Some(tmpFolder))
      else
        file

    LoadsContrib.loadContribToTensorflow()

    val tfChkPointsVars = FileUtils
      .listFilesAndDirs(
        new File(folder),
        new WildcardFileFilter("part*"),
        new WildcardFileFilter("variables*"))
      .toArray()

    val variablesDir = tfChkPointsVars(1).toString
    val variablesData = tfChkPointsVars(2).toString
    val variablesIndex = tfChkPointsVars(3).toString

    // 3. Read file as SavedModelBundle
    val graph = readGraph(Paths.get(folder, SavedModelPB).toString)
    val session = new Session(graph, ConfigProto.parseFrom(TFSessionConfig))
    val varPath = Paths.get(variablesData)
    val idxPath = Paths.get(variablesIndex)

    processInitAllTableOp(
      initAllTables,
      t,
      session,
      variablesDir,
      variablesKey = "part-00000-of-00001")

    val varBytes = ChunkBytes.readFileInByteChunks(varPath, BUFFER_SIZE)
    val idxBytes = Files.readAllBytes(idxPath)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)
    t.clearTensors()

    val tfWrapper =
      new TensorflowWrapper(Variables(varBytes, idxBytes), graph.toGraphDef.toByteArray)
    tfWrapper.m_session = session
    tfWrapper
  }

  def extractVariablesSavedModel(session: Session): Variables = {
    val t = new TensorResources()

    val folder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + TFVarsSuffix)
      .toAbsolutePath
      .toString
    val variablesFile = Paths.get(folder, VariablesKey).toString

    session.runner
      .addTarget(SaveControlDependenciesOP)
      .feed(SaveConstOP, t.createTensor(variablesFile))
      .run()

    val varPath = Paths.get(folder, VariablesPathValue)
    val varBytes = ChunkBytes.readFileInByteChunks(varPath, BUFFER_SIZE)

    val idxPath = Paths.get(folder, VariablesIdxValue)
    val idxBytes = Files.readAllBytes(idxPath)

    val vars = Variables(varBytes, idxBytes)

    FileHelper.delete(folder)

    vars
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy