All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.models.rnn.Test.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.models.rnn


import com.intel.analytics.bigdl.dataset.{DataSet, LocalDataSet, MiniBatch, SampleToMiniBatch}
import com.intel.analytics.bigdl.dataset.text.{Dictionary, LabeledSentence, LabeledSentenceToSample}
import com.intel.analytics.bigdl.nn.{Concat, Identity, LogSoftMax, Module}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Engine
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import scala.util.Random

object Test {
  Logger.getLogger("org").setLevel(Level.ERROR)
  Logger.getLogger("akka").setLevel(Level.ERROR)
  Logger.getLogger("breeze").setLevel(Level.ERROR)
  Logger.getLogger("com.intel.analytics.bigdl.optim").setLevel(Level.INFO)

  import Utils._
  val logger = Logger.getLogger(getClass)

  def main(args: Array[String]): Unit = {
    testParser.parse(args, new TestParams()).foreach { param =>

      val vocab = Dictionary(param.folder)
      val conf = Engine.createSparkConf()
        .setAppName("Test rnn on text")
        .set("spark.task.maxFailures", "1")
      val sc = new SparkContext(conf)
      Engine.init

      val model = Module.load[Float](param.modelSnapshot.get)

      val timeDim = 2
      val featDim = 3
      val concat = Tensor[Float]()
      val lines = readSentence(param.folder)
      val input = lines.map(x =>
        x.map(t => vocab.getIndex(t).toFloat))
      val labeledInput = input.map(x =>
        new LabeledSentence[Float](x, x))

      val vocabSize = vocab.getVocabSize() + 1
      val batchSize = param.batchSize

      val rdd = sc.parallelize(labeledInput).mapPartitions(iter =>
        LabeledSentenceToSample[Float](vocabSize).apply(iter)
      ).mapPartitions(iter =>
        SampleToMiniBatch[Float](batchSize).apply(iter)
      )

      val flow = rdd.mapPartitions(iter => {
        iter.map(batch => {
          var curInput = batch.getInput().toTensor[Float]
          // Iteratively output predicted words
          for (i <- 1 to param.numOfWords.getOrElse(0)) {
            val input = curInput.max(featDim)._2
            val output = model.forward(curInput).toTensor[Float]
            val predict = output.max(featDim)._2.select(timeDim, output.size(timeDim))
            concat.resize(curInput.size(1), curInput.size(timeDim) + 1, curInput.size(featDim))
            concat.narrow(timeDim, 1, curInput.size(timeDim)).copy(curInput)
            for (j <- 1 to curInput.size(1)) {
              concat.setValue(j, concat.size(timeDim), predict.valueAt(j, 1).toInt + 1, 1.0f)
            }
            curInput = concat
          }
          val predIdx = curInput.max(featDim)._2
          val predArray = new Array[Float](predIdx.nElement())
          Array.copy(predIdx.storage().array(), predIdx.storageOffset() - 1,
            predArray, 0, predIdx.nElement())
          predArray.grouped(predIdx.size(timeDim)).toArray[Array[Float]]
        })
      }).collect().flatMap(x => x)

      val results = flow.map(x => x.map(t => vocab.getWord(t)))
      results.foreach(x => logger.info(x.mkString(" ")))
      sc.stop()
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy