weka.classifiers.functions.MultilayerPerceptron Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* MultilayerPerceptron.java
* Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.functions;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.FontMetrics;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;
import javax.swing.BorderFactory;
import javax.swing.Box;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTextField;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.functions.neural.LinearUnit;
import weka.classifiers.functions.neural.NeuralConnection;
import weka.classifiers.functions.neural.NeuralNode;
import weka.classifiers.functions.neural.SigmoidUnit;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Randomizable;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
/**
* A Classifier that uses backpropagation to classify
* instances.
* This network can be built by hand, created by an algorithm or both. The
* network can also be monitored and modified during training time. The nodes in
* this network are all sigmoid (except for when the class is numeric in which
* case the the output nodes become unthresholded linear units).
*
*
*
* Valid options are:
*
*
*
* -L <learning rate>
* Learning Rate for the backpropagation algorithm.
* (Value should be between 0 - 1, Default = 0.3).
*
*
*
* -M <momentum>
* Momentum Rate for the backpropagation algorithm.
* (Value should be between 0 - 1, Default = 0.2).
*
*
*
* -N <number of epochs>
* Number of epochs to train through.
* (Default = 500).
*
*
*
* -V <percentage size of validation set>
* Percentage size of validation set to use to terminate
* training (if this is non zero it can pre-empt num of epochs.
* (Value should be between 0 - 100, Default = 0).
*
*
*
* -S <seed>
* The value used to seed the random number generator
* (Value should be >= 0 and and a long, Default = 0).
*
*
*
* -E <threshold for number of consequetive errors>
* The consequetive number of errors allowed for validation
* testing before the netwrok terminates.
* (Value should be > 0, Default = 20).
*
*
*
* -G
* GUI will be opened.
* (Use this to bring up a GUI).
*
*
*
* -A
* Autocreation of the network connections will NOT be done.
* (This will be ignored if -G is NOT set)
*
*
*
* -B
* A NominalToBinary filter will NOT automatically be used.
* (Set this to not use a NominalToBinary filter).
*
*
*
* -H <comma seperated numbers for nodes on each layer>
* The hidden layers to be created for the network.
* (Value should be a list of comma separated Natural
* numbers or the letters 'a' = (attribs + classes) / 2,
* 'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
* for wildcard values, Default = a).
*
*
*
* -C
* Normalizing a numeric class will NOT be done.
* (Set this to not normalize the class if it's numeric).
*
*
*
* -I
* Normalizing the attributes will NOT be done.
* (Set this to not normalize the attributes).
*
*
*
* -R
* Reseting the network will NOT be allowed.
* (Set this to not allow the network to reset).
*
*
*
* -D
* Learning rate decay will occur.
* (Set this to cause the learning rate to decay).
*
*
*
*
* @author Malcolm Ware ([email protected])
* @version $Revision: 10169 $
*/
public class MultilayerPerceptron extends AbstractClassifier implements
OptionHandler, WeightedInstancesHandler, Randomizable {
/** for serialization */
private static final long serialVersionUID = -5990607817048210779L;
/**
* Main method for testing this class.
*
* @param argv should contain command line options (see setOptions)
*/
public static void main(String[] argv) {
runClassifier(new MultilayerPerceptron(), argv);
}
/**
* This inner class is used to connect the nodes in the network up to the data
* that they are classifying, Note that objects of this class are only
* suitable to go on the attribute side or class side of the network and not
* both.
*/
protected class NeuralEnd extends NeuralConnection {
/** for serialization */
static final long serialVersionUID = 7305185603191183338L;
/**
* the value that represents the instance value this node represents. For an
* input it is the attribute number, for an output, if nominal it is the
* class value.
*/
private int m_link;
/** True if node is an input, False if it's an output. */
private boolean m_input;
/**
* Constructor
*/
public NeuralEnd(String id) {
super(id);
m_link = 0;
m_input = true;
}
/**
* Call this function to determine if the point at x,y is on the unit.
*
* @param g The graphics context for font size info.
* @param x The x coord.
* @param y The y coord.
* @param w The width of the display.
* @param h The height of the display.
* @return True if the point is on the unit, false otherwise.
*/
@Override
public boolean onUnit(Graphics g, int x, int y, int w, int h) {
FontMetrics fm = g.getFontMetrics();
int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int) (m_y * h) - fm.getHeight() / 2;
if (x < l || x > l + fm.stringWidth(m_id) + 4 || y < t
|| y > t + fm.getHeight() + fm.getDescent() + 4) {
return false;
}
return true;
}
/**
* This will draw the node id to the graphics context.
*
* @param g The graphics context.
* @param w The width of the drawing area.
* @param h The height of the drawing area.
*/
@Override
public void drawNode(Graphics g, int w, int h) {
if ((m_type & PURE_INPUT) == PURE_INPUT) {
g.setColor(Color.green);
} else {
g.setColor(Color.orange);
}
FontMetrics fm = g.getFontMetrics();
int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int) (m_y * h) - fm.getHeight() / 2;
g.fill3DRect(l, t, fm.stringWidth(m_id) + 4,
fm.getHeight() + fm.getDescent() + 4, true);
g.setColor(Color.black);
g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
}
/**
* Call this function to draw the node highlighted.
*
* @param g The graphics context.
* @param w The width of the drawing area.
* @param h The height of the drawing area.
*/
@Override
public void drawHighlight(Graphics g, int w, int h) {
g.setColor(Color.black);
FontMetrics fm = g.getFontMetrics();
int l = (int) (m_x * w) - fm.stringWidth(m_id) / 2;
int t = (int) (m_y * h) - fm.getHeight() / 2;
g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8,
fm.getHeight() + fm.getDescent() + 8);
drawNode(g, w, h);
}
/**
* Call this to get the output value of this unit.
*
* @param calculate True if the value should be calculated if it hasn't been
* already.
* @return The output value, or NaN, if the value has not been calculated.
*/
@Override
public double outputValue(boolean calculate) {
if (Double.isNaN(m_unitValue) && calculate) {
if (m_input) {
if (m_currentInstance.isMissing(m_link)) {
m_unitValue = 0;
} else {
m_unitValue = m_currentInstance.value(m_link);
}
} else {
// node is an output.
m_unitValue = 0;
for (int noa = 0; noa < m_numInputs; noa++) {
m_unitValue += m_inputList[noa].outputValue(true);
}
if (m_numeric && m_normalizeClass) {
// then scale the value;
// this scales linearly from between -1 and 1
m_unitValue = m_unitValue
* m_attributeRanges[m_instances.classIndex()]
+ m_attributeBases[m_instances.classIndex()];
}
}
}
return m_unitValue;
}
/**
* Call this to get the error value of this unit, which in this case is the
* difference between the predicted class, and the actual class.
*
* @param calculate True if the value should be calculated if it hasn't been
* already.
* @return The error value, or NaN, if the value has not been calculated.
*/
@Override
public double errorValue(boolean calculate) {
if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
if (m_input) {
m_unitError = 0;
for (int noa = 0; noa < m_numOutputs; noa++) {
m_unitError += m_outputList[noa].errorValue(true);
}
} else {
if (m_currentInstance.classIsMissing()) {
m_unitError = .1;
} else if (m_instances.classAttribute().isNominal()) {
if (m_currentInstance.classValue() == m_link) {
m_unitError = 1 - m_unitValue;
} else {
m_unitError = 0 - m_unitValue;
}
} else if (m_numeric) {
if (m_normalizeClass) {
if (m_attributeRanges[m_instances.classIndex()] == 0) {
m_unitError = 0;
} else {
m_unitError = (m_currentInstance.classValue() - m_unitValue)
/ m_attributeRanges[m_instances.classIndex()];
// m_numericRange;
}
} else {
m_unitError = m_currentInstance.classValue() - m_unitValue;
}
}
}
}
return m_unitError;
}
/**
* Call this to reset the value and error for this unit, ready for the next
* run. This will also call the reset function of all units that are
* connected as inputs to this one. This is also the time that the update
* for the listeners will be performed.
*/
@Override
public void reset() {
if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
m_unitValue = Double.NaN;
m_unitError = Double.NaN;
m_weightsUpdated = false;
for (int noa = 0; noa < m_numInputs; noa++) {
m_inputList[noa].reset();
}
}
}
/**
* Call this to have the connection save the current weights.
*/
@Override
public void saveWeights() {
for (int i = 0; i < m_numInputs; i++) {
m_inputList[i].saveWeights();
}
}
/**
* Call this to have the connection restore from the saved weights.
*/
@Override
public void restoreWeights() {
for (int i = 0; i < m_numInputs; i++) {
m_inputList[i].restoreWeights();
}
}
/**
* Call this function to set What this end unit represents.
*
* @param input True if this unit is used for entering an attribute, False
* if it's used for determining a class value.
* @param val The attribute number or class type that this unit represents.
* (for nominal attributes).
*/
public void setLink(boolean input, int val) throws Exception {
m_input = input;
if (input) {
m_type = PURE_INPUT;
} else {
m_type = PURE_OUTPUT;
}
if (val < 0
|| (input && val > m_instances.numAttributes())
|| (!input && m_instances.classAttribute().isNominal() && val > m_instances
.classAttribute().numValues())) {
m_link = 0;
} else {
m_link = val;
}
}
/**
* @return link for this node.
*/
public int getLink() {
return m_link;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10169 $");
}
}
/**
* Inner class used to draw the nodes onto.(uses the node lists!!) This will
* also handle the user input.
*/
private class NodePanel extends JPanel implements RevisionHandler {
/** for serialization */
static final long serialVersionUID = -3067621833388149984L;
/**
* The constructor.
*/
public NodePanel() {
addMouseListener(new MouseAdapter() {
@Override
public void mousePressed(MouseEvent e) {
if (!m_stopped) {
return;
}
if ((e.getModifiers() & MouseEvent.BUTTON1_MASK) == MouseEvent.BUTTON1_MASK
&& !e.isAltDown()) {
Graphics g = NodePanel.this.getGraphics();
int x = e.getX();
int y = e.getY();
int w = NodePanel.this.getWidth();
int h = NodePanel.this.getHeight();
ArrayList tmp = new ArrayList(4);
for (int noa = 0; noa < m_numAttributes; noa++) {
if (m_inputs[noa].onUnit(g, x, y, w, h)) {
tmp.add(m_inputs[noa]);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
true);
return;
}
}
for (int noa = 0; noa < m_numClasses; noa++) {
if (m_outputs[noa].onUnit(g, x, y, w, h)) {
tmp.add(m_outputs[noa]);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
true);
return;
}
}
for (NeuralConnection m_neuralNode : m_neuralNodes) {
if (m_neuralNode.onUnit(g, x, y, w, h)) {
tmp.add(m_neuralNode);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
true);
return;
}
}
NeuralNode temp = new NeuralNode(String.valueOf(m_nextId),
m_random, m_sigmoidUnit);
m_nextId++;
temp.setX((double) e.getX() / w);
temp.setY((double) e.getY() / h);
tmp.add(temp);
addNode(temp);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
true);
} else {
// then right click
Graphics g = NodePanel.this.getGraphics();
int x = e.getX();
int y = e.getY();
int w = NodePanel.this.getWidth();
int h = NodePanel.this.getHeight();
ArrayList tmp = new ArrayList(4);
for (int noa = 0; noa < m_numAttributes; noa++) {
if (m_inputs[noa].onUnit(g, x, y, w, h)) {
tmp.add(m_inputs[noa]);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
false);
return;
}
}
for (int noa = 0; noa < m_numClasses; noa++) {
if (m_outputs[noa].onUnit(g, x, y, w, h)) {
tmp.add(m_outputs[noa]);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
false);
return;
}
}
for (NeuralConnection m_neuralNode : m_neuralNodes) {
if (m_neuralNode.onUnit(g, x, y, w, h)) {
tmp.add(m_neuralNode);
selection(
tmp,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
false);
return;
}
}
selection(
null,
(e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK,
false);
}
}
});
}
/**
* This function gets called when the user has clicked something It will
* amend the current selection or connect the current selection to the new
* selection. Or if nothing was selected and the right button was used it
* will delete the node.
*
* @param v The units that were selected.
* @param ctrl True if ctrl was held down.
* @param left True if it was the left mouse button.
*/
private void selection(ArrayList v, boolean ctrl,
boolean left) {
if (v == null) {
// then unselect all.
m_selected.clear();
repaint();
return;
}
// then exclusive or the new selection with the current one.
if ((ctrl || m_selected.size() == 0) && left) {
boolean removed = false;
for (int noa = 0; noa < v.size(); noa++) {
removed = false;
for (int nob = 0; nob < m_selected.size(); nob++) {
if (v.get(noa) == m_selected.get(nob)) {
// then remove that element
m_selected.remove(nob);
removed = true;
break;
}
}
if (!removed) {
m_selected.add(v.get(noa));
}
}
repaint();
return;
}
if (left) {
// then connect the current selection to the new one.
for (int noa = 0; noa < m_selected.size(); noa++) {
for (int nob = 0; nob < v.size(); nob++) {
NeuralConnection.connect(m_selected.get(noa), v.get(nob));
}
}
} else if (m_selected.size() > 0) {
// then disconnect the current selection from the new one.
for (int noa = 0; noa < m_selected.size(); noa++) {
for (int nob = 0; nob < v.size(); nob++) {
NeuralConnection.disconnect(m_selected.get(noa), v.get(nob));
NeuralConnection.disconnect(v.get(nob), m_selected.get(noa));
}
}
} else {
// then remove the selected node. (it was right clicked while
// no other units were selected
for (int noa = 0; noa < v.size(); noa++) {
v.get(noa).removeAllInputs();
v.get(noa).removeAllOutputs();
removeNode(v.get(noa));
}
}
repaint();
}
/**
* This will paint the nodes ontot the panel.
*
* @param g The graphics context.
*/
@Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
int x = getWidth();
int y = getHeight();
if (25 * m_numAttributes > 25 * m_numClasses && 25 * m_numAttributes > y) {
setSize(x, 25 * m_numAttributes);
} else if (25 * m_numClasses > y) {
setSize(x, 25 * m_numClasses);
} else {
setSize(x, y);
}
y = getHeight();
for (int noa = 0; noa < m_numAttributes; noa++) {
m_inputs[noa].drawInputLines(g, x, y);
}
for (int noa = 0; noa < m_numClasses; noa++) {
m_outputs[noa].drawInputLines(g, x, y);
m_outputs[noa].drawOutputLines(g, x, y);
}
for (NeuralConnection m_neuralNode : m_neuralNodes) {
m_neuralNode.drawInputLines(g, x, y);
}
for (int noa = 0; noa < m_numAttributes; noa++) {
m_inputs[noa].drawNode(g, x, y);
}
for (int noa = 0; noa < m_numClasses; noa++) {
m_outputs[noa].drawNode(g, x, y);
}
for (NeuralConnection m_neuralNode : m_neuralNodes) {
m_neuralNode.drawNode(g, x, y);
}
for (int noa = 0; noa < m_selected.size(); noa++) {
m_selected.get(noa).drawHighlight(g, x, y);
}
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10169 $");
}
}
/**
* This provides the basic controls for working with the neuralnetwork
*
* @author Malcolm Ware ([email protected])
* @version $Revision: 10169 $
*/
class ControlPanel extends JPanel implements RevisionHandler {
/** for serialization */
static final long serialVersionUID = 7393543302294142271L;
/** The start stop button. */
public JButton m_startStop;
/** The button to accept the network (even if it hasn't done all epochs. */
public JButton m_acceptButton;
/** A label to state the number of epochs processed so far. */
public JPanel m_epochsLabel;
/** A label to state the total number of epochs to be processed. */
public JLabel m_totalEpochsLabel;
/** A text field to allow the changing of the total number of epochs. */
public JTextField m_changeEpochs;
/** A label to state the learning rate. */
public JLabel m_learningLabel;
/** A label to state the momentum. */
public JLabel m_momentumLabel;
/** A text field to allow the changing of the learning rate. */
public JTextField m_changeLearning;
/** A text field to allow the changing of the momentum. */
public JTextField m_changeMomentum;
/**
* A label to state roughly the accuracy of the network.(because the
* accuracy is calculated per epoch, but the network is changing throughout
* each epoch train).
*/
public JPanel m_errorLabel;
/** The constructor. */
public ControlPanel() {
setBorder(BorderFactory.createTitledBorder("Controls"));
m_totalEpochsLabel = new JLabel("Num Of Epochs ");
m_epochsLabel = new JPanel() {
/** for serialization */
private static final long serialVersionUID = 2562773937093221399L;
@Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
g.drawString("Epoch " + m_epoch, 0, 10);
}
};
m_epochsLabel.setFont(m_totalEpochsLabel.getFont());
m_changeEpochs = new JTextField();
m_changeEpochs.setText("" + m_numEpochs);
m_errorLabel = new JPanel() {
/** for serialization */
private static final long serialVersionUID = 4390239056336679189L;
@Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
if (m_valSize == 0) {
g.drawString(
"Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10);
} else {
g.drawString(
"Validation Error per Epoch = "
+ Utils.doubleToString(m_error, 7), 0, 10);
}
}
};
m_errorLabel.setFont(m_epochsLabel.getFont());
m_learningLabel = new JLabel("Learning Rate = ");
m_momentumLabel = new JLabel("Momentum = ");
m_changeLearning = new JTextField();
m_changeMomentum = new JTextField();
m_changeLearning.setText("" + m_learningRate);
m_changeMomentum.setText("" + m_momentum);
setLayout(new BorderLayout(15, 10));
m_stopIt = true;
m_accepted = false;
m_startStop = new JButton("Start");
m_startStop.setActionCommand("Start");
m_acceptButton = new JButton("Accept");
m_acceptButton.setActionCommand("Accept");
JPanel buttons = new JPanel();
buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS));
buttons.add(m_startStop);
buttons.add(m_acceptButton);
add(buttons, BorderLayout.WEST);
JPanel data = new JPanel();
data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
Box ab = new Box(BoxLayout.X_AXIS);
ab.add(m_epochsLabel);
data.add(ab);
ab = new Box(BoxLayout.X_AXIS);
Component b = Box.createGlue();
ab.add(m_totalEpochsLabel);
ab.add(m_changeEpochs);
m_changeEpochs.setMaximumSize(new Dimension(200, 20));
ab.add(b);
data.add(ab);
ab = new Box(BoxLayout.X_AXIS);
ab.add(m_errorLabel);
data.add(ab);
add(data, BorderLayout.CENTER);
data = new JPanel();
data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
ab = new Box(BoxLayout.X_AXIS);
b = Box.createGlue();
ab.add(m_learningLabel);
ab.add(m_changeLearning);
m_changeLearning.setMaximumSize(new Dimension(200, 20));
ab.add(b);
data.add(ab);
ab = new Box(BoxLayout.X_AXIS);
b = Box.createGlue();
ab.add(m_momentumLabel);
ab.add(m_changeMomentum);
m_changeMomentum.setMaximumSize(new Dimension(200, 20));
ab.add(b);
data.add(ab);
add(data, BorderLayout.EAST);
m_startStop.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
if (e.getActionCommand().equals("Start")) {
m_stopIt = false;
m_startStop.setText("Stop");
m_startStop.setActionCommand("Stop");
int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
m_numEpochs = n;
m_changeEpochs.setText("" + m_numEpochs);
double m = Double.valueOf(m_changeLearning.getText()).doubleValue();
setLearningRate(m);
m_changeLearning.setText("" + m_learningRate);
m = Double.valueOf(m_changeMomentum.getText()).doubleValue();
setMomentum(m);
m_changeMomentum.setText("" + m_momentum);
blocker(false);
} else if (e.getActionCommand().equals("Stop")) {
m_stopIt = true;
m_startStop.setText("Start");
m_startStop.setActionCommand("Start");
}
}
});
m_acceptButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
m_accepted = true;
blocker(false);
}
});
m_changeEpochs.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
if (n > 0) {
m_numEpochs = n;
blocker(false);
}
}
});
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10169 $");
}
}
/**
* a ZeroR model in case no model can be built from the data or the network
* predicts all zeros for the classes
*/
private Classifier m_ZeroR;
/** Whether to use the default ZeroR model */
private boolean m_useDefaultModel = false;
/** The training instances. */
private Instances m_instances;
/** The current instance running through the network. */
private Instance m_currentInstance;
/** A flag to say that it's a numeric class. */
private boolean m_numeric;
/** The ranges for all the attributes. */
private double[] m_attributeRanges;
/** The base values for all the attributes. */
private double[] m_attributeBases;
/** The output units.(only feeds the errors, does no calcs) */
private NeuralEnd[] m_outputs;
/** The input units.(only feeds the inputs does no calcs) */
private NeuralEnd[] m_inputs;
/** All the nodes that actually comprise the logical neural net. */
private NeuralConnection[] m_neuralNodes;
/** The number of classes. */
private int m_numClasses = 0;
/** The number of attributes. */
private int m_numAttributes = 0; // note the number doesn't include the class.
/** The panel the nodes are displayed on. */
private NodePanel m_nodePanel;
/** The control panel. */
private ControlPanel m_controlPanel;
/** The next id number available for default naming. */
private int m_nextId;
/** A Vector list of the units currently selected. */
private ArrayList m_selected;
/** The number of epochs to train through. */
private int m_numEpochs;
/** a flag to state if the network should be running, or stopped. */
private boolean m_stopIt;
/** a flag to state that the network has in fact stopped. */
private boolean m_stopped;
/** a flag to state that the network should be accepted the way it is. */
private boolean m_accepted;
/** The window for the network. */
private JFrame m_win;
/**
* A flag to tell the build classifier to automatically build a neural net.
*/
private boolean m_autoBuild;
/**
* A flag to state that the gui for the network should be brought up. To allow
* interaction while training.
*/
private boolean m_gui;
/** An int to say how big the validation set should be. */
private int m_valSize;
/** The number to to use to quit on validation testing. */
private int m_driftThreshold;
/** The number used to seed the random number generator. */
private int m_randomSeed;
/** The actual random number generator. */
private Random m_random;
/** A flag to state that a nominal to binary filter should be used. */
private boolean m_useNomToBin;
/** The actual filter. */
private NominalToBinary m_nominalToBinaryFilter;
/** The string that defines the hidden layers */
private String m_hiddenLayers;
/** This flag states that the user wants the input values normalized. */
private boolean m_normalizeAttributes;
/** This flag states that the user wants the learning rate to decay. */
private boolean m_decay;
/** This is the learning rate for the network. */
private double m_learningRate;
/** This is the momentum for the network. */
private double m_momentum;
/** Shows the number of the epoch that the network just finished. */
private int m_epoch;
/** Shows the error of the epoch that the network just finished. */
private double m_error;
/**
* This flag states that the user wants the network to restart if it is found
* to be generating infinity or NaN for the error value. This would restart
* the network with the current options except that the learning rate would be
* smaller than before, (perhaps half of its current value). This option will
* not be available if the gui is chosen (if the gui is open the user can fix
* the network themselves, it is an architectural minefield for the network to
* be reset with the gui open).
*/
private boolean m_reset;
/**
* This flag states that the user wants the class to be normalized while
* processing in the network is done. (the final answer will be in the
* original range regardless). This option will only be used when the class is
* numeric.
*/
private boolean m_normalizeClass;
/**
* this is a sigmoid unit.
*/
private final SigmoidUnit m_sigmoidUnit;
/**
* This is a linear unit.
*/
private final LinearUnit m_linearUnit;
/**
* The constructor.
*/
public MultilayerPerceptron() {
m_instances = null;
m_currentInstance = null;
m_controlPanel = null;
m_nodePanel = null;
m_epoch = 0;
m_error = 0;
m_outputs = new NeuralEnd[0];
m_inputs = new NeuralEnd[0];
m_numAttributes = 0;
m_numClasses = 0;
m_neuralNodes = new NeuralConnection[0];
m_selected = new ArrayList(4);
m_nextId = 0;
m_stopIt = true;
m_stopped = true;
m_accepted = false;
m_numeric = false;
m_random = null;
m_nominalToBinaryFilter = new NominalToBinary();
m_sigmoidUnit = new SigmoidUnit();
m_linearUnit = new LinearUnit();
// setting all the options to their defaults. To completely change these
// defaults they will also need to be changed down the bottom in the
// setoptions function (the text info in the accompanying functions should
// also be changed to reflect the new defaults
m_normalizeClass = true;
m_normalizeAttributes = true;
m_autoBuild = true;
m_gui = false;
m_useNomToBin = true;
m_driftThreshold = 20;
m_numEpochs = 500;
m_valSize = 0;
m_randomSeed = 0;
m_hiddenLayers = "a";
m_learningRate = .3;
m_momentum = .2;
m_reset = true;
m_decay = false;
}
/**
* @param d True if the learning rate should decay.
*/
public void setDecay(boolean d) {
m_decay = d;
}
/**
* @return the flag for having the learning rate decay.
*/
public boolean getDecay() {
return m_decay;
}
/**
* This sets the network up to be able to reset itself with the current
* settings and the learning rate at half of what it is currently. This will
* only happen if the network creates NaN or infinite errors. Also this will
* continue to happen until the network is trained properly. The learning rate
* will also get set back to it's original value at the end of this. This can
* only be set to true if the GUI is not brought up.
*
* @param r True if the network should restart with it's current options and
* set the learning rate to half what it currently is.
*/
public void setReset(boolean r) {
if (m_gui) {
r = false;
}
m_reset = r;
}
/**
* @return The flag for reseting the network.
*/
public boolean getReset() {
return m_reset;
}
/**
* @param c True if the class should be normalized (the class will only ever
* be normalized if it is numeric). (Normalization puts the range
* between -1 - 1).
*/
public void setNormalizeNumericClass(boolean c) {
m_normalizeClass = c;
}
/**
* @return The flag for normalizing a numeric class.
*/
public boolean getNormalizeNumericClass() {
return m_normalizeClass;
}
/**
* @param a True if the attributes should be normalized (even nominal
* attributes will get normalized here) (range goes between -1 - 1).
*/
public void setNormalizeAttributes(boolean a) {
m_normalizeAttributes = a;
}
/**
* @return The flag for normalizing attributes.
*/
public boolean getNormalizeAttributes() {
return m_normalizeAttributes;
}
/**
* @param f True if a nominalToBinary filter should be used on the data.
*/
public void setNominalToBinaryFilter(boolean f) {
m_useNomToBin = f;
}
/**
* @return The flag for nominal to binary filter use.
*/
public boolean getNominalToBinaryFilter() {
return m_useNomToBin;
}
/**
* This seeds the random number generator, that is used when a random number
* is needed for the network.
*
* @param l The seed.
*/
@Override
public void setSeed(int l) {
if (l >= 0) {
m_randomSeed = l;
}
}
/**
* @return The seed for the random number generator.
*/
@Override
public int getSeed() {
return m_randomSeed;
}
/**
* This sets the threshold to use for when validation testing is being done.
* It works by ending testing once the error on the validation set has
* consecutively increased a certain number of times.
*
* @param t The threshold to use for this.
*/
public void setValidationThreshold(int t) {
if (t > 0) {
m_driftThreshold = t;
}
}
/**
* @return The threshold used for validation testing.
*/
public int getValidationThreshold() {
return m_driftThreshold;
}
/**
* The learning rate can be set using this command. NOTE That this is a static
* variable so it affect all networks that are running. Must be greater than 0
* and no more than 1.
*
* @param l The New learning rate.
*/
public void setLearningRate(double l) {
if (l > 0 && l <= 1) {
m_learningRate = l;
if (m_controlPanel != null) {
m_controlPanel.m_changeLearning.setText("" + l);
}
}
}
/**
* @return The learning rate for the nodes.
*/
public double getLearningRate() {
return m_learningRate;
}
/**
* The momentum can be set using this command. THE same conditions apply to
* this as to the learning rate.
*
* @param m The new Momentum.
*/
public void setMomentum(double m) {
if (m >= 0 && m <= 1) {
m_momentum = m;
if (m_controlPanel != null) {
m_controlPanel.m_changeMomentum.setText("" + m);
}
}
}
/**
* @return The momentum for the nodes.
*/
public double getMomentum() {
return m_momentum;
}
/**
* This will set whether the network is automatically built or if it is left
* up to the user. (there is nothing to stop a user from altering an autobuilt
* network however).
*
* @param a True if the network should be auto built.
*/
public void setAutoBuild(boolean a) {
if (!m_gui) {
a = true;
}
m_autoBuild = a;
}
/**
* @return The auto build state.
*/
public boolean getAutoBuild() {
return m_autoBuild;
}
/**
* This will set what the hidden layers are made up of when auto build is
* enabled. Note to have no hidden units, just put a single 0, Any more 0's
* will indicate that the string is badly formed and make it unaccepted.
* Negative numbers, and floats will do the same. There are also some
* wildcards. These are 'a' = (number of attributes + number of classes) / 2,
* 'i' = number of attributes, 'o' = number of classes, and 't' = number of
* attributes + number of classes.
*
* @param h A string with a comma seperated list of numbers. Each number is
* the number of nodes to be on a hidden layer.
*/
public void setHiddenLayers(String h) {
String tmp = "";
StringTokenizer tok = new StringTokenizer(h, ",");
if (tok.countTokens() == 0) {
return;
}
double dval;
int val;
String c;
boolean first = true;
while (tok.hasMoreTokens()) {
c = tok.nextToken().trim();
if (c.equals("a") || c.equals("i") || c.equals("o") || c.equals("t")) {
tmp += c;
} else {
dval = Double.valueOf(c).doubleValue();
val = (int) dval;
if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) && val >= 0)) {
tmp += val;
} else {
return;
}
}
first = false;
if (tok.hasMoreTokens()) {
tmp += ", ";
}
}
m_hiddenLayers = tmp;
}
/**
* @return A string representing the hidden layers, each number is the number
* of nodes on a hidden layer.
*/
public String getHiddenLayers() {
return m_hiddenLayers;
}
/**
* This will set whether A GUI is brought up to allow interaction by the user
* with the neural network during training.
*
* @param a True if gui should be created.
*/
public void setGUI(boolean a) {
m_gui = a;
if (!a) {
setAutoBuild(true);
} else {
setReset(false);
}
}
/**
* @return The true if should show gui.
*/
public boolean getGUI() {
return m_gui;
}
/**
* This will set the size of the validation set.
*
* @param a The size of the validation set, as a percentage of the whole.
*/
public void setValidationSetSize(int a) {
if (a < 0 || a > 99) {
return;
}
m_valSize = a;
}
/**
* @return The percentage size of the validation set.
*/
public int getValidationSetSize() {
return m_valSize;
}
/**
* Set the number of training epochs to perform. Must be greater than 0.
*
* @param n The number of epochs to train through.
*/
public void setTrainingTime(int n) {
if (n > 0) {
m_numEpochs = n;
}
}
/**
* @return The number of epochs to train through.
*/
public int getTrainingTime() {
return m_numEpochs;
}
/**
* Call this function to place a node into the network list.
*
* @param n The node to place in the list.
*/
private void addNode(NeuralConnection n) {
NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1];
for (int noa = 0; noa < m_neuralNodes.length; noa++) {
temp1[noa] = m_neuralNodes[noa];
}
temp1[temp1.length - 1] = n;
m_neuralNodes = temp1;
}
/**
* Call this function to remove the passed node from the list. This will only
* remove the node if it is in the neuralnodes list.
*
* @param n The neuralConnection to remove.
* @return True if removed false if not (because it wasn't there).
*/
private boolean removeNode(NeuralConnection n) {
NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1];
int skip = 0;
for (int noa = 0; noa < m_neuralNodes.length; noa++) {
if (n == m_neuralNodes[noa]) {
skip++;
} else if (!((noa - skip) >= temp1.length)) {
temp1[noa - skip] = m_neuralNodes[noa];
} else {
return false;
}
}
m_neuralNodes = temp1;
return true;
}
/**
* This function sets what the m_numeric flag to represent the passed class it
* also performs the normalization of the attributes if applicable and sets up
* the info to normalize the class. (note that regardless of the options it
* will fill an array with the range and base, set to normalize all attributes
* and the class to be between -1 and 1)
*
* @param inst the instances.
* @return The modified instances. This needs to be done. If the attributes
* are normalized then deep copies will be made of all the instances
* which will need to be passed back out.
*/
private Instances setClassType(Instances inst) throws Exception {
if (inst != null) {
// x bounds
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
double value;
m_attributeRanges = new double[inst.numAttributes()];
m_attributeBases = new double[inst.numAttributes()];
for (int noa = 0; noa < inst.numAttributes(); noa++) {
min = Double.POSITIVE_INFINITY;
max = Double.NEGATIVE_INFINITY;
for (int i = 0; i < inst.numInstances(); i++) {
if (!inst.instance(i).isMissing(noa)) {
value = inst.instance(i).value(noa);
if (value < min) {
min = value;
}
if (value > max) {
max = value;
}
}
}
m_attributeRanges[noa] = (max - min) / 2;
m_attributeBases[noa] = (max + min) / 2;
if (noa != inst.classIndex() && m_normalizeAttributes) {
for (int i = 0; i < inst.numInstances(); i++) {
if (m_attributeRanges[noa] != 0) {
inst.instance(i).setValue(
noa,
(inst.instance(i).value(noa) - m_attributeBases[noa])
/ m_attributeRanges[noa]);
} else {
inst.instance(i).setValue(noa,
inst.instance(i).value(noa) - m_attributeBases[noa]);
}
}
}
}
if (inst.classAttribute().isNumeric()) {
m_numeric = true;
} else {
m_numeric = false;
}
}
return inst;
}
/**
* A function used to stop the code that called buildclassifier from
* continuing on before the user has finished the decision tree.
*
* @param tf True to stop the thread, False to release the thread that is
* waiting there (if one).
*/
public synchronized void blocker(boolean tf) {
if (tf) {
try {
wait();
} catch (InterruptedException e) {
}
} else {
notifyAll();
}
}
/**
* Call this function to update the control panel for the gui.
*/
private void updateDisplay() {
if (m_gui) {
m_controlPanel.m_errorLabel.repaint();
m_controlPanel.m_epochsLabel.repaint();
}
}
/**
* this will reset all the nodes in the network.
*/
private void resetNetwork() {
for (int noc = 0; noc < m_numClasses; noc++) {
m_outputs[noc].reset();
}
}
/**
* This will cause the output values of all the nodes to be calculated. Note
* that the m_currentInstance is used to calculate these values.
*/
private void calculateOutputs() {
for (int noc = 0; noc < m_numClasses; noc++) {
// get the values.
m_outputs[noc].outputValue(true);
}
}
/**
* This will cause the error values to be calculated for all nodes. Note that
* the m_currentInstance is used to calculate these values. Also the output
* values should have been calculated first.
*
* @return The squared error.
*/
private double calculateErrors() throws Exception {
double ret = 0, temp = 0;
for (int noc = 0; noc < m_numAttributes; noc++) {
// get the errors.
m_inputs[noc].errorValue(true);
}
for (int noc = 0; noc < m_numClasses; noc++) {
temp = m_outputs[noc].errorValue(false);
ret += temp * temp;
}
return ret;
}
/**
* This will cause the weight values to be updated based on the learning rate,
* momentum and the errors that have been calculated for each node.
*
* @param l The learning rate to update with.
* @param m The momentum to update with.
*/
private void updateNetworkWeights(double l, double m) {
for (int noc = 0; noc < m_numClasses; noc++) {
// update weights
m_outputs[noc].updateWeights(l, m);
}
}
/**
* This creates the required input units.
*/
private void setupInputs() throws Exception {
m_inputs = new NeuralEnd[m_numAttributes];
int now = 0;
for (int noa = 0; noa < m_numAttributes + 1; noa++) {
if (m_instances.classIndex() != noa) {
m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
m_inputs[noa - now].setX(.1);
m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
m_inputs[noa - now].setLink(true, noa);
} else {
now = 1;
}
}
}
/**
* This creates the required output units.
*/
private void setupOutputs() throws Exception {
m_outputs = new NeuralEnd[m_numClasses];
for (int noa = 0; noa < m_numClasses; noa++) {
if (m_numeric) {
m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
} else {
m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().value(noa));
}
m_outputs[noa].setX(.9);
m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
m_outputs[noa].setLink(false, noa);
NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
m_sigmoidUnit);
m_nextId++;
temp.setX(.75);
temp.setY((noa + 1.0) / (m_numClasses + 1));
addNode(temp);
NeuralConnection.connect(temp, m_outputs[noa]);
}
}
/**
* Call this function to automatically generate the hidden units
*/
private void setupHiddenLayer() {
StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
int val = 0; // num of nodes in a layer
int prev = 0; // used to remember the previous layer
int num = tok.countTokens(); // number of layers
String c;
for (int noa = 0; noa < num; noa++) {
// note that I am using the Double to get the value rather than the
// Integer class, because for some reason the Double implementation can
// handle leading white space and the integer version can't!?!
c = tok.nextToken().trim();
if (c.equals("a")) {
val = (m_numAttributes + m_numClasses) / 2;
} else if (c.equals("i")) {
val = m_numAttributes;
} else if (c.equals("o")) {
val = m_numClasses;
} else if (c.equals("t")) {
val = m_numAttributes + m_numClasses;
} else {
val = Double.valueOf(c).intValue();
}
for (int nob = 0; nob < val; nob++) {
NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
m_sigmoidUnit);
m_nextId++;
temp.setX(.5 / (num) * noa + .25);
temp.setY((nob + 1.0) / (val + 1));
addNode(temp);
if (noa > 0) {
// then do connections
for (int noc = m_neuralNodes.length - nob - 1 - prev; noc < m_neuralNodes.length
- nob - 1; noc++) {
NeuralConnection.connect(m_neuralNodes[noc], temp);
}
}
}
prev = val;
}
tok = new StringTokenizer(m_hiddenLayers, ",");
c = tok.nextToken();
if (c.equals("a")) {
val = (m_numAttributes + m_numClasses) / 2;
} else if (c.equals("i")) {
val = m_numAttributes;
} else if (c.equals("o")) {
val = m_numClasses;
} else if (c.equals("t")) {
val = m_numAttributes + m_numClasses;
} else {
val = Double.valueOf(c).intValue();
}
if (val == 0) {
for (int noa = 0; noa < m_numAttributes; noa++) {
for (int nob = 0; nob < m_numClasses; nob++) {
NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
}
}
} else {
for (int noa = 0; noa < m_numAttributes; noa++) {
for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
}
}
for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length; noa++) {
for (int nob = 0; nob < m_numClasses; nob++) {
NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
}
}
}
}
/**
* This will go through all the nodes and check if they are connected to a
* pure output unit. If so they will be set to be linear units. If not they
* will be set to be sigmoid units.
*/
private void setEndsToLinear() {
for (NeuralConnection m_neuralNode : m_neuralNodes) {
if ((m_neuralNode.getType() & NeuralConnection.OUTPUT) == NeuralConnection.OUTPUT) {
((NeuralNode) m_neuralNode).setMethod(m_linearUnit);
} else {
((NeuralNode) m_neuralNode).setMethod(m_sigmoidUnit);
}
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Call this function to build and train a neural network for the training
* data provided.
*
* @param i The training data.
* @throws Exception if can't build classification properly.
*/
@Override
public void buildClassifier(Instances i) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(i);
// remove instances with missing class
i = new Instances(i);
i.deleteWithMissingClass();
m_ZeroR = new weka.classifiers.rules.ZeroR();
m_ZeroR.buildClassifier(i);
// only class? -> use ZeroR model
if (i.numAttributes() == 1) {
System.err
.println("Cannot build model (only class attribute present in data!), "
+ "using ZeroR model instead!");
m_useDefaultModel = true;
return;
} else {
m_useDefaultModel = false;
}
m_epoch = 0;
m_error = 0;
m_instances = null;
m_currentInstance = null;
m_controlPanel = null;
m_nodePanel = null;
m_outputs = new NeuralEnd[0];
m_inputs = new NeuralEnd[0];
m_numAttributes = 0;
m_numClasses = 0;
m_neuralNodes = new NeuralConnection[0];
m_selected = new ArrayList(4);
m_nextId = 0;
m_stopIt = true;
m_stopped = true;
m_accepted = false;
m_instances = new Instances(i);
m_random = new Random(m_randomSeed);
m_instances.randomize(m_random);
if (m_useNomToBin) {
m_nominalToBinaryFilter = new NominalToBinary();
m_nominalToBinaryFilter.setInputFormat(m_instances);
m_instances = Filter.useFilter(m_instances, m_nominalToBinaryFilter);
}
m_numAttributes = m_instances.numAttributes() - 1;
m_numClasses = m_instances.numClasses();
setClassType(m_instances);
// this sets up the validation set.
Instances valSet = null;
// numinval is needed later
int numInVal = (int) (m_valSize / 100.0 * m_instances.numInstances());
if (m_valSize > 0) {
if (numInVal == 0) {
numInVal = 1;
}
valSet = new Instances(m_instances, 0, numInVal);
}
// /////////
setupInputs();
setupOutputs();
if (m_autoBuild) {
setupHiddenLayer();
}
// ///////////////////////////
// this sets up the gui for usage
if (m_gui) {
m_win = new JFrame();
m_win.addWindowListener(new WindowAdapter() {
@Override
public void windowClosing(WindowEvent e) {
boolean k = m_stopIt;
m_stopIt = true;
int well = JOptionPane.showConfirmDialog(m_win, "Are You Sure...\n"
+ "Click Yes To Accept" + " The Neural Network"
+ "\n Click No To Return", "Accept Neural Network",
JOptionPane.YES_NO_OPTION);
if (well == 0) {
m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
m_accepted = true;
blocker(false);
} else {
m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
}
m_stopIt = k;
}
});
m_win.getContentPane().setLayout(new BorderLayout());
m_win.setTitle("Neural Network");
m_nodePanel = new NodePanel();
// without the following two lines, the
// NodePanel.paintComponents(Graphics)
// method will go berserk if the network doesn't fit completely: it will
// get called on a constant basis, using 100% of the CPU
// see the following forum thread:
// http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011
m_nodePanel.setPreferredSize(new Dimension(640, 480));
m_nodePanel.revalidate();
JScrollPane sp = new JScrollPane(m_nodePanel,
JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,
JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
m_controlPanel = new ControlPanel();
m_win.getContentPane().add(sp, BorderLayout.CENTER);
m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
m_win.setSize(640, 480);
m_win.setVisible(true);
}
// This sets up the initial state of the gui
if (m_gui) {
blocker(true);
m_controlPanel.m_changeEpochs.setEnabled(false);
m_controlPanel.m_changeLearning.setEnabled(false);
m_controlPanel.m_changeMomentum.setEnabled(false);
}
// For silly situations in which the network gets accepted before training
// commenses
if (m_numeric) {
setEndsToLinear();
}
if (m_accepted) {
m_win.dispose();
m_controlPanel = null;
m_nodePanel = null;
m_instances = new Instances(m_instances, 0);
m_currentInstance = null;
return;
}
// connections done.
double right = 0;
double driftOff = 0;
double lastRight = Double.POSITIVE_INFINITY;
double bestError = Double.POSITIVE_INFINITY;
double tempRate;
double totalWeight = 0;
double totalValWeight = 0;
double origRate = m_learningRate; // only used for when reset
// ensure that at least 1 instance is trained through.
if (numInVal == m_instances.numInstances()) {
numInVal--;
}
if (numInVal < 0) {
numInVal = 0;
}
for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
if (!m_instances.instance(noa).classIsMissing()) {
totalWeight += m_instances.instance(noa).weight();
}
}
if (m_valSize != 0) {
for (int noa = 0; noa < valSet.numInstances(); noa++) {
if (!valSet.instance(noa).classIsMissing()) {
totalValWeight += valSet.instance(noa).weight();
}
}
}
m_stopped = false;
for (int noa = 1; noa < m_numEpochs + 1; noa++) {
right = 0;
for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
m_currentInstance = m_instances.instance(nob);
if (!m_currentInstance.classIsMissing()) {
// this is where the network updating (and training occurs, for the
// training set
resetNetwork();
calculateOutputs();
tempRate = m_learningRate * m_currentInstance.weight();
if (m_decay) {
tempRate /= noa;
}
right += (calculateErrors() / m_instances.numClasses())
* m_currentInstance.weight();
updateNetworkWeights(tempRate, m_momentum);
}
}
right /= totalWeight;
if (Double.isInfinite(right) || Double.isNaN(right)) {
if (!m_reset) {
m_instances = null;
throw new Exception("Network cannot train. Try restarting with a"
+ " smaller learning rate.");
} else {
// reset the network if possible
if (m_learningRate <= Utils.SMALL) {
throw new IllegalStateException("Learning rate got too small ("
+ m_learningRate + " <= " + Utils.SMALL + ")!");
}
m_learningRate /= 2;
buildClassifier(i);
m_learningRate = origRate;
m_instances = new Instances(m_instances, 0);
m_currentInstance = null;
return;
}
}
// //////////////////////do validation testing if applicable
if (m_valSize != 0) {
right = 0;
for (int nob = 0; nob < valSet.numInstances(); nob++) {
m_currentInstance = valSet.instance(nob);
if (!m_currentInstance.classIsMissing()) {
// this is where the network updating occurs, for the validation set
resetNetwork();
calculateOutputs();
right += (calculateErrors() / valSet.numClasses())
* m_currentInstance.weight();
// note 'right' could be calculated here just using
// the calculate output values. This would be faster.
// be less modular
}
}
if (right < lastRight) {
if (right < bestError) {
bestError = right;
// save the network weights at this point
for (int noc = 0; noc < m_numClasses; noc++) {
m_outputs[noc].saveWeights();
}
driftOff = 0;
}
} else {
driftOff++;
}
lastRight = right;
if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
for (int noc = 0; noc < m_numClasses; noc++) {
m_outputs[noc].restoreWeights();
}
m_accepted = true;
}
right /= totalValWeight;
}
m_epoch = noa;
m_error = right;
// shows what the neuralnet is upto if a gui exists.
updateDisplay();
// This junction controls what state the gui is in at the end of each
// epoch, Such as if it is paused, if it is resumable etc...
if (m_gui) {
while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0))
&& !m_accepted) {
m_stopIt = true;
m_stopped = true;
if (m_epoch >= m_numEpochs && m_valSize == 0) {
m_controlPanel.m_startStop.setEnabled(false);
} else {
m_controlPanel.m_startStop.setEnabled(true);
}
m_controlPanel.m_startStop.setText("Start");
m_controlPanel.m_startStop.setActionCommand("Start");
m_controlPanel.m_changeEpochs.setEnabled(true);
m_controlPanel.m_changeLearning.setEnabled(true);
m_controlPanel.m_changeMomentum.setEnabled(true);
blocker(true);
if (m_numeric) {
setEndsToLinear();
}
}
m_controlPanel.m_changeEpochs.setEnabled(false);
m_controlPanel.m_changeLearning.setEnabled(false);
m_controlPanel.m_changeMomentum.setEnabled(false);
m_stopped = false;
// if the network has been accepted stop the training loop
if (m_accepted) {
m_win.dispose();
m_controlPanel = null;
m_nodePanel = null;
m_instances = new Instances(m_instances, 0);
m_currentInstance = null;
return;
}
}
if (m_accepted) {
m_instances = new Instances(m_instances, 0);
m_currentInstance = null;
return;
}
}
if (m_gui) {
m_win.dispose();
m_controlPanel = null;
m_nodePanel = null;
}
m_instances = new Instances(m_instances, 0);
m_currentInstance = null;
}
/**
* Call this function to predict the class of an instance once a
* classification model has been built with the buildClassifier call.
*
* @param i The instance to classify.
* @return A double array filled with the probabilities of each class type.
* @throws Exception if can't classify instance.
*/
@Override
public double[] distributionForInstance(Instance i) throws Exception {
// default model?
if (m_useDefaultModel) {
return m_ZeroR.distributionForInstance(i);
}
if (m_useNomToBin) {
m_nominalToBinaryFilter.input(i);
m_currentInstance = m_nominalToBinaryFilter.output();
} else {
m_currentInstance = i;
}
// Make a copy of the instance so that it isn't modified
m_currentInstance = (Instance) m_currentInstance.copy();
if (m_normalizeAttributes) {
for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
if (noa != m_instances.classIndex()) {
if (m_attributeRanges[noa] != 0) {
m_currentInstance.setValue(noa,
(m_currentInstance.value(noa) - m_attributeBases[noa])
/ m_attributeRanges[noa]);
} else {
m_currentInstance.setValue(noa, m_currentInstance.value(noa)
- m_attributeBases[noa]);
}
}
}
}
resetNetwork();
// since all the output values are needed.
// They are calculated manually here and the values collected.
double[] theArray = new double[m_numClasses];
for (int noa = 0; noa < m_numClasses; noa++) {
theArray[noa] = m_outputs[noa].outputValue(true);
}
if (m_instances.classAttribute().isNumeric()) {
return theArray;
}
// now normalize the array
double count = 0;
for (int noa = 0; noa < m_numClasses; noa++) {
count += theArray[noa];
}
if (count <= 0) {
return m_ZeroR.distributionForInstance(i);
}
for (int noa = 0; noa < m_numClasses; noa++) {
theArray[noa] /= count;
}
return theArray;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2024 Weber Informatics LLC | Privacy Policy