All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.example.udfpredictor.Utils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.example.udfpredictor

import java.io.{File, InputStream, PrintWriter}

import com.intel.analytics.bigdl.example.utils.WordMeta
import com.intel.analytics.bigdl.example.utils.TextClassifier
import com.intel.analytics.bigdl.models.utils.ModelBroadcast
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.nn.Module
import org.apache.spark.SparkContext

import scala.io.Source
import scopt.OptionParser


object Utils {

  type Model = AbstractModule[Activity, Activity, Float]
  type Word2Meta = Map[String, WordMeta]
  type Word2Index = Map[String, Int]
  type Word2Vec = Map[Float, Array[Float]]
  type SampleShape = Array[Int]
  type TFP = TextClassificationUDFParams

  case class Sample(filename: String, text: String)

  private var textClassification: TextClassifier = null

  def getTextClassifier(param: TextClassificationUDFParams): TextClassifier = {
    if (textClassification == null) {
      textClassification = new TextClassifier(param)
    }
    textClassification
  }

  def getModel(sc: SparkContext, param: TFP): (Model, Option[Word2Meta],
    Option[Word2Vec], SampleShape) = {
    val textClassification = getTextClassifier(param)
    if (param.modelPath.isDefined) {
      (Module.load[Float](param.modelPath.get),
        None,
        None,
        Array(param.maxSequenceLength, param.embeddingDim))
    } else {
      // get train and validation rdds
      val (rdds, word2Meta, word2Vec) = textClassification.getData(sc)
      // save word2Meta for later generate vectors
      val word2Index = word2Meta.mapValues[Int]((wordMeta: WordMeta) => wordMeta.index)
      sc.parallelize(word2Index.toSeq).saveAsTextFile(s"${param.baseDir}/word2Meta.txt")
      // train
      val trainedModel = textClassification.trainFromData(sc, rdds)
      // after training, save model
      if (param.checkpoint.isDefined) {
        trainedModel.save(s"${param.checkpoint.get}/model.1", overWrite = true)
      }

      (trainedModel.evaluate(),
        Some(word2Meta),
        Some(word2Vec),
        Array(param.maxSequenceLength, param.embeddingDim))
    }
  }

  def getWord2Vec(word2Index: Map[String, Int]): Map[Float, Array[Float]] = {
    val word2Vec = textClassification.buildWord2VecWithIndex(word2Index)
    word2Vec
  }

  def genUdf(sc: SparkContext,
             model: Model,
             sampleShape: Array[Int],
             word2Index: Word2Index,
             word2Vec: Word2Vec)
            (implicit ev: TensorNumeric[Float]): (String) => Int = {

    val broadcastModel = ModelBroadcast[Float]().broadcast(sc, model)
    val word2IndexBC = sc.broadcast(word2Index)
    val word2VecBC = sc.broadcast(word2Vec)

    val udf = (text: String) => {
      val sequenceLen = sampleShape(0)
      val embeddingDim = sampleShape(1)
      val word2Meta = word2IndexBC.value
      val word2Vec = word2VecBC.value
      // first to tokens
      val tokens = text.replaceAll("[^a-zA-Z]", " ")
        .toLowerCase().split("\\s+").filter(_.length > 2).map { word: String =>
        if (word2Meta.contains(word)) {
          Some(word2Meta(word).toFloat)
        } else {
          None
        }
      }.flatten

      // shaping
      val paddedTokens = if (tokens.length > sequenceLen) {
        tokens.slice(tokens.length - sequenceLen, tokens.length)
      } else {
        tokens ++ Array.fill[Float](sequenceLen - tokens.length)(0)
      }

      val data = paddedTokens.map { word: Float =>
        if (word2Vec.contains(word)) {
          word2Vec(word)
        } else {
          // Treat it as zeros if cannot be found from pre-trained word2Vec
          Array.fill[Float](embeddingDim)(0)
        }
      }.flatten

      val featureTensor: Tensor[Float] = Tensor[Float]()
      var featureData: Array[Float] = null
      val sampleSize = sampleShape.product
      val localModel = broadcastModel.value()

      // create tensor from input column
      if (featureData == null) {
        featureData = new Array[Float](1 * sampleSize)
      }
      Array.copy(data.map(ev.fromType(_)), 0,
        featureData, 0, sampleSize)
      featureTensor.set(Storage[Float](featureData), sizes = Array(1) ++ sampleShape)
      val tensorBuffer = featureTensor.transpose(2, 3)

      // predict
      val output = localModel.forward(tensorBuffer).toTensor[Float]
      val predict = if (output.dim == 2) {
        output.max(2)._2.squeeze().storage().array()
      } else if (output.dim == 1) {
        output.max(1)._2.squeeze().storage().array()
      } else {
        throw new IllegalArgumentException
      }
      ev.toType[Int](predict(0))
    }

    udf
  }

  def loadTestData(testDir: String): IndexedSeq[Sample] = {
    val fileList = new File(testDir).listFiles()
      .filter(_.isFile).filter(_.getName.forall(Character.isDigit)).sorted

    val testData = fileList.map { file => {
      val fileName = file.getName
      val source = Source.fromFile(file, "ISO-8859-1")
      val text = try source.getLines().toList.mkString("\n") finally source.close()
      Sample(fileName, text)
    }
    }
    testData
  }

  def getResourcePath(resource: String): String = {
    val stream: InputStream = getClass.getResourceAsStream(resource)
    val lines = scala.io.Source.fromInputStream(stream).mkString
    val file = File.createTempFile(resource, "")
    val pw = new PrintWriter(file)
    pw.write(lines)
    pw.close()
    file.getAbsolutePath
  }

  val localParser = new OptionParser[TextClassificationUDFParams]("BigDL Example") {
    opt[String]('b', "baseDir")
      .text("Base dir containing the training and word2Vec data")
      .action((x, c) => c.copy(baseDir = x))
    opt[String]('p', "partitionNum")
      .text("you may want to tune the partitionNum if run into spark mode")
      .action((x, c) => c.copy(partitionNum = x.toInt))
    opt[String]('s', "maxSequenceLength")
      .text("maxSequenceLength")
      .action((x, c) => c.copy(maxSequenceLength = x.toInt))
    opt[String]('w', "maxWordsNum")
      .text("maxWordsNum")
      .action((x, c) => c.copy(maxWordsNum = x.toInt))
    opt[String]('l', "trainingSplit")
      .text("trainingSplit")
      .action((x, c) => c.copy(trainingSplit = x.toDouble))
    opt[String]('z', "batchSize")
      .text("batchSize")
      .action((x, c) => c.copy(batchSize = x.toInt))
    opt[String]("modelPath")
      .text("where to load the model")
      .action((x, c) => c.copy(modelPath = Some(x)))
    opt[String]("checkpoint")
      .text("where to load the model")
      .action((x, c) => c.copy(checkpoint = Some(x)))
    opt[String]('f', "dataDir")
      .text("Text dir containing the text data")
      .action((x, c) => c.copy(testDir = x))
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy