All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.util.TrainingHelper.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.util

import com.johnsnowlabs.nlp.pretrained.ResourceType.ResourceType
import com.johnsnowlabs.nlp.pretrained.{ResourceMetadata, ResourceType}
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.util.MLWriter
import org.apache.spark.sql.Dataset

import java.io.File
import java.nio.file.Paths
import java.sql.Timestamp
import java.util.Date
import scala.util.Try

object TrainingHelper {

  def saveModel(
      name: String,
      language: Option[String],
      libVersion: Option[Version],
      sparkVersion: Option[Version],
      modelWriter: MLWriter,
      folder: String,
      category: Option[ResourceType] = Some(ResourceType.NOT_DEFINED)): Unit = {

    // 1. Get current timestamp
    val timestamp = new Timestamp(new Date().getTime)

    // 2. Create resource metadata
    val meta = new ResourceMetadata(
      name,
      language,
      libVersion,
      sparkVersion,
      true,
      timestamp,
      true,
      category = category)

    // 3. Save model to file
    val file = Paths.get(folder, meta.key).toString.replaceAllLiterally("\\", "/")
    modelWriter.save(file)

    // 4. Zip file
    val zipFile = Paths.get(folder, meta.fileName).toString
    ZipArchiveUtil.zip(file, zipFile)

    // 5. Remove original file
    FileUtils.forceDeleteOnExit(new File(file))

    // 6. Remove original file
    try {
      FileUtils.deleteDirectory(new File(file))
    } catch {
      case _: java.io.IOException => // file lock may prevent deletion, ignore and continue
    }

    // 7. Add to metadata.json info about resource
    val metadataFile = Paths.get(folder, "metadata.json").toString
    ResourceMetadata.addMetadataToFile(metadataFile, meta)
  }

  def hasColumn(dataSet: Dataset[_], column: String): Boolean = Try(dataSet(column)).isSuccess

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy