All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.MNISTViewer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.datasets;


import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.GridLayout;
import java.awt.Rectangle;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;

import javax.swing.BorderFactory;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;


public class MNISTViewer {

	MnistManager manager;
	NN network;
	InputPanel input;
	OutputPanel output;
	OptionsPanel options;
	
	int xnodes = 50;
	int ynodes = 24;
	int connections = 10;
	
	int outputRows = 10;
	int outputCols = 10;

	public MNISTViewer(){

		// Load the data manager
		try {
			manager = new MnistManager("MNIST/train-images-idx3-ubyte", "MNIST/train-labels-idx1-ubyte");
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}	
		
		network = new NN(xnodes*ynodes, connections);
		network.init();
		
		// Setup viewer
		MyFrame mf = new MyFrame("MNSIT Viewer");

		//Display the viewer.
		mf.pack();
		mf.setVisible(true);		
	}


	/**
	 * @param args
	 */
	public static void main(String[] args) {
		MNISTViewer viewer = new MNISTViewer();
	}

	class MyFrame extends JFrame {

		public MyFrame(String s){
			super(s);
			setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

			JPanel main = new JPanel();
			main.setBorder(BorderFactory.createLineBorder(Color.black));
			main.setLayout(new GridLayout(1, 3));

			input = new InputPanel();
			output = new OutputPanel();
			options = new OptionsPanel();

			add(main);
			main.add(input); 
			main.add(output); 
			main.add(options);

		}
	}	

	class InputPanel extends JPanel{

		private int width = 200;
		private int height = 200;
		private int mx = 20; // Margin x
		private int my = 30; // Margin y	
		
		private int imageIndex = 1;
		private int maxIndex;

		public InputPanel(){
			super();
			setPreferredSize(new Dimension(width, height));
			setBorder(BorderFactory.createLineBorder(Color.black));
			JLabel title = new JLabel("Input");
			add(title);	
			manager.setCurrent(imageIndex);
			maxIndex = manager.getImages().getCount();
		}
		
		public void nextImage(){
			imageIndex = imageIndex + 1 > maxIndex ? maxIndex : imageIndex + 1;
			manager.setCurrent(imageIndex);
		}
		
		public void previousImage(){
			imageIndex = imageIndex - 1 < 1 ? 1 : imageIndex - 1;
			manager.setCurrent(imageIndex);
		}	
		
		public void drawCurrentImage(Graphics g){
			manager.setCurrent(imageIndex);
			int[][] image = null;
			int rows = manager.getImages().getRows();
			int cols = manager.getImages().getRows();
			
			try {
				image = manager.readImage();
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			
			for (int i = 0; i < rows; i++){
				for (int j = 0; j < cols; j++){
					int c = image[i][j];
					g.setColor(new Color(c, c, c));
					g.fillRect(mx+j, my + i, 1, 1);
				}
				//System.out.println();
			}
		}
		
		public void drawCurrentOutput(Graphics g){
	
			int xoffset = mx;
			int yoffset = 3 * my;			
			
			int target;
			try {
				target = manager.readLabel();
				int[][] values = new int[outputRows][outputCols];
				for (int i = 0; i < outputCols; i++){
					values[target][i] = 255;;
				}
				int c = 0;
				for (int i = 0; i < outputRows; i++){
					for (int j = 0; j < outputCols; j++){
						c = values[i][j];
						g.setColor(new Color(c, c, c));
						g.fillRect(xoffset+j, yoffset + i, 1, 1);
					}
				}				
				
			} catch (IOException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}

		}		

		@Override
		public void paintComponent(Graphics g){
			super.paintComponent(g);
			g.drawRect(mx, my, getWidth()-2*mx, getHeight()-2*my);
			drawCurrentImage(g);
			drawCurrentOutput(g);
		}

	}	

	class OutputPanel extends JPanel{

		private int width = 200;
		private int height = 200;
		private int mx = 20; // Margin x
		private int my = 30; // Margin y

		public OutputPanel(){
			super();
			setPreferredSize(new Dimension(width, height));
			setBorder(BorderFactory.createLineBorder(Color.black));
			JLabel title = new JLabel("Output");
			add(title);
		}

		private void drawInputNodes(Graphics g){
			
			int rows = manager.getImages().getRows();
			int cols = manager.getImages().getRows();
			
			float[] state = network.readInput();
			if (state == null){
				state = new float[rows*cols];
			}
			
			int xoffset = mx;
			int yoffset = my;
			
			for (int i = 0; i < rows; i++){
				for (int j = 0; j < cols; j++){
					int index = i * cols + j;
					int v = (int)(255.0 * state[index]);
					int c = v > 255 ? 255 : v;
					try {
					g.setColor(new Color(c, c, c));
					g.fillRect(xoffset+j, yoffset + i, 1, 1);
					} catch (Exception e) {
						e.printStackTrace();
					}
				}
			}				
		}		
		
		private void drawState(Graphics g){
			float[] state = network.getState();
			
			int xoffset = mx;
			int yoffset = 3 * my;
			
			int rows = ynodes;
			int cols = xnodes;
			
			for (int i = 0; i < rows; i++){
				for (int j = 0; j < cols; j++){
					int v = (int)(255.0 * state[i * cols + j]);
					int c = v > 255 ? 255 : v;					
					g.setColor(new Color(c, c, c));
					g.fillRect(xoffset+j, yoffset + i, 1, 1);
				}
			}				
		}
		
		private void drawOutputNodes(Graphics g){
			
			int rows = outputRows;
			int cols = outputCols;
			
			float[] state = network.readOutput();
			if (state == null){
				state = new float[rows*cols];
			}
			
			int xoffset =  mx;
			int yoffset = 4 * my;
			
			for (int i = 0; i < rows; i++){
				for (int j = 0; j < cols; j++){
					int index = i * cols + j;
					int v = (int)(255.0 * state[index]);
					int c = v > 255 ? 255 : v;
					try {
					g.setColor(new Color(c, c, c));
					g.fillRect(xoffset+j, yoffset + i, 1, 1);
					} catch (Exception e) {
						e.printStackTrace();
					}
				}
			}				
		}			

		@Override
		public void paintComponent(Graphics g){
			super.paintComponent(g);
			g.drawRect(mx, my, getWidth()-2*mx, getHeight()-2*my);
			drawInputNodes(g);
			drawState(g);
			drawOutputNodes(g);
		}	

	}

	class OptionsPanel extends JPanel{

		public OptionsPanel(){
			super();
			setPreferredSize(new Dimension(200,200));
			setBorder(BorderFactory.createLineBorder(Color.black));
			JLabel title = new JLabel("Options");
			add(title);
			JButton next = new JButton("Next");
			JButton previous = new JButton("Previous");
			JButton setInput = new JButton("Set Input");
			JButton setOutput = new JButton("Set Output");
			JButton update = new JButton("Update");
			JButton reset = new JButton("Reset");
			
			next.addActionListener
			(
				new ActionListener(){
					public void actionPerformed( ActionEvent e )
					{
						input.nextImage();
						input.repaint();
					}
				}
			);
			
			previous.addActionListener(
				new ActionListener(){
					public void actionPerformed( ActionEvent e )
					{
						input.previousImage();
						input.repaint();
					}
				}
			);
			
			setInput.addActionListener(
					new ActionListener(){
						public void actionPerformed( ActionEvent e )
						{
							try {
								// serialise image
								manager.setCurrent(input.imageIndex);
								int[][] image = manager.readImage();
								int size = image.length*image[0].length;
								int[] inputNodes = new int[size];
								for (int i = 0; i < size; i++){
									inputNodes[i] = i;
								}
								float[] inputValues = new float[size];
								for (int i = 0; i < image.length; i++){
									for (int j = 0; j < image[0].length; j++){
										inputValues[i*image.length + j] = (float)(image[i][j])/255f;
									}
								}
								network.setInput(inputNodes, inputValues);
							} catch (IOException e1) {
								// TODO Auto-generated catch block
								e1.printStackTrace();
							}
							
							output.repaint();
						}			
					}
				);		
			
			setOutput.addActionListener(
					new ActionListener(){
						public void actionPerformed( ActionEvent e )
						{

							int size = outputRows*outputCols;
							
							// set output as the last nodes in the network list
							int[] outputNodes = new int[size];
							int index = xnodes*ynodes-1-size;
							for (int i = 0; i < size; i++){
								outputNodes[i] = index;
								index++;
							}
								
							// map output values to current image label
							int target;
							try {
								target = manager.readLabel();
								float[][] values = new float[outputRows][outputCols];
								for (int i = 0; i < outputCols; i++){
									values[target][i] = 1.0f;
								}
								float[] outputValues = new float[size];
								for (int i = 0; i < outputRows; i++){
									for (int j = 0; j < outputCols; j++){
										outputValues[i*outputCols + j] = values[i][j];
									}
								}							

								network.setOutput(outputNodes, outputValues);
								input.repaint();
								output.repaint();								
							} catch (IOException e1) {
								// TODO Auto-generated catch block
								e1.printStackTrace();
							}
						}			
					}
				);				
			
			update.addActionListener(
				new ActionListener(){
					public void actionPerformed( ActionEvent e )
					{
						network.update();
						output.repaint();	
						//System.out.println(network.toString());
					}
				}
			);	
		
			reset.addActionListener(
					new ActionListener(){
						public void actionPerformed( ActionEvent e )
						{
							network.reset();
							output.repaint();	
						}
					}
				);				
			
			add(next); add(previous); add(setInput); add(setOutput); add(update); add(reset);
			
		}

	}

}









© 2015 - 2025 Weber Informatics LLC | Privacy Policy