org.deeplearning4j.plot.NeuralNetPlotter Maven / Gradle / Ivy
package org.deeplearning4j.plot;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.NeuralNetworkGradient;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;
/**
* Credit to :
* http://yosinski.com/media/papers/Yosinski2012VisuallyDebuggingRestrictedBoltzmannMachine.pdf
*
*
* for visualizations
* @author Adam Gibson
*
*/
public class NeuralNetPlotter {
private static ClassPathResource r = new ClassPathResource("/scripts/plot.py");
private static Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);
static {
loadIntoTmp();
}
public void renderFilter(DoubleMatrix w,int r,int c,long length) {
try {
String filePath = writeMatrix(w);
Process is = Runtime.getRuntime().exec("python /tmp/plot.py filter " + filePath + " " + r + " " + c + " " + length);
log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString());
log.info("Rendering weights " + filePath);
log.error(IOUtils.readLines(is.getErrorStream()).toString());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void plotNetworkGradient(NeuralNetwork network,NeuralNetworkGradient gradient) {
plotMatrices(
new String[]{"W","hbias","vbias","w-gradient","hbias-gradient","vbias-gradient"},
new DoubleMatrix[]{
network.getW(),
network.gethBias(),
network.getvBias(),
gradient.getwGradient(),
gradient.gethBiasGradient(),
gradient.getvBiasGradient()
});
plotActivations(network);
/*DoubleMatrix w = network.getW().transpose();
boolean isPerfectSquare = Math.sqrt(w.columns) % 1 == 0;
if(isPerfectSquare)
renderFilter(w,w.rows,w.columns,network.getInput().rows);
else
log.warn("Cant render good representation of filter witout perfect square, please choose different hidden layer size");*/
}
public void plotMatrices(String[] titles,DoubleMatrix[] matrices) {
String[] path = new String[matrices.length * 2];
try {
if(titles.length != matrices.length)
throw new IllegalArgumentException("Titles and matrix lengths must be equal");
for(int i = 0; i < path.length - 1; i+=2) {
path[i] = writeMatrix(MatrixUtil.unroll(matrices[i / 2]));
path[i + 1] = titles[i / 2];
}
String paths = StringUtils.join(path,",");
Process is = Runtime.getRuntime().exec("python /tmp/plot.py multi " + paths);
log.info("Rendering multiple matrices... ");
log.info("Std out " + IOUtils.readLines(is.getInputStream()).toString());
log.error(IOUtils.readLines(is.getErrorStream()).toString());
}catch(IOException e) {
throw new RuntimeException(e);
}
}
protected String writeMatrix(DoubleMatrix matrix) throws IOException {
String filePath = System.getProperty("java.io.tmpdir") + File.separator + UUID.randomUUID().toString();
File write = new File(filePath);
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(write,true));
write.deleteOnExit();
for(int i = 0; i < matrix.rows; i++) {
DoubleMatrix row = matrix.getRow(i);
StringBuffer sb = new StringBuffer();
for(int j = 0; j < row.length; j++) {
sb.append(String.format("%.10f", row.get(j)));
if(j < row.length - 1)
sb.append(",");
}
sb.append("\n");
String line = sb.toString();
bos.write(line.getBytes());
bos.flush();
}
bos.close();
return filePath;
}
public void plotWeights(NeuralNetwork network) {
try {
String filePath = writeMatrix(network.getW());
Process is = Runtime.getRuntime().exec("python /tmp/plot.py weights " + filePath);
log.info("Rendering weights " + filePath);
log.error(IOUtils.readLines(is.getErrorStream()).toString());
}catch(Exception e) {
}
}
public void plotActivations(NeuralNetwork network) {
try {
if(network.getInput() == null)
throw new IllegalStateException("Unable to plot; missing input");;
DoubleMatrix hbiasMean = network.getInput().mmul(network.getW()).addRowVector(network.gethBias());
String filePath = writeMatrix(hbiasMean);
Process is = Runtime.getRuntime().exec("python /tmp/plot.py hbias " + filePath);
Thread.sleep(10000);
is.destroy();
log.info("Rendering hbias " + filePath);
log.error(IOUtils.readLines(is.getErrorStream()).toString());
}catch(Exception e) {
log.warn("Image closed");
}
}
private static void loadIntoTmp() {
File script = new File("/tmp/plot.py");
try {
List lines = IOUtils.readLines(r.getInputStream());
FileUtils.writeLines(script, lines);
} catch (IOException e) {
throw new IllegalStateException("Unable to load python file");
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy