All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.util.NetSaverLoaderUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.util;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.HashMap;
import java.util.Map;

/**
 * Utility to save and load network configuration and parameters.
 */

public class NetSaverLoaderUtils {
    private static final Logger log = LoggerFactory.getLogger(NetSaverLoaderUtils.class);

    private NetSaverLoaderUtils(){}

    /**
     * Save model configuration and parameters
     * @param net trained network | model
     * @param basePath path to store configuration
     */
    public static void saveNetworkAndParameters(MultiLayerNetwork net, String basePath) {
        String confPath = FilenameUtils.concat(basePath, net.toString()+"-conf.json");
        String paramPath = FilenameUtils.concat(basePath, net.toString() + ".bin");
        log.info("Saving model and parameters to {} and {} ...",  confPath, paramPath);

        // save parameters
        try(DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(paramPath)))) {
            Nd4j.write(net.params(), dos);
            dos.flush();

            // save model configuration
            FileUtils.write(new File(confPath), net.getLayerWiseConfigurations().toJson());
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * Load existing model configuration and parameters
     * @param confPath string path where model configuration is stored
     * @param paramPath string path where parameters are stored
     */
    public static MultiLayerNetwork loadNetworkAndParameters(String confPath, String paramPath) {
        log.info("Loading saved model and parameters...");
        MultiLayerNetwork savedNetwork = null;
        // load parameters
        try {
            MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(confPath);
            DataInputStream dis = new DataInputStream(new FileInputStream(paramPath));
            INDArray newParams = Nd4j.read(dis);
            dis.close();

            // load model configuration
            savedNetwork = new MultiLayerNetwork(confFromJson);
            savedNetwork.init();
            savedNetwork.setParams(newParams);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return savedNetwork;
    }

    /**
     * Save model updators
     * @param net trained network | model
     * @param basePath path to store configuration
     */
    public static void saveUpdators(MultiLayerNetwork net, String basePath){
        String paramPath = FilenameUtils.concat(basePath, net.toString() + "updators.bin");
        try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(paramPath)))){
            oos.writeObject(net.getUpdater());
        } catch (IOException e) {
            e.printStackTrace();
        }

    }

    /**
     * Load model updators
     * @param updatorPath path of the updators
     * Returns saved updaters
     */
    public static Updater loadUpdators(String updatorPath){
        Updater updater = null;
        try(ObjectInputStream oos = new ObjectInputStream(new FileInputStream(new File(updatorPath)))){
            updater = (Updater) oos.readObject();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return updater;
    }
    
    /**
     * Save existing parameters for the layer
     * @param param layer parameters in INDArray format
     * @param paramPath string path where parameters are stored
     */
    public static void saveLayerParameters(INDArray param, String paramPath)  {
        // save parameters for each layer
        log.info("Saving parameters to {} ...", paramPath);

        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(paramPath)))){
            Nd4j.write(param, dos);
            dos.flush();
        } catch(IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * Load existing parameters to the layer
     * @param layer to load the parameters into
     * @param paramPath string path where parameters are stored
     */
    public static Layer loadLayerParameters(Layer layer, String paramPath) {
        // load parameters for each layer
        String name = layer.conf().getLayer().getLayerName();
        log.info("Loading saved parameters for layer {} ...", name);

        try{
            DataInputStream dis = new DataInputStream(new FileInputStream(paramPath));
            INDArray param = Nd4j.read(dis);
            dis.close();
            layer.setParams(param);
        } catch(IOException e) {
            e.printStackTrace();
        }

        return layer;
    }


    /**
     * Save existing parameters for the network
     * @param net trained network | model
     * @param layerIds list of *int* layer ids
     * @param paramPaths map of layer ids and string paths to store parameters
     */
    public static void saveParameters(MultiLayerNetwork net, int[] layerIds, Map paramPaths) {
        Layer layer;
        for(int layerId: layerIds) {
            layer = net.getLayer(layerId);
            if (!layer.paramTable().isEmpty()) {
                NetSaverLoaderUtils.saveLayerParameters(layer.params(), paramPaths.get(layerId));
            }
        }
    }

    /**
     * Save existing parameters for the network
     * @param net trained network | model
     * @param layerIds list of *string* layer ids
     * @param paramPaths map of layer ids and string paths to store parameters
     */
    public static void saveParameters(MultiLayerNetwork net, String[] layerIds, Map paramPaths) {
        Layer layer;
        for(String layerId: layerIds) {
            layer = net.getLayer(layerId);
            if (!layer.paramTable().isEmpty()) {
                NetSaverLoaderUtils.saveLayerParameters(layer.params(), paramPaths.get(layerId));
            }
        }
    }

    /**
     * Load existing parameters for the network
     * @param net trained network | model
     * @param layerIds list of *int* layer ids
     * @param paramPaths map of layer ids and string paths to find parameters
     */
    public static MultiLayerNetwork loadParameters(MultiLayerNetwork net, int[] layerIds, Map paramPaths) {
        Layer layer;
        for(int layerId: layerIds) {
            layer = net.getLayer(layerId);
            loadLayerParameters(layer, paramPaths.get(layerId));
        }
        return net;
    }

    /**
     * Load existing parameters for the network
     * @param net trained network | model
     * @param layerIds list of *string* layer ids
     * @param paramPaths map of layer ids and string paths to find parameters
     */
    public static MultiLayerNetwork loadParameters(MultiLayerNetwork net, String[] layerIds, Map paramPaths) {
        Layer layer;
        for(String layerId: layerIds) {
            layer = net.getLayer(layerId);
            loadLayerParameters(layer, paramPaths.get(layerId));
        }
        return net;
    }


    /**
     * Create map of *int* layerIds to path
     * @param layerIds list of *string* layer ids
     * @param basePath string path to find parameters
     */
    public static  Map getIdParamPaths(String basePath, int[] layerIds){
        Map paramPaths = new HashMap<>();
        for (int id : layerIds) {
            paramPaths.put(id, FilenameUtils.concat(basePath, id + ".bin"));
        }

        return paramPaths;
    }

    /**
     * Create map of *string* layerIds to path
     * @param layerIds list of *string* layer ids
     * @param basePath string path to find parameters
     */
    public static Map getStringParamPaths(String basePath, String[] layerIds){
        Map paramPaths = new HashMap<>();

        for (String name : layerIds) {
            paramPaths.put(name, FilenameUtils.concat(basePath, name + ".bin"));
        }

        return paramPaths;
    }

    /**
     * Define output directory based on network type
     * @param networkType
     */
    public static String defineOutputDir(String networkType){
        String tmpDir = System.getProperty("java.io.tmpdir");
        String outputPath = File.separator + networkType + File.separator + "output";
        File dataDir = new File(tmpDir,outputPath);
        if (!dataDir.getParentFile().exists())
            dataDir.mkdirs();
        return dataDir.toString();

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy