All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.plot.NeuralNetPlotter Maven / Gradle / Ivy

package org.deeplearning4j.plot;

import java.io.*;
import java.util.List;
import java.util.UUID;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;


/**
 * Credit to :
 * http://yosinski.com/media/papers/Yosinski2012VisuallyDebuggingRestrictedBoltzmannMachine.pdf
 *
 *
 * for visualizations
 * @author Adam Gibson
 *
 */
public class NeuralNetPlotter implements Serializable {

    private static 	ClassPathResource r = new ClassPathResource("/scripts/plot.py");
    private static Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);


    static {
        loadIntoTmp();
    }


    public void renderFilter(INDArray w,int r,int c,long length) {
        try {
            String filePath = writeMatrix(w);
            Process is = Runtime.getRuntime().exec("python /tmp/plot.py filter " + filePath + " " + r + " " + c + " " + length);
            log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString());
            log.info("Rendering weights " + filePath);
            log.error(IOUtils.readLines(is.getErrorStream()).toString());

        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }




    public void plotNetworkGradient(NeuralNetwork network,NeuralNetworkGradient gradient,int patchesPerRow) {
        histogram(
                new String[]{"W", "hbias", "vbias", "w-gradient", "hbias-gradient", "vbias-gradient"},

                new INDArray[]{
                        network.getW(),
                        network.gethBias(),
                        network.getvBias(),
                        gradient.getwGradient(),
                        gradient.gethBiasGradient(),
                        gradient.getvBiasGradient()

                });
        plotActivations(network);

        FilterRenderer render = new FilterRenderer();
        try {
            if(network.getW().shape().length > 2) {
                INDArray w = (INDArray) network.getW().dup();
                INDArray render2 = w.transpose();
                render.renderFilters(render2, "currimg.png", w.columns() , w.rows(),w.slices());

            }
            else
                render.renderFilters(network.getW().dup(), "currimg.png", (int)Math.sqrt(network.getW().rows()) , (int) Math.sqrt( network.getW().rows()),patchesPerRow);


        } catch (Exception e) {
            log.error("Unable to plot filter, continuing...",e);
        }
    }



    /**
     * Histograms the given matrices. This is primarily used
     * for debugging gradients. You don't necessarily use this directly
     * @param titles the titles of the plots
     * @param matrices the matrices to plot
     */
    public void scatter(String[] titles, INDArray[] matrices) {
        String[] path = new String[matrices.length * 2];
        try {
            if(titles.length != matrices.length)
                throw new IllegalArgumentException("Titles and matrix lengths must be equal");


            for(int i = 0; i < path.length - 1; i+=2) {
                path[i] = writeMatrix(matrices[i / 2].ravel());
                path[i + 1] = titles[i / 2];
            }
            String paths = StringUtils.join(path,",");

            Process is = Runtime.getRuntime().exec("python /tmp/plot.py scatter " + paths);

            log.info("Rendering Matrix histograms... ");
            log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString());
            log.error(IOUtils.readLines(is.getErrorStream()).toString());


        }catch(IOException e) {
            throw new RuntimeException(e);
        }

    }

    /**
     * Histograms the given matrices. This is primarily used
     * for debugging gradients. You don't necessarily use this directly
     * @param titles the titles of the plots
     * @param matrices the matrices to plot
     */
    public void histogram(String[] titles, INDArray[] matrices) {
        String[] path = new String[matrices.length * 2];
        try {
            if(titles.length != matrices.length)
                throw new IllegalArgumentException("Titles and matrix lengths must be equal");


            for(int i = 0; i < path.length - 1; i+=2) {
                path[i] = writeMatrix(matrices[i / 2].ravel());
                path[i + 1] = titles[i / 2];
            }
            String paths = StringUtils.join(path,",");

            Process is = Runtime.getRuntime().exec("python /tmp/plot.py multi " + paths);

            log.info("Rendering Matrix histograms... ");
            log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString());
            log.error(IOUtils.readLines(is.getErrorStream()).toString());


        }catch(IOException e) {
            throw new RuntimeException(e);
        }

    }



    protected String writeMatrix(INDArray matrix) throws IOException {
        String filePath = System.getProperty("java.io.tmpdir") + File.separator +  UUID.randomUUID().toString();
        File write = new File(filePath);
        BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write,true));
        write.deleteOnExit();
        for(int i = 0; i < matrix.rows(); i++) {
            INDArray row = matrix.getRow(i);
            StringBuffer sb = new StringBuffer();
            for(int j = 0; j < row.length(); j++) {
                sb.append(String.format("%.10f", row.getDouble(j)));
                if(j < row.length() - 1)
                    sb.append(",");
            }
            sb.append("\n");
            String line = sb.toString();
            bos.write(line.getBytes());
            bos.flush();
        }

        bos.close();
        return filePath;
    }



    public void plotActivations(NeuralNetwork network) {
        try {
            if(network.getInput() == null)
                throw new IllegalStateException("Unable to plot; missing input");

            INDArray hbiasMean = network.hBiasMean();


            String filePath = writeMatrix(hbiasMean);

            Process is = Runtime.getRuntime().exec("python /tmp/plot.py hbias " + filePath);

            Thread.sleep(10000);
            is.destroy();


            log.info("Rendering hbias " + filePath);
            log.error(IOUtils.readLines(is.getErrorStream()).toString());

        }catch(Exception e) {
            log.warn("Image closed");

        }
    }


    private static void loadIntoTmp() {

        File script = new File("/tmp/plot.py");


        try {
            List lines = IOUtils.readLines(r.getInputStream());
            FileUtils.writeLines(script, lines);

        } catch (IOException e) {
            throw new IllegalStateException("Unable to load python file");

        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy