All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.ml.nn.TrainingStrategy.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2015 Skymind,Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package org.deeplearning4j.spark.ml.nn

import org.apache.spark.AccumulatorParam
import org.apache.spark.rdd.RDD
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j

import scala.collection.JavaConversions._

/**
 * An abstract training strategy for a distributed neural network, given partitions of data.
 * @tparam RowType
 *
 * @author Eron Wright
 */
private[spark] abstract class TrainingStrategy[RowType]{
  /**
   * Train a network in some way.
   * @param rdd input RDD
   * @param partitionTrainer a function to train a local MultiLayerNetwork with a partition's data
   * @return final parameters
   */
  def train(
       rdd: RDD[RowType], 
       partitionTrainer: (MultiLayerNetwork, Iterator[RowType]) => Unit): INDArray
}

/**
 * Train a distributed neural network using parameter averaging.   
 * 
 * Training occurs over some number of epochs.  In each epoch:
 * 1. broadcast the parameters to all workers.
 * 2. train a local network on each partition until the network converges (i.e. after some iterations).
 * 3. accumulate the final parameters from each partition and compute an average.
 *
 * @author Eron Wright
 */
private[spark] class ParameterAveragingTrainingStrategy[RowType](
    @transient val conf: MultiLayerConfiguration,
    epochs: Int)
  extends TrainingStrategy[RowType] {
 
   override def train(
       rdd: RDD[RowType], 
       partitionTrainer: (MultiLayerNetwork, Iterator[RowType]) => Unit): INDArray = {
     val sc = rdd.sparkContext
     val confJson = conf.toJson()
     
     def initialParams() = {
       val network = new MultiLayerNetwork(conf)
       network.init()
       network.setListeners(List(new ScoreIterationListener(1)))
       network.params()
     }
   
     var networkParams = initialParams()
     
     for(epoch <- (1 to epochs)) {
       val broadcastedParams = sc.broadcast(networkParams)
       val accumulatedParams = sc.accumulator(Nd4j.zeros(networkParams.shape(): _*))(INDArrayAccumulatorParam)
     
       rdd.foreachPartition { iterator =>
         @transient val conf: MultiLayerConfiguration = MultiLayerConfiguration.fromJson(confJson)

         val network = new MultiLayerNetwork(conf)
         network.init()
         network.setListeners(List(new ScoreIterationListener(1)))
         network.setParams(broadcastedParams.value.dup())
         
         partitionTrainer(network, iterator)
         
         accumulatedParams += network.params()
       }
       
       networkParams = accumulatedParams.value.divi(rdd.partitions.length)
     }
     
     networkParams
   }
}

private[spark] object INDArrayAccumulatorParam extends AccumulatorParam[INDArray] {
  def addInPlace(t1: INDArray, t2: INDArray): INDArray = { t1.addi(t2); t1 }
  def zero(initialValue: INDArray): INDArray = Nd4j.zeros(initialValue.shape(): _*)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy