All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.iterativereduce.impl.multilayer.WorkerNode Maven / Gradle / Ivy

The newest version!
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.iterativereduce.impl.multilayer;

import java.util.List;



import org.apache.commons.lang3.time.StopWatch;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.RecordReader;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.iterativereduce.impl.reader.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.scaleout.api.ir.ParameterVectorUpdateable;
import org.deeplearning4j.iterativereduce.runtime.ComputableWorker;


import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;




/**
 * Base IterativeReduce worker node
 *
 * @author josh
 *
 */
public class WorkerNode implements ComputableWorker,DeepLearningConfigurable {

    private MultiLayerNetwork multiLayerNetwork;
    private static final Logger log = LoggerFactory.getLogger(WorkerNode.class);
    private RecordReader recordParser;
    private DataSetIterator hdfsDataSetIterator = null;
    private long totalRecordsProcessed = 0;
    private StopWatch totalRunTimeWatch = new StopWatch();
    private StopWatch batchWatch = new StopWatch();

    // confs (for now, may get rid of later)
    private int batchSize = 20;
    private int numberClasses = 2;
    private int labelIndex = -1;
    public final static String LABEL_INDEX = "org.deeplearning4j.labelindex";


    /**
     *
     * - app.iteration.count
     * 		- indicates how many times we're going to call the workers
     *
     */
    @Override
    public ParameterVectorUpdateable compute() {
        log.info("Worker > Compute() -------------------------- ");

        DataSet hdfsBatch = null;

        if ( this.hdfsDataSetIterator.hasNext() ) {
            hdfsBatch = this.hdfsDataSetIterator.next();
            if (hdfsBatch.getFeatures().rows() > 0) {

                log.info("Rows: " + hdfsBatch.numExamples() + ", inputs: " + hdfsBatch.numInputs() + ", " + hdfsBatch);


                // calc stats on number records processed
                this.totalRecordsProcessed += hdfsBatch.getFeatures().rows();
                batchWatch.reset();
                batchWatch.start();
                this.multiLayerNetwork.fit( hdfsBatch );
                batchWatch.stop();
                log.info("Worker > Processed Total " + this.totalRecordsProcessed + ", Batch Time " + batchWatch.toString() + " Total Time " + totalRunTimeWatch.toString());
            }
            else {
                // in case we get a blank line
                log.info("Worker > Idle pass, no records left to process");

            }

        }

        else
            log.info("Worker > Idle pass, no records left to process");




        return new ParameterVectorUpdateable(multiLayerNetwork.params());
    }

    @Override
    public void setRecordReader(RecordReader r) {
        this.recordParser = r;
        // int batchSize,int labelIndex,int numPossibleLabels
        this.hdfsDataSetIterator = new RecordReaderDataSetIterator(r,null,batchSize,labelIndex,numberClasses);
    }


    @Override
    public ParameterVectorUpdateable compute(List arg0) {
        return compute();

    }

    @Override
    public ParameterVectorUpdateable getResults() {
        return new ParameterVectorUpdateable(multiLayerNetwork.params());
    }



    /**
     * setup the local DBN instance based on conf params
     *
     *
     *
     */
    @Override
    public void setup(Configuration conf) {

        log.info("Worker-Conf: " + conf.get(MULTI_LAYER_CONF));

        MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson( conf.get(MULTI_LAYER_CONF));
        FeedForwardLayer outputLayer = (FeedForwardLayer) conf2.getConf(conf2.getConfs().size() - 1).getLayer();
        this.numberClasses = outputLayer.getNOut();
        labelIndex = conf.getInt(LABEL_INDEX,-1);
        if(labelIndex < 0)
            throw new IllegalStateException("Illegal label index");
        multiLayerNetwork = new MultiLayerNetwork(conf2);


    }

    /**
     * Collect the update from the master node and apply it to the local
     * parameter vector
     *
     * TODO: check the state changes of the incoming message!
     *
     */
    @Override
    public void update(ParameterVectorUpdateable masterUpdateUpdateable) {
        multiLayerNetwork.setParameters(masterUpdateUpdateable.get());
    }





    @Override
    public void setup(org.canova.api.conf.Configuration conf) {

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy