All Downloads are FREE. Search and download functionalities are using the official Maven repository.

breeze.classify.Trainer.scala Maven / Gradle / Ivy

package breeze.classify

import breeze.config._
import breeze.data.SparseFeatureDataset
import io.Source
import breeze.stats.ContingencyStats
import breeze.serialization.DataSerialization
import java.io._
import breeze.linalg._
import breeze.math.VectorSpace

case class TrainerParams(
  @Help(text="The kind of classifier to train. {Logistic,SVM,Pegasos}") `type`: String= "Logistic",
  @Help(text="Input file in svm light format.") input: File= new java.io.File("train"),
  @Help(text="Prints this") help:Boolean = false)

/**
 * Main program that builds a classifier from a [[breeze.data.SparseFeatureDataset]].
 * You can build a logistic classifier or an SVM, at the moment.
 *
 * @author dlwh
 */
object Trainer extends App {
  val config = CommandLineParser.parseArguments(args)._1
  val params = config.readIn[TrainerParams]("");
  if(params.help) {
    println(GenerateHelp[TrainerParams](config))
  } else {
    val input = SparseFeatureDataset.fromSource[Int](Source.fromFile(params.input),params.input.getName)
    type TheClassifier = LinearClassifier[Int,UnindexedLFMatrix[Int,SparseVector[Double]],Counter[Int,Double],SparseVector[Double]]

    val trainer:Classifier.Trainer[Int,SparseVector[Double]] { type MyClassifier = TheClassifier } = params.`type`.toLowerCase match {
      case "logistic" => new LogisticClassifier.Trainer[Int,SparseVector[Double]]
      case "svm" => new SVM.SMOTrainer[Int,SparseVector[Double]]()
//      case "pegasos" => new SVM.Pegasos[Int,SparseVector[Double]](30 * input.examples.length)
    }
    val classifier = trainer.train(input.examples)
    println("Performance on training set: ")
    println(ContingencyStats(classifier,input.examples))

  }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy