com.intel.analytics.bigdl.example.textclassification.TextClassifier.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.example.textclassification
import com.intel.analytics.bigdl.example.utils._
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, _}
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter, T}
import org.apache.log4j.{Level => Levle4j, Logger => Logger4j}
import org.slf4j.{Logger, LoggerFactory}
import scopt.OptionParser
import scala.collection.mutable.{ArrayBuffer, Map => MMap}
import scala.language.existentials
object TextClassifier {
val log: Logger = LoggerFactory.getLogger(this.getClass)
LoggerFilter.redirectSparkInfoLogs()
Logger4j.getLogger("com.intel.analytics.bigdl.optim").setLevel(Levle4j.INFO)
def main(args: Array[String]): Unit = {
val localParser = new OptionParser[TextClassificationParams]("BigDL Example") {
opt[String]('b', "baseDir")
.required()
.text("Base dir containing the training and word2Vec data")
.action((x, c) => c.copy(baseDir = x))
opt[String]('p', "partitionNum")
.text("you may want to tune the partitionNum if run into spark mode")
.action((x, c) => c.copy(partitionNum = x.toInt))
opt[String]('s', "maxSequenceLength")
.text("maxSequenceLength")
.action((x, c) => c.copy(maxSequenceLength = x.toInt))
opt[String]('w', "maxWordsNum")
.text("maxWordsNum")
.action((x, c) => c.copy(maxWordsNum = x.toInt))
opt[String]('l', "trainingSplit")
.text("trainingSplit")
.action((x, c) => c.copy(trainingSplit = x.toDouble))
opt[String]('z', "batchSize")
.text("batchSize")
.action((x, c) => c.copy(batchSize = x.toInt))
opt[Int]('l', "learningRate")
.text("learningRate")
.action((x, c) => c.copy(learningRate = x))
}
localParser.parse(args, TextClassificationParams()).map { param =>
log.info(s"Current parameters: $param")
val textClassification = new TextClassifier(param)
textClassification.train()
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy