All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.ml.classification.MultiLayerNetworkClassification.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.deeplearning4j.spark.ml.classification

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, GenericRowWithSchema}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.conf.layers.{OutputLayer, FeedForwardLayer}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.deeplearning4j.spark.ml.nn.ParameterAveragingTrainingStrategy
import org.deeplearning4j.spark.ml.param.shared.{HasEpochs, HasMultiLayerConfiguration}
import org.deeplearning4j.spark.ml.util.Identifiable
import org.deeplearning4j.spark.util.conversions._
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.util.FeatureUtil

import scala.collection.JavaConversions

/*
 * Parameters for neural network classification.
 */
trait NeuralNetworkClassificationParams extends Params
  with HasMultiLayerConfiguration 
  with HasEpochs {
}

/**
 * Neural network-based learning algorithm for supervised classification.
 * 
 * This class is an estimator that produces a model.  Accepts a feature vector and 
 * a multiclass numeric label as input, and produces a probability vector and a predicted label as output.
 * 
 * Noteworthy parameters:
 *  - conf        - the multilayer configuration
 *  - epochs      - the number of full passes over the dataset, with convergence on each pass.
 *
 *  @author Eron Wright
 */
@DeveloperApi
class NeuralNetworkClassification(override val uid: String)
  extends Classifier[Vector, NeuralNetworkClassification, NeuralNetworkClassificationModel]
  with NeuralNetworkClassificationParams {

  def this() = this(Identifiable.randomUID("nnClassification"))

  /** @group setParam */
  def setConf(value: String): this.type = set(conf, value)
  def setConf(value: MultiLayerConfiguration): this.type = set(conf, value.toJson)

  /** @group setParam */
  def setEpochs(value: Int): this.type = set(epochs, value)
  setDefault(epochs -> 1)

  override protected def train(dataset: DataFrame): NeuralNetworkClassificationModel = {
    val sqlContext = dataset.sqlContext
    val sc = sqlContext.sparkContext
    
    // parameters
    @transient val c = MultiLayerConfiguration.fromJson($(conf))

    // prepare the dataset for classification
    val prepared = dataset.select($(labelCol), $(featuresCol))
    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
    if (handlePersistence) prepared.persist(StorageLevel.MEMORY_AND_DISK)

    // resolve the number of classes/outcomes
    val numClasses = c.getConf(c.getConfs.size() - 1).getLayer match {
      case layer: OutputLayer => layer.getNOut match {
        case 0 => {
          val numClasses = NominalAttribute.fromStructField(dataset.schema($(labelCol))) match {
            case (attr: NominalAttribute) => attr.getNumValues match {
              case Some(value: Int) => value
              case _ => throw new UnsupportedOperationException("expected numValues on nominal attribute")
            }
            case _ => throw new UnsupportedOperationException(s"column ${$(labelCol)} must be indexed")
          }
          layer.setNOut(numClasses)
          numClasses
        }
        case nOut => nOut
      }
      case _ => throw new UnsupportedOperationException(s"classification requires an output layer")
    }

    // devise a training strategy for the distributed neural network
    val trainingStrategy = new ParameterAveragingTrainingStrategy[Row](c, $(epochs))

    // train
    val networkParams = trainingStrategy.train(
        prepared.rdd, (network:MultiLayerNetwork, rows:Iterator[Row]) => {

          // features & labels
          val (features, labels) = rows.map { row =>
            (
              row.getAs[Vector](1): INDArray,
              FeatureUtil.toOutcomeVector(row.getDouble(0).toInt, numClasses)
              )
          }.toIterable.unzip

          if(features.size >= 1) {
            val featureMatrix = Nd4j.vstack(features.toArray: _*)
            val labelMatrix = Nd4j.vstack(labels.toArray: _*)

            network.fit(featureMatrix, labelMatrix)
          }
    })

    if (handlePersistence) prepared.unpersist()

    new NeuralNetworkClassificationModel(uid, numClasses, sc.broadcast(networkParams)).setParent(this)
  }
}

/**
 * Neural network-based classification model.
 *
 * @author Eron Wright
 */
@DeveloperApi
class NeuralNetworkClassificationModel private[ml] (
    override val uid: String,
    override val numClasses: Int,
    val networkParams: Broadcast[INDArray])
  extends ClassificationModel[Vector, NeuralNetworkClassificationModel]
  with NeuralNetworkClassificationParams {

  override protected def predictRaw(features: Vector): Vector = {
    throw new NotImplementedError()
  }

  override def transform(dataset: DataFrame): DataFrame = {

    val schema = transformSchema(dataset.schema, logging = true)

    val newRdd = dataset.mapPartitions { iterator:Iterator[Row] =>
      val network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson($(conf)))
      network.init()
      network.setParameters(networkParams.value)

      // prepare the input feature matrix, while retaining the rows for later use
      val featuresIndex = schema.fieldIndex(getFeaturesCol)
      val (features, rows) = iterator.map { row =>
        (
          row.getAs[Vector](featuresIndex): INDArray,
          row
        )
      }.toIterable.unzip

      // compute output
      val newRows = rows.size match {
        case 0 => Seq()
        case _ =>
          val featureMatrix = Nd4j.vstack(features.toArray: _*)
          val outputMatrix = network.output(featureMatrix, true)

          // prepare column generators for required columns
          val cols = {
            schema.fieldNames flatMap {
              case f if f == $(rawPredictionCol) => Seq(
                (row: Row, i: Int, output: Vector) => output)
              case f if f == $(predictionCol) => Seq(
                (row: Row, i: Int, output: Vector) => raw2prediction(output))
              case _ => Seq.empty
            }
          }

          // transform the input rows, appending required columns
          rows.zipWithIndex.map {
            case (row, i) => {
              val output = outputMatrix.getRow(i): Vector
              Row.fromSeq(row.toSeq ++ cols.map(_(row, i, output)))
            }
          }
      }

      newRows.iterator
    }

    dataset.sqlContext.createDataFrame(newRdd, schema)
  }

  override def copy(extra: ParamMap): NeuralNetworkClassificationModel = {
    copyValues(new NeuralNetworkClassificationModel(uid, numClasses, networkParams), extra)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy