All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.impl.layer.Master Maven / Gradle / Ivy

package org.deeplearning4j.spark.impl.layer;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.canova.RDDMiniBatches;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

import java.io.Serializable;

/**
 * Master class for spark
 *
 * @author Adam Gibson
 */
public class Master implements Serializable {

    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private NeuralNetConfiguration conf;
    private RecordReader recordReader;



    public Master(SparkContext sparkContext,NeuralNetConfiguration conf,RecordReader recordReader) {
        this.sparkContext = sparkContext;
        this.conf = conf.clone();
        this.recordReader = recordReader;
        sc = new JavaSparkContext(this.sparkContext);
    }

    public Master(JavaSparkContext sc,NeuralNetConfiguration conf,RecordReader recordReader) {
        this.sc = sc;
        this.recordReader = recordReader;
        this.conf = conf.clone();
    }

    public Layer fit(String path,int labelIndex,int numLabels) {
        JavaRDD lines = sc.textFile(path);
        // gotta map this to a Matrix/INDArray
        JavaRDD points = lines.map(new RecordReaderFunction(recordReader
                , labelIndex, numLabels));
        return fit(points);

    }

    public Layer fit(JavaRDD rdd) {
        int batchSize = conf.getBatchSize();
        JavaRDD miniBatches = new RDDMiniBatches(batchSize,rdd).miniBatchesJava();
        Layer layer = conf.getLayerFactory().create(conf);
        INDArray params = layer.params();
        int paramsLength = layer.numParams();
        if(params.length() != paramsLength)
            throw new IllegalStateException("Number of params " + paramsLength + " was not equal to " + params.length());
        INDArray newParams = miniBatches.map(new DL4jWorker(conf.toJson(),params)).reduce(new Function2() {
            @Override
            public INDArray call(INDArray v1, INDArray v2) throws Exception {
                return v1.add(v2);
            }
        }).divi(miniBatches.count());
        layer.setParameters(newParams);
        return layer;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy