All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.MultilayerPerceptron Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    MultilayerPerceptron.java
 *    Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand
 */

package weka.classifiers.functions;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.FontMetrics;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;

import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextField;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.functions.neural.LinearUnit;
import weka.classifiers.functions.neural.NeuralConnection;
import weka.classifiers.functions.neural.NeuralNode;
import weka.classifiers.functions.neural.SigmoidUnit;
import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

/**
 * A classifier that uses backpropagation to learn a multi-layer perceptron to classify instances.
 * The network can be built by hand or set up using a simple heuristic.
 * The network parameters can also be monitored and modified during training time.
 * The nodes in this network are all sigmoid (except for when the class
 * is numeric, in which case the output nodes become unthresholded linear units).
 * 

* * * Valid options are: *

* *

 * -L <learning rate>
 *  Learning rate for the backpropagation algorithm.
 *  (Value should be between 0 - 1, Default = 0.3).
 * 
* *
 * -M <momentum>
 *  Momentum rate for the backpropagation algorithm.
 *  (Value should be between 0 - 1, Default = 0.2).
 * 
* *
 * -N <number of epochs>
 *  Number of epochs to train through.
 *  (Default = 500).
 * 
* *
 * -V <percentage size of validation set>
 *  Percentage size of validation set to use to terminate
 *  training (if this is non zero it can pre-empt num of epochs.
 *  (Value should be between 0 - 100, Default = 0).
 * 
* *
 * -S <seed>
 *  The value used to seed the random number generator
 *  (Value should be >= 0 and and a long, Default = 0).
 * 
* *
 * -E <threshold for number of consecutive errors>
 *  The number of consecutive increases of error allowed for validation
 *  testing before training terminates.
 *  (Value should be > 0, Default = 20).
 * 
* *
 * -G
 *  GUI will be opened.
 *  (Use this to bring up a GUI).
 * 
* *
 * -A
 *  Autocreation of the network connections will NOT be done.
 *  (This will be ignored if -G is NOT set)
 * 
* *
 * -B
 *  A NominalToBinary filter will NOT automatically be used.
 *  (Set this to not use a NominalToBinary filter).
 * 
* *
 * -H <comma separated numbers for nodes on each layer>
 *  The hidden layers to be created for the network.
 *  (Value should be a list of comma separated Natural 
 *  numbers or the letters 'a' = (attribs + classes) / 2, 
 *  'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
 *  for wildcard values, Default = a).
 * 
* *
 * -C
 *  Normalizing a numeric class will NOT be done.
 *  (Set this to not normalize the class if it's numeric).
 * 
* *
 * -I
 *  Normalizing the attributes will NOT be done.
 *  (Set this to not normalize the attributes).
 * 
* *
 * -R
 *  Reseting the network will NOT be allowed.
 *  (Set this to not allow the network to reset).
 * 
* *
 * -D
 *  Learning rate decay will occur.
 *  (Set this to cause the learning rate to decay).
 * 
* * * * @author Malcolm Ware ([email protected]) * @version $Revision: 15021 $ */ public class MultilayerPerceptron extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler, Randomizable, IterativeClassifier { /** for serialization */ private static final long serialVersionUID = -5990607817048210779L; /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String[] argv) { runClassifier(new MultilayerPerceptron(), argv); } /** * This inner class is used to connect the nodes in the network up to the data * that they are classifying, Note that objects of this class are only * suitable to go on the attribute side or class side of the network and not * both. */ protected class NeuralEnd extends NeuralConnection { /** for serialization */ static final long serialVersionUID = 7305185603191183338L; /** * the value that represents the instance value this node represents. For an * input it is the attribute number, for an output, if nominal it is the * class value. */ private int m_link; /** True if node is an input, False if it's an output. */ private boolean m_input; /** * Constructor */ public NeuralEnd(String id) { super(id); m_link = 0; m_input = true; } /** * Call this function to determine if the point at x,y is on the unit. * * @param g The graphics context for font size info. * @param x The x coord. * @param y The y coord. * @param w The width of the display. * @param h The height of the display. * @return True if the point is on the unit, false otherwise. */ @Override public boolean onUnit(Graphics g, int x, int y, int w, int h) { FontMetrics fm = g.getFontMetrics(); int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2; int t = (int) (m_y * h) - fm.getHeight() / 2; if (x < l || x > l + fm.stringWidth(m_id) + 4 || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) { return false; } return true; } /** * This will draw the node id to the graphics context. * * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ @Override public void drawNode(Graphics g, int w, int h) { if ((m_type & PURE_INPUT) == PURE_INPUT) { g.setColor(Color.green); } else { g.setColor(Color.orange); } FontMetrics fm = g.getFontMetrics(); int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2; int t = (int) (m_y * h) - fm.getHeight() / 2; g.fill3DRect(l, t, fm.stringWidth(m_id) + 4, fm.getHeight() + fm.getDescent() + 4, true); g.setColor(Color.black); g.drawString(m_id, l + 2, t + fm.getHeight() + 2); } /** * Call this function to draw the node highlighted. * * @param g The graphics context. * @param w The width of the drawing area. * @param h The height of the drawing area. */ @Override public void drawHighlight(Graphics g, int w, int h) { g.setColor(Color.black); FontMetrics fm = g.getFontMetrics(); int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2; int t = (int) (m_y * h) - fm.getHeight() / 2; g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8, fm.getHeight() + fm.getDescent() + 8); drawNode(g, w, h); } /** * Call this to get the output value of this unit. * * @param calculate True if the value should be calculated if it hasn't been * already. * @return The output value, or NaN, if the value has not been calculated. */ @Override public double outputValue(boolean calculate) { if (Double.isNaN(m_unitValue) && calculate) { if (m_input) { if (m_currentInstance.isMissing(m_link)) { m_unitValue = 0; } else { m_unitValue = m_currentInstance.value(m_link); } } else { // node is an output. m_unitValue = 0; for (int noa = 0; noa < m_numInputs; noa++) { m_unitValue += m_inputList[noa].outputValue(true); } if (m_numeric && m_normalizeClass) { // then scale the value; // this scales linearly from between -1 and 1 m_unitValue = m_unitValue * m_attributeRanges[m_instances.classIndex()] + m_attributeBases[m_instances.classIndex()]; } } } return m_unitValue; } /** * Call this to get the error value of this unit, which in this case is the * difference between the predicted class, and the actual class. * * @param calculate True if the value should be calculated if it hasn't been * already. * @return The error value, or NaN, if the value has not been calculated. */ @Override public double errorValue(boolean calculate) { if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) { if (m_input) { m_unitError = 0; for (int noa = 0; noa < m_numOutputs; noa++) { m_unitError += m_outputList[noa].errorValue(true); } } else { if (m_currentInstance.classIsMissing()) { m_unitError = .1; } else if (m_instances.classAttribute().isNominal()) { if (m_currentInstance.classValue() == m_link) { m_unitError = 1 - m_unitValue; } else { m_unitError = 0 - m_unitValue; } } else if (m_numeric) { if (m_normalizeClass) { if (m_attributeRanges[m_instances.classIndex()] == 0) { m_unitError = 0; } else { m_unitError = (m_currentInstance.classValue() - m_unitValue) / m_attributeRanges[m_instances.classIndex()]; // m_numericRange; } } else { m_unitError = m_currentInstance.classValue() - m_unitValue; } } } } return m_unitError; } /** * Call this to reset the value and error for this unit, ready for the next * run. This will also call the reset function of all units that are * connected as inputs to this one. This is also the time that the update * for the listeners will be performed. */ @Override public void reset() { if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) { m_unitValue = Double.NaN; m_unitError = Double.NaN; m_weightsUpdated = false; for (int noa = 0; noa < m_numInputs; noa++) { m_inputList[noa].reset(); } } } /** * Call this to have the connection save the current weights. */ @Override public void saveWeights() { for (int i = 0; i < m_numInputs; i++) { m_inputList[i].saveWeights(); } } /** * Call this to have the connection restore from the saved weights. */ @Override public void restoreWeights() { for (int i = 0; i < m_numInputs; i++) { m_inputList[i].restoreWeights(); } } /** * Call this function to set What this end unit represents. * * @param input True if this unit is used for entering an attribute, False * if it's used for determining a class value. * @param val The attribute number or class type that this unit represents. * (for nominal attributes). */ public void setLink(boolean input, int val) throws Exception { m_input = input; if (input) { m_type = PURE_INPUT; } else { m_type = PURE_OUTPUT; } if (val < 0 || (input && val > m_instances.numAttributes()) || (!input && m_instances.classAttribute().isNominal() && val > m_instances .classAttribute().numValues())) { m_link = 0; } else { m_link = val; } } /** * @return link for this node. */ public int getLink() { return m_link; } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 15021 $"); } } /** * Inner class used to draw the nodes onto.(uses the node lists!!) This will * also handle the user input. */ private class NodePanel extends JPanel implements RevisionHandler { /** for serialization */ static final long serialVersionUID = -3067621833388149984L; /** * The constructor. */ public NodePanel() { addMouseListener(new MouseAdapter() { @Override public void mousePressed(MouseEvent e) { if (!m_stopped) { return; } if ((e.getModifiers() & MouseEvent.BUTTON1_MASK) == MouseEvent.BUTTON1_MASK && !e.isAltDown()) { Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); ArrayList tmp = new ArrayList(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.add(m_inputs[noa]); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, true); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.add(m_outputs[noa]); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, true); return; } } for (NeuralConnection m_neuralNode : m_neuralNodes) { if (m_neuralNode.onUnit(g, x, y, w, h)) { tmp.add(m_neuralNode); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, true); return; } } NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX((double) e.getX() / w); temp.setY((double) e.getY() / h); tmp.add(temp); addNode(temp); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, true); } else { // then right click Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); ArrayList tmp = new ArrayList(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.add(m_inputs[noa]); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, false); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.add(m_outputs[noa]); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, false); return; } } for (NeuralConnection m_neuralNode : m_neuralNodes) { if (m_neuralNode.onUnit(g, x, y, w, h)) { tmp.add(m_neuralNode); selection( tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, false); return; } } selection( null, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK, false); } } }); } /** * This function gets called when the user has clicked something It will * amend the current selection or connect the current selection to the new * selection. Or if nothing was selected and the right button was used it * will delete the node. * * @param v The units that were selected. * @param ctrl True if ctrl was held down. * @param left True if it was the left mouse button. */ private void selection(ArrayList v, boolean ctrl, boolean left) { if (v == null) { // then unselect all. m_selected.clear(); repaint(); return; } // then exclusive or the new selection with the current one. if ((ctrl || m_selected.size() == 0) && left) { boolean removed = false; for (int noa = 0; noa < v.size(); noa++) { removed = false; for (int nob = 0; nob < m_selected.size(); nob++) { if (v.get(noa) == m_selected.get(nob)) { // then remove that element m_selected.remove(nob); removed = true; break; } } if (!removed) { m_selected.add(v.get(noa)); } } repaint(); return; } if (left) { // then connect the current selection to the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection.connect(m_selected.get(noa), v.get(nob)); } } } else if (m_selected.size() > 0) { // then disconnect the current selection from the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection.disconnect(m_selected.get(noa), v.get(nob)); NeuralConnection.disconnect(v.get(nob), m_selected.get(noa)); } } } else { // then remove the selected node. (it was right clicked while // no other units were selected for (int noa = 0; noa < v.size(); noa++) { v.get(noa).removeAllInputs(); v.get(noa).removeAllOutputs(); removeNode(v.get(noa)); } } repaint(); } /** * To make sure pane has appropriate size and scroll bars work correctly. */ @Override public Dimension getPreferredSize() { return new Dimension(getWidth(), Integer.max(getHeight(), Integer.max(25 * m_numAttributes, 25 * m_numClasses))); } /** * This will paint the nodes ontot the panel. * * @param g The graphics context. */ @Override public void paintComponent(Graphics g) { super.paintComponent(g); int x = getWidth(); int y = getHeight(); for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawInputLines(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawInputLines(g, x, y); m_outputs[noa].drawOutputLines(g, x, y); } for (NeuralConnection m_neuralNode : m_neuralNodes) { m_neuralNode.drawInputLines(g, x, y); } for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawNode(g, x, y); } for (NeuralConnection m_neuralNode : m_neuralNodes) { m_neuralNode.drawNode(g, x, y); } for (int noa = 0; noa < m_selected.size(); noa++) { m_selected.get(noa).drawHighlight(g, x, y); } } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 15021 $"); } } /** * This provides the basic controls for working with the neuralnetwork * * @author Malcolm Ware ([email protected]) * @version $Revision: 15021 $ */ class ControlPanel extends JPanel implements RevisionHandler { /** for serialization */ static final long serialVersionUID = 7393543302294142271L; /** The start stop button. */ public JButton m_startStop; /** The button to accept the network (even if it hasn't done all epochs. */ public JButton m_acceptButton; /** A label to state the number of epochs processed so far. */ public JPanel m_epochsLabel; /** A label to state the total number of epochs to be processed. */ public JLabel m_totalEpochsLabel; /** A text field to allow the changing of the total number of epochs. */ public JTextField m_changeEpochs; /** A label to state the learning rate. */ public JLabel m_learningLabel; /** A label to state the momentum. */ public JLabel m_momentumLabel; /** A text field to allow the changing of the learning rate. */ public JTextField m_changeLearning; /** A text field to allow the changing of the momentum. */ public JTextField m_changeMomentum; /** * A label to state roughly the accuracy of the network.(because the * accuracy is calculated per epoch, but the network is changing throughout * each epoch train). */ public JPanel m_errorLabel; /** The constructor. */ public ControlPanel() { setBorder(BorderFactory.createTitledBorder("Controls")); m_totalEpochsLabel = new JLabel("Num Of Epochs "); m_epochsLabel = new JPanel() { /** for serialization */ private static final long serialVersionUID = 2562773937093221399L; @Override public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); g.drawString("Epoch " + m_epoch, 0, 10); } }; m_epochsLabel.setFont(m_totalEpochsLabel.getFont()); m_changeEpochs = new JTextField(); m_changeEpochs.setText("" + m_numEpochs); m_errorLabel = new JPanel() { /** for serialization */ private static final long serialVersionUID = 4390239056336679189L; @Override public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); if (m_valSize == 0) { g.drawString( "Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } else { g.drawString( "Validation Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } } }; m_errorLabel.setFont(m_epochsLabel.getFont()); m_learningLabel = new JLabel("Learning Rate = "); m_momentumLabel = new JLabel("Momentum = "); m_changeLearning = new JTextField(); m_changeMomentum = new JTextField(); m_changeLearning.setText("" + m_learningRate); m_changeMomentum.setText("" + m_momentum); setLayout(new BorderLayout(15, 10)); m_stopIt = true; m_accepted = false; m_startStop = new JButton("Start"); m_startStop.setActionCommand("Start"); m_acceptButton = new JButton("Accept"); m_acceptButton.setActionCommand("Accept"); JPanel buttons = new JPanel(); buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS)); buttons.add(m_startStop); buttons.add(m_acceptButton); add(buttons, BorderLayout.WEST); JPanel data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); Box ab = new Box(BoxLayout.X_AXIS); ab.add(m_epochsLabel); data.add(ab); ab = new Box(BoxLayout.X_AXIS); Component b = Box.createGlue(); ab.add(m_totalEpochsLabel); ab.add(m_changeEpochs); m_changeEpochs.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); ab.add(m_errorLabel); data.add(ab); add(data, BorderLayout.CENTER); data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_learningLabel); ab.add(m_changeLearning); m_changeLearning.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_momentumLabel); ab.add(m_changeMomentum); m_changeMomentum.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); add(data, BorderLayout.EAST); m_startStop.addActionListener(new ActionListener() { @Override public void actionPerformed(ActionEvent e) { if (e.getActionCommand().equals("Start")) { m_stopIt = false; m_startStop.setText("Stop"); m_startStop.setActionCommand("Stop"); int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); m_numEpochs = n; m_changeEpochs.setText("" + m_numEpochs); double m = Double.valueOf(m_changeLearning.getText()).doubleValue(); setLearningRate(m); m_changeLearning.setText("" + m_learningRate); m = Double.valueOf(m_changeMomentum.getText()).doubleValue(); setMomentum(m); m_changeMomentum.setText("" + m_momentum); blocker(false); } else if (e.getActionCommand().equals("Stop")) { m_stopIt = true; m_startStop.setText("Start"); m_startStop.setActionCommand("Start"); } } }); m_acceptButton.addActionListener(new ActionListener() { @Override public void actionPerformed(ActionEvent e) { m_accepted = true; blocker(false); } }); m_changeEpochs.addActionListener(new ActionListener() { @Override public void actionPerformed(ActionEvent e) { int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); if (n > 0) { m_numEpochs = n; blocker(false); } } }); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 15021 $"); } } /** * a ZeroR model in case no model can be built from the data or the network * predicts all zeros for the classes */ private Classifier m_ZeroR; /** Whether to use the default ZeroR model */ private boolean m_useDefaultModel = false; /** The training instances. */ private Instances m_instances; /** The current instance running through the network. */ private Instance m_currentInstance; /** A flag to say that it's a numeric class. */ private boolean m_numeric; /** The ranges for all the attributes. */ private double[] m_attributeRanges; /** The base values for all the attributes. */ private double[] m_attributeBases; /** The output units.(only feeds the errors, does no calcs) */ private NeuralEnd[] m_outputs; /** The input units.(only feeds the inputs does no calcs) */ private NeuralEnd[] m_inputs; /** All the nodes that actually comprise the logical neural net. */ private NeuralConnection[] m_neuralNodes; /** The number of classes. */ private int m_numClasses = 0; /** The number of attributes. */ private int m_numAttributes = 0; // note the number doesn't include the class. /** The panel the nodes are displayed on. */ private NodePanel m_nodePanel; /** The control panel. */ private ControlPanel m_controlPanel; /** The next id number available for default naming. */ private int m_nextId; /** A Vector list of the units currently selected. */ private ArrayList m_selected; /** The number of epochs to train through. */ private int m_numEpochs; /** a flag to state if the network should be running, or stopped. */ private boolean m_stopIt; /** a flag to state that the network has in fact stopped. */ private boolean m_stopped; /** a flag to state that the network should be accepted the way it is. */ private boolean m_accepted; /** The window for the network. */ private JFrame m_win; /** * A flag to tell the build classifier to automatically build a neural net. */ private boolean m_autoBuild; /** * A flag to state that the gui for the network should be brought up. To allow * interaction while training. */ private boolean m_gui; /** An int to say how big the validation set should be. */ private int m_valSize; /** The number to to use to quit on validation testing. */ private int m_driftThreshold; /** The number used to seed the random number generator. */ private int m_randomSeed; /** The actual random number generator. */ private Random m_random; /** A flag to state that a nominal to binary filter should be used. */ private boolean m_useNomToBin; /** The actual filter. */ private NominalToBinary m_nominalToBinaryFilter; /** The string that defines the hidden layers */ private String m_hiddenLayers; /** This flag states that the user wants the input values normalized. */ private boolean m_normalizeAttributes; /** This flag states that the user wants the learning rate to decay. */ private boolean m_decay; /** This is the learning rate for the network. */ private double m_learningRate; /** This is the momentum for the network. */ private double m_momentum; /** Shows the number of the epoch that the network just finished. */ private int m_epoch; /** Number of iterations (epochs) performed in this session of iterating */ private int m_numItsPerformed; /** Shows the error of the epoch that the network just finished. */ private double m_error; /** * This flag states that the user wants the network to restart if it is found * to be generating infinity or NaN for the error value. This would restart * the network with the current options except that the learning rate would be * smaller than before, (perhaps half of its current value). This option will * not be available if the gui is chosen (if the gui is open the user can fix * the network themselves, it is an architectural minefield for the network to * be reset with the gui open). */ private boolean m_reset; /** * This flag states that the user wants the class to be normalized while * processing in the network is done. (the final answer will be in the * original range regardless). This option will only be used when the class is * numeric. */ private boolean m_normalizeClass; /** * this is a sigmoid unit. */ private final SigmoidUnit m_sigmoidUnit; /** * This is a linear unit. */ private final LinearUnit m_linearUnit; /** * Whether to allow training to continue at a later point after the initial * model is built. */ protected boolean m_resume; /** * The constructor. */ public MultilayerPerceptron() { m_instances = null; m_currentInstance = null; m_controlPanel = null; m_nodePanel = null; m_epoch = 0; m_error = 0; m_outputs = new NeuralEnd[0]; m_inputs = new NeuralEnd[0]; m_numAttributes = 0; m_numClasses = 0; m_neuralNodes = new NeuralConnection[0]; m_selected = new ArrayList(4); m_nextId = 0; m_stopIt = true; m_stopped = true; m_accepted = false; m_numeric = false; m_random = null; m_nominalToBinaryFilter = new NominalToBinary(); m_sigmoidUnit = new SigmoidUnit(); m_linearUnit = new LinearUnit(); // setting all the options to their defaults. To completely change these // defaults they will also need to be changed down the bottom in the // setoptions function (the text info in the accompanying functions should // also be changed to reflect the new defaults m_normalizeClass = true; m_normalizeAttributes = true; m_autoBuild = true; m_gui = false; m_useNomToBin = true; m_driftThreshold = 20; m_numEpochs = 500; m_valSize = 0; m_randomSeed = 0; m_hiddenLayers = "a"; m_learningRate = .3; m_momentum = .2; m_reset = true; m_decay = false; } /** * @param d True if the learning rate should decay. */ public void setDecay(boolean d) { m_decay = d; } /** * @return the flag for having the learning rate decay. */ public boolean getDecay() { return m_decay; } /** * This sets the network up to be able to reset itself with the current * settings and the learning rate at half of what it is currently. This will * only happen if the network creates NaN or infinite errors. Also this will * continue to happen until the network is trained properly. The learning rate * will also get set back to it's original value at the end of this. This can * only be set to true if the GUI is not brought up. * * @param r True if the network should restart with it's current options and * set the learning rate to half what it currently is. */ public void setReset(boolean r) { if (m_gui) { r = false; } m_reset = r; } /** * @return The flag for reseting the network. */ public boolean getReset() { return m_reset; } /** * @param c True if the class should be normalized (the class will only ever * be normalized if it is numeric). (Normalization puts the range * between -1 - 1). */ public void setNormalizeNumericClass(boolean c) { m_normalizeClass = c; } /** * @return The flag for normalizing a numeric class. */ public boolean getNormalizeNumericClass() { return m_normalizeClass; } /** * @param a True if the attributes should be normalized (even nominal * attributes will get normalized here) (range goes between -1 - 1). */ public void setNormalizeAttributes(boolean a) { m_normalizeAttributes = a; } /** * @return The flag for normalizing attributes. */ public boolean getNormalizeAttributes() { return m_normalizeAttributes; } /** * @param f True if a nominalToBinary filter should be used on the data. */ public void setNominalToBinaryFilter(boolean f) { m_useNomToBin = f; } /** * @return The flag for nominal to binary filter use. */ public boolean getNominalToBinaryFilter() { return m_useNomToBin; } /** * This seeds the random number generator, that is used when a random number * is needed for the network. * * @param l The seed. */ @Override public void setSeed(int l) { if (l >= 0) { m_randomSeed = l; } } /** * @return The seed for the random number generator. */ @Override public int getSeed() { return m_randomSeed; } /** * This sets the threshold to use for when validation testing is being done. * It works by ending testing once the error on the validation set has * consecutively increased a certain number of times. * * @param t The threshold to use for this. */ public void setValidationThreshold(int t) { if (t > 0) { m_driftThreshold = t; } } /** * @return The threshold used for validation testing. */ public int getValidationThreshold() { return m_driftThreshold; } /** * The learning rate can be set using this command. NOTE That this is a static * variable so it affect all networks that are running. Must be greater than 0 * and no more than 1. * * @param l The New learning rate. */ public void setLearningRate(double l) { if (l > 0 && l <= 1) { m_learningRate = l; if (m_controlPanel != null) { m_controlPanel.m_changeLearning.setText("" + l); } } } /** * @return The learning rate for the nodes. */ public double getLearningRate() { return m_learningRate; } /** * The momentum can be set using this command. THE same conditions apply to * this as to the learning rate. * * @param m The new Momentum. */ public void setMomentum(double m) { if (m >= 0 && m <= 1) { m_momentum = m; if (m_controlPanel != null) { m_controlPanel.m_changeMomentum.setText("" + m); } } } /** * @return The momentum for the nodes. */ public double getMomentum() { return m_momentum; } /** * This will set whether the network is automatically built or if it is left * up to the user. (there is nothing to stop a user from altering an autobuilt * network however). * * @param a True if the network should be auto built. */ public void setAutoBuild(boolean a) { if (!m_gui) { a = true; } m_autoBuild = a; } /** * @return The auto build state. */ public boolean getAutoBuild() { return m_autoBuild; } /** * This will set what the hidden layers are made up of when auto build is * enabled. Note to have no hidden units, just put a single 0, Any more 0's * will indicate that the string is badly formed and make it unaccepted. * Negative numbers, and floats will do the same. There are also some * wildcards. These are 'a' = (number of attributes + number of classes) / 2, * 'i' = number of attributes, 'o' = number of classes, and 't' = number of * attributes + number of classes. * * @param h A string with a comma seperated list of numbers. Each number is * the number of nodes to be on a hidden layer. */ public void setHiddenLayers(String h) { String tmp = ""; StringTokenizer tok = new StringTokenizer(h, ","); if (tok.countTokens() == 0) { return; } double dval; int val; String c; boolean first = true; while (tok.hasMoreTokens()) { c = tok.nextToken().trim(); if (c.equals("a") || c.equals("i") || c.equals("o") || c.equals("t")) { tmp += c; } else { dval = Double.valueOf(c).doubleValue(); val = (int) dval; if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) && val >= 0)) { tmp += val; } else { return; } } first = false; if (tok.hasMoreTokens()) { tmp += ", "; } } m_hiddenLayers = tmp; } /** * @return A string representing the hidden layers, each number is the number * of nodes on a hidden layer. */ public String getHiddenLayers() { return m_hiddenLayers; } /** * This will set whether A GUI is brought up to allow interaction by the user * with the neural network during training. * * @param a True if gui should be created. */ public void setGUI(boolean a) { m_gui = a; if (!a) { setAutoBuild(true); } else { setReset(false); } } /** * @return The true if should show gui. */ public boolean getGUI() { return m_gui; } /** * This will set the size of the validation set. * * @param a The size of the validation set, as a percentage of the whole. */ public void setValidationSetSize(int a) { if (a < 0 || a > 99) { return; } m_valSize = a; } /** * @return The percentage size of the validation set. */ public int getValidationSetSize() { return m_valSize; } /** * Set the number of training epochs to perform. Must be greater than 0. * * @param n The number of epochs to train through. */ public void setTrainingTime(int n) { if (n > 0) { m_numEpochs = n; } } /** * @return The number of epochs to train through. */ public int getTrainingTime() { return m_numEpochs; } /** * Call this function to place a node into the network list. * * @param n The node to place in the list. */ private void addNode(NeuralConnection n) { NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1]; for (int noa = 0; noa < m_neuralNodes.length; noa++) { temp1[noa] = m_neuralNodes[noa]; } temp1[temp1.length - 1] = n; m_neuralNodes = temp1; } /** * Call this function to remove the passed node from the list. This will only * remove the node if it is in the neuralnodes list. * * @param n The neuralConnection to remove. * @return True if removed false if not (because it wasn't there). */ private boolean removeNode(NeuralConnection n) { NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1]; int skip = 0; for (int noa = 0; noa < m_neuralNodes.length; noa++) { if (n == m_neuralNodes[noa]) { skip++; } else if (!((noa - skip) >= temp1.length)) { temp1[noa - skip] = m_neuralNodes[noa]; } else { return false; } } m_neuralNodes = temp1; return true; } /** * This function sets what the m_numeric flag to represent the passed class it * also performs the normalization of the attributes if applicable and sets up * the info to normalize the class. (note that regardless of the options it * will fill an array with the range and base, set to normalize all attributes * and the class to be between -1 and 1) * * @param inst the instances. * @return The modified instances. This needs to be done. If the attributes * are normalized then deep copies will be made of all the instances * which will need to be passed back out. */ private Instances setClassType(Instances inst) throws Exception { if (inst != null) { // x bounds m_attributeRanges = new double[inst.numAttributes()]; m_attributeBases = new double[inst.numAttributes()]; for (int noa = 0; noa < inst.numAttributes(); noa++) { double min = Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < inst.numInstances(); i++) { if (!inst.instance(i).isMissing(noa)) { double value = inst.instance(i).value(noa); if (value < min) { min = value; } if (value > max) { max = value; } } } m_attributeRanges[noa] = (max - min) / 2; m_attributeBases[noa] = (max + min) / 2; } if (m_normalizeAttributes) { for (int i = 0; i < inst.numInstances(); i++) { Instance currentInstance = inst.instance(i); double[] instance = new double[inst.numAttributes()]; for (int noa = 0; noa < inst.numAttributes(); noa++) { if (noa != inst.classIndex()) { if (m_attributeRanges[noa] != 0) { instance[noa] = (currentInstance.value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]; } else { instance[noa] = currentInstance.value(noa) - m_attributeBases[noa]; } } else { instance[noa] = currentInstance.value(noa); } } inst.set(i, new DenseInstance(currentInstance.weight(), instance)); } } if (inst.classAttribute().isNumeric()) { m_numeric = true; } else { m_numeric = false; } } return inst; } /** * A function used to stop the code that called buildclassifier from * continuing on before the user has finished the decision tree. * * @param tf True to stop the thread, False to release the thread that is * waiting there (if one). */ public synchronized void blocker(boolean tf) { if (tf) { try { wait(); } catch (InterruptedException e) { } } else { notifyAll(); } } /** * Call this function to update the control panel for the gui. */ private void updateDisplay() { if (m_gui) { m_controlPanel.m_errorLabel.repaint(); m_controlPanel.m_epochsLabel.repaint(); } } /** * this will reset all the nodes in the network. */ private void resetNetwork() { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].reset(); } } /** * This will cause the output values of all the nodes to be calculated. Note * that the m_currentInstance is used to calculate these values. */ private void calculateOutputs() { for (int noc = 0; noc < m_numClasses; noc++) { // get the values. m_outputs[noc].outputValue(true); } } /** * This will cause the error values to be calculated for all nodes. Note that * the m_currentInstance is used to calculate these values. Also the output * values should have been calculated first. * * @return The squared error. */ private double calculateErrors() throws Exception { double ret = 0, temp = 0; for (int noc = 0; noc < m_numAttributes; noc++) { // get the errors. m_inputs[noc].errorValue(true); } for (int noc = 0; noc < m_numClasses; noc++) { temp = m_outputs[noc].errorValue(false); ret += temp * temp; } return ret; } /** * This will cause the weight values to be updated based on the learning rate, * momentum and the errors that have been calculated for each node. * * @param l The learning rate to update with. * @param m The momentum to update with. */ private void updateNetworkWeights(double l, double m) { for (int noc = 0; noc < m_numClasses; noc++) { // update weights m_outputs[noc].updateWeights(l, m); } } /** * This creates the required input units. */ private void setupInputs() throws Exception { m_inputs = new NeuralEnd[m_numAttributes]; int now = 0; for (int noa = 0; noa < m_numAttributes + 1; noa++) { if (m_instances.classIndex() != noa) { m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name()); m_inputs[noa - now].setX(.1); m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1)); m_inputs[noa - now].setLink(true, noa); } else { now = 1; } } } /** * This creates the required output units. */ private void setupOutputs() throws Exception { m_outputs = new NeuralEnd[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { if (m_numeric) { m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name()); } else { m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().value(noa)); } m_outputs[noa].setX(.9); m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1)); m_outputs[noa].setLink(false, noa); NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX(.75); temp.setY((noa + 1.0) / (m_numClasses + 1)); addNode(temp); NeuralConnection.connect(temp, m_outputs[noa]); } } /** * Call this function to automatically generate the hidden units */ private void setupHiddenLayer() { StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ","); int val = 0; // num of nodes in a layer int prev = 0; // used to remember the previous layer int num = tok.countTokens(); // number of layers String c; for (int noa = 0; noa < num; noa++) { // note that I am using the Double to get the value rather than the // Integer class, because for some reason the Double implementation can // handle leading white space and the integer version can't!?! c = tok.nextToken().trim(); if (c.equals("a")) { val = (m_numAttributes + m_numClasses) / 2; } else if (c.equals("i")) { val = m_numAttributes; } else if (c.equals("o")) { val = m_numClasses; } else if (c.equals("t")) { val = m_numAttributes + m_numClasses; } else { val = Double.valueOf(c).intValue(); } for (int nob = 0; nob < val; nob++) { NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX(.5 / (num) * noa + .25); temp.setY((nob + 1.0) / (val + 1)); addNode(temp); if (noa > 0) { // then do connections for (int noc = m_neuralNodes.length - nob - 1 - prev; noc < m_neuralNodes.length - nob - 1; noc++) { NeuralConnection.connect(m_neuralNodes[noc], temp); } } } prev = val; } tok = new StringTokenizer(m_hiddenLayers, ","); c = tok.nextToken(); if (c.equals("a")) { val = (m_numAttributes + m_numClasses) / 2; } else if (c.equals("i")) { val = m_numAttributes; } else if (c.equals("o")) { val = m_numClasses; } else if (c.equals("t")) { val = m_numAttributes + m_numClasses; } else { val = Double.valueOf(c).intValue(); } if (val == 0) { for (int noa = 0; noa < m_numAttributes; noa++) { for (int nob = 0; nob < m_numClasses; nob++) { NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]); } } } else { for (int noa = 0; noa < m_numAttributes; noa++) { for (int nob = m_numClasses; nob < m_numClasses + val; nob++) { NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]); } } for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length; noa++) { for (int nob = 0; nob < m_numClasses; nob++) { NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]); } } } } /** * This will go through all the nodes and check if they are connected to a * pure output unit. If so they will be set to be linear units. If not they * will be set to be sigmoid units. */ private void setEndsToLinear() { for (NeuralConnection m_neuralNode : m_neuralNodes) { if ((m_neuralNode.getType() & NeuralConnection.OUTPUT) == NeuralConnection.OUTPUT) { ((NeuralNode) m_neuralNode).setMethod(m_linearUnit); } else { ((NeuralNode) m_neuralNode).setMethod(m_sigmoidUnit); } } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** The instances in the validation set (if any) */ protected Instances valSet = null; /** The number of instances in the validation set (if any) */ protected int numInVal = 0; /** Total weight of the instances in the training set */ protected double totalWeight = 0; /** Total weight of the instances in the validation set (if any) */ protected double totalValWeight = 0; /** Drift off counter */ protected double driftOff = 0; /** To keep track of error */ protected double lastRight = Double.POSITIVE_INFINITY; protected double bestError = Double.POSITIVE_INFINITY; /** Data in original format (in case learning rate gets reset */ protected Instances originalFormatData = null; /** * Initializes an iterative classifier. * * @param data the instances to be used in induction * @exception Exception if the model cannot be initialized */ @Override public void initializeClassifier(Instances data) throws Exception { m_numItsPerformed = 0; if (m_instances == null || m_instances.numInstances() == 0) { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); originalFormatData = data; m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); // only class? -> use ZeroR model if (data.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_useDefaultModel = true; return; } else { m_useDefaultModel = false; } m_epoch = 0; m_error = 0; m_instances = null; m_currentInstance = null; m_controlPanel = null; m_nodePanel = null; m_outputs = new NeuralEnd[0]; m_inputs = new NeuralEnd[0]; m_numAttributes = 0; m_numClasses = 0; m_neuralNodes = new NeuralConnection[0]; m_selected = new ArrayList(4); m_nextId = 0; m_stopIt = true; m_stopped = true; m_accepted = false; m_instances = new Instances(data); m_random = new Random(m_randomSeed); m_instances.randomize(m_random); if (m_useNomToBin) { m_nominalToBinaryFilter = new NominalToBinary(); m_nominalToBinaryFilter.setInputFormat(m_instances); m_instances = Filter.useFilter(m_instances, m_nominalToBinaryFilter); } m_numAttributes = m_instances.numAttributes() - 1; m_numClasses = m_instances.numClasses(); setClassType(m_instances); // this sets up the validation set. // numinval is needed later numInVal = (int) (m_valSize / 100.0 * m_instances.numInstances()); if (m_valSize > 0) { if (numInVal == 0) { numInVal = 1; } valSet = new Instances(m_instances, 0, numInVal); } // ///////// setupInputs(); setupOutputs(); if (m_autoBuild) { setupHiddenLayer(); } // /////////////////////////// // this sets up the gui for usage if (m_gui) { m_win = Utils.getWekaJFrame("Neural Network", null); m_win.addWindowListener(new WindowAdapter() { @Override public void windowClosing(WindowEvent e) { boolean k = m_stopIt; m_stopIt = true; int well = JOptionPane.showConfirmDialog(m_win, "Are You Sure...\n" + "Click Yes To Accept" + " The Neural Network" + "\n Click No To Return", "Accept Neural Network", JOptionPane.YES_NO_OPTION); if (well == 0) { m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); m_accepted = true; blocker(false); } else { m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE); } m_stopIt = k; } }); m_win.getContentPane().setLayout(new BorderLayout()); m_nodePanel = new NodePanel(); // without the following two lines, the // NodePanel.paintComponents(Graphics) // method will go berserk if the network doesn't fit completely: it will // get called on a constant basis, using 100% of the CPU // see the following forum thread: // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011 m_nodePanel.setPreferredSize(new Dimension(640, 480)); m_nodePanel.revalidate(); JScrollPane sp = new JScrollPane(m_nodePanel, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS); m_controlPanel = new ControlPanel(); m_win.getContentPane().add(sp, BorderLayout.CENTER); m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH); m_win.setSize(640, 480); m_win.setVisible(true); } // This sets up the initial state of the gui if (m_gui) { blocker(true); m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); } // For silly situations in which the network gets accepted before training // commences if (m_numeric) { setEndsToLinear(); } if (m_accepted) { return; } // connections done. totalWeight = 0; totalValWeight = 0; driftOff = 0; lastRight = Double.POSITIVE_INFINITY; bestError = Double.POSITIVE_INFINITY; // ensure that at least 1 instance is trained through. if (numInVal == m_instances.numInstances()) { numInVal--; } if (numInVal < 0) { numInVal = 0; } for (int noa = numInVal; noa < m_instances.numInstances(); noa++) { if (!m_instances.instance(noa).classIsMissing()) { totalWeight += m_instances.instance(noa).weight(); } } if (m_valSize != 0) { for (int noa = 0; noa < valSet.numInstances(); noa++) { if (!valSet.instance(noa).classIsMissing()) { totalValWeight += valSet.instance(noa).weight(); } } } m_stopped = false; } } /** * Performs one iteration. * * @return false if no further iterations could be performed, true otherwise * @exception Exception if this iteration fails for unexpected reasons */ @Override public boolean next() throws Exception { if (m_accepted || m_useDefaultModel || m_instances.numInstances() == 0) { // Has user accepted the network already or do we need to use default model? return false; } m_epoch++; m_numItsPerformed++; double right = 0; for (int nob = numInVal; nob < m_instances.numInstances(); nob++) { m_currentInstance = m_instances.instance(nob); if (!m_currentInstance.classIsMissing()) { // this is where the network updating (and training occurs, for the // training set resetNetwork(); calculateOutputs(); double tempRate = m_learningRate * m_currentInstance.weight(); if (m_decay) { tempRate /= m_epoch; } right += (calculateErrors() / m_instances.numClasses()) * m_currentInstance.weight(); updateNetworkWeights(tempRate, m_momentum); } } right /= totalWeight; if (Double.isInfinite(right) || Double.isNaN(right)) { if ((!m_reset) || (originalFormatData == null)){ m_instances = null; throw new Exception("Network cannot train. Try restarting with a smaller learning rate."); } else { // reset the network if possible if (m_learningRate <= Utils.SMALL) { throw new IllegalStateException("Learning rate got too small (" + m_learningRate + " <= " + Utils.SMALL + ")!"); } double origRate = m_learningRate; // only used for when reset m_learningRate /= 2; buildClassifier(originalFormatData); m_learningRate = origRate; return false; } } // //////////////////////do validation testing if applicable if (m_valSize != 0) { right = 0; if (valSet == null) { throw new IllegalArgumentException("Trying to use validation set but validation set is null."); } for (int nob = 0; nob < valSet.numInstances(); nob++) { m_currentInstance = valSet.instance(nob); if (!m_currentInstance.classIsMissing()) { // this is where the network updating occurs, for the validation set resetNetwork(); calculateOutputs(); right += (calculateErrors() / valSet.numClasses()) * m_currentInstance.weight(); // note 'right' could be calculated here just using // the calculate output values. This would be faster. // be less modular } } if (right < lastRight) { if (right < bestError) { bestError = right; // save the network weights at this point for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].saveWeights(); } driftOff = 0; } } else { driftOff++; } lastRight = right; // if (driftOff > m_driftThreshold || m_epoch + 1 >= m_numEpochs) { if (driftOff > m_driftThreshold || m_numItsPerformed + 1 >= m_numEpochs) { for (int noc = 0; noc < m_numClasses; noc++) { m_outputs[noc].restoreWeights(); } m_accepted = true; } right /= totalValWeight; } m_error = right; // shows what the neuralnet is upto if a gui exists. updateDisplay(); // This junction controls what state the gui is in at the end of each // epoch, Such as if it is paused, if it is resumable etc... if (m_gui) { // while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && !m_accepted) { while ((m_stopIt || (m_numItsPerformed >= m_numEpochs && m_valSize == 0)) && !m_accepted) { m_stopIt = true; m_stopped = true; //if (m_epoch >= m_numEpochs && m_valSize == 0) { if (m_numItsPerformed >= m_numEpochs && m_valSize == 0) { m_controlPanel.m_startStop.setEnabled(false); } else { m_controlPanel.m_startStop.setEnabled(true); } m_controlPanel.m_startStop.setText("Start"); m_controlPanel.m_startStop.setActionCommand("Start"); m_controlPanel.m_changeEpochs.setEnabled(true); m_controlPanel.m_changeLearning.setEnabled(true); m_controlPanel.m_changeMomentum.setEnabled(true); blocker(true); if (m_numeric) { setEndsToLinear(); } } m_controlPanel.m_changeEpochs.setEnabled(false); m_controlPanel.m_changeLearning.setEnabled(false); m_controlPanel.m_changeMomentum.setEnabled(false); m_stopped = false; // if the network has been accepted stop the training loop if (m_accepted) { return false; } } if (m_accepted) { return false; } // if (m_epoch < m_numEpochs) { if (m_numItsPerformed < m_numEpochs) { return true; // We can keep iterating } else { return false; } } /** * Signal end of iterating, useful for any house-keeping/cleanup * * @exception Exception if cleanup fails */ @Override public void done() throws Exception { if (m_gui) { m_win.dispose(); m_controlPanel = null; m_nodePanel = null; } if (!getResume()) { if (!m_useDefaultModel) { m_instances = new Instances(m_instances, 0); } m_currentInstance = null; valSet = null; originalFormatData = null; } } /** * Tool tip text for resume property * * @return the tool tip text for the finalize property */ public String resumeTipText() { return "Set whether classifier can continue training after performing the" + "requested number of iterations. \n\tNote that setting this to true will " + "retain certain data structures which can increase the \n\t" + "size of the model."; } /** * If called with argument true, then the next time done() is called the model is effectively * "frozen" and no further iterations can be performed * * @param resume true if the model is to be finalized after performing iterations */ public void setResume(boolean resume) { m_resume = resume; } /** * Returns true if the model is to be finalized (or has been finalized) after * training. * * @return the current value of finalize */ public boolean getResume() { return m_resume; } /** * Call this function to build and train a neural network for the training * data provided. * * @param i The training data. * @throws Exception if can't build classification properly. */ @Override public void buildClassifier(Instances i) throws Exception { m_instances = null; // Initialize classifier initializeClassifier(i); // For the given number of iterations while (next()) { } // Clean up done(); } /** * Call this function to predict the class of an instance once a * classification model has been built with the buildClassifier call. * * @param i The instance to classify. * @return A double array filled with the probabilities of each class type. * @throws Exception if can't classify instance. */ @Override public double[] distributionForInstance(Instance i) throws Exception { // default model? if (m_useDefaultModel) { return m_ZeroR.distributionForInstance(i); } if (m_useNomToBin) { m_nominalToBinaryFilter.input(i); m_currentInstance = m_nominalToBinaryFilter.output(); } else { m_currentInstance = i; } // Make a copy of the instance so that it isn't modified m_currentInstance = (Instance) m_currentInstance.copy(); if (m_normalizeAttributes) { double[] instance = new double[m_currentInstance.numAttributes()]; for (int noa = 0; noa < m_instances.numAttributes(); noa++) { if (noa != m_instances.classIndex()) { if (m_attributeRanges[noa] != 0) { instance[noa] = (m_currentInstance.value(noa) - m_attributeBases[noa]) / m_attributeRanges[noa]; } else { instance[noa] = m_currentInstance.value(noa) - m_attributeBases[noa]; } } else { instance[noa] = m_currentInstance.value(noa); } } m_currentInstance = new DenseInstance(m_currentInstance.weight(), instance); m_currentInstance.setDataset(m_instances); } resetNetwork(); // since all the output values are needed. // They are calculated manually here and the values collected. double[] theArray = new double[m_numClasses]; for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] = m_outputs[noa].outputValue(true); } if (m_instances.classAttribute().isNumeric()) { return theArray; } // now normalize the array double count = 0; for (int noa = 0; noa < m_numClasses; noa++) { count += theArray[noa]; } if (count <= 0) { return m_ZeroR.distributionForInstance(i); } for (int noa = 0; noa < m_numClasses; noa++) { theArray[noa] /= count; } return theArray; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy