All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.ml.tensorflow.TensorflowSerializeModel.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.ml.tensorflow

import com.johnsnowlabs.ml.tensorflow.sentencepiece.LoadSentencepiece
import com.johnsnowlabs.nlp.annotators.ner.dl.LoadsContrib
import com.johnsnowlabs.util.FileHelper
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession

import java.io.File
import java.nio.file.{Files, Paths}
import java.util.UUID

trait WriteTensorflowModel {

  def writeTensorflowModel(
      path: String,
      spark: SparkSession,
      tensorflow: TensorflowWrapper,
      suffix: String,
      filename: String,
      configProtoBytes: Option[Array[Byte]] = None): Unit = {

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    val tfFile = Paths.get(tmpFolder, filename).toString

    // 2. Save Tensorflow state
    tensorflow.saveToFile(tfFile, configProtoBytes)

    // 3. Copy to dest folder
    fs.copyFromLocalFile(new Path(tfFile), new Path(path))

    // 4. Remove tmp folder
    FileUtils.deleteDirectory(new File(tmpFolder))
  }

  def writeTensorflowModelV2(
      path: String,
      spark: SparkSession,
      tensorflow: TensorflowWrapper,
      suffix: String,
      filename: String,
      configProtoBytes: Option[Array[Byte]] = None,
      savedSignatures: Option[Map[String, String]] = None): Unit = {

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    val tfFile = Paths.get(tmpFolder, filename).toString

    // 2. Save Tensorflow state
    tensorflow.saveToFileV1V2(tfFile, configProtoBytes, savedSignatures = savedSignatures)

    // 3. Copy to dest folder
    fs.copyFromLocalFile(new Path(tfFile), new Path(path))

    // 4. Remove tmp folder
    FileUtils.deleteDirectory(new File(tmpFolder))
  }

  def writeTensorflowHub(
      path: String,
      tfPath: String,
      spark: SparkSession,
      suffix: String = "_use"): Unit = {

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp folder
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    // 2. Get the paths to saved_model.pb and variables directory
    val savedModelPath = Paths.get(tfPath, "saved_model.pb").toString
    val variableFilesPath = Paths.get(tfPath, "variables").toString

    // 3. Copy to dest folder
    fs.copyFromLocalFile(new Path(savedModelPath), new Path(path))
    fs.copyFromLocalFile(new Path(variableFilesPath), new Path(path))

    // 4. Remove tmp folder
    FileUtils.deleteDirectory(new File(tmpFolder))
  }

}

trait ReadTensorflowModel {
  val tfFile: String

  def readTensorflowModel(
      path: String,
      spark: SparkSession,
      suffix: String,
      zipped: Boolean = true,
      useBundle: Boolean = false,
      tags: Array[String] = Array.empty,
      initAllTables: Boolean = false,
      savedSignatures: Option[Map[String, String]] = None): TensorflowWrapper = {

    LoadsContrib.loadContribToCluster(spark)

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp directory
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    // 2. Copy to local dir
    fs.copyToLocalFile(new Path(path, tfFile), new Path(tmpFolder))

    // 3. Read Tensorflow state
    val (tf, _) = TensorflowWrapper.read(
      new Path(tmpFolder, tfFile).toString,
      zipped,
      tags = tags,
      useBundle = useBundle,
      initAllTables = initAllTables,
      savedSignatures = savedSignatures)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)

    tf
  }

  def readTensorflowWithSPModel(
      path: String,
      spark: SparkSession,
      suffix: String,
      zipped: Boolean = true,
      useBundle: Boolean = false,
      tags: Array[String] = Array.empty,
      initAllTables: Boolean = false,
      loadSP: Boolean = false): TensorflowWrapper = {

    if (loadSP) {
      LoadSentencepiece.loadSPToCluster(spark)
    }

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp directory
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    // 2. Copy to local dir
    fs.copyToLocalFile(new Path(path, tfFile), new Path(tmpFolder))

    // 3. Read Tensorflow state
    val tf = TensorflowWrapper.readWithSP(
      new Path(tmpFolder, tfFile).toString,
      zipped,
      tags = tags,
      useBundle = useBundle,
      initAllTables = initAllTables,
      loadSP = loadSP)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)

    tf
  }

  def readTensorflowChkPoints(
      path: String,
      spark: SparkSession,
      suffix: String,
      zipped: Boolean = true,
      tags: Array[String] = Array.empty,
      initAllTables: Boolean = false): TensorflowWrapper = {

    LoadsContrib.loadContribToCluster(spark)

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp directory
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    // 2. Copy to local dir
    fs.copyToLocalFile(new Path(path, tfFile), new Path(tmpFolder))

    // 3. Read Tensorflow state
    val tf = TensorflowWrapper.readChkPoints(
      new Path(tmpFolder, tfFile).toString,
      zipped,
      tags = tags,
      initAllTables = initAllTables)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)

    tf
  }

  def readTensorflowHub(
      path: String,
      spark: SparkSession,
      suffix: String,
      zipped: Boolean = true,
      useBundle: Boolean = false,
      tags: Array[String] = Array.empty): TensorflowWrapper = {

    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
    val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

    // 1. Create tmp directory
    val tmpFolder = Files
      .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
      .toAbsolutePath
      .toString

    // 2. Copy to local dir
    fs.copyToLocalFile(new Path(path, "saved_model.pb"), new Path(tmpFolder))
    fs.copyToLocalFile(new Path(path, "variables"), new Path(tmpFolder))

    // 3. Read Tensorflow state
    val (tf, _) = TensorflowWrapper.read(
      new Path(tmpFolder).toString,
      zipped,
      tags = tags,
      useBundle = useBundle)

    // 4. Remove tmp folder
    FileHelper.delete(tmpFolder)

    tf
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy