All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.neural.NeuralConnection Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    NeuralConnection.java
 *    Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand
 */

package weka.classifiers.functions.neural;

import java.awt.Color;
import java.awt.Graphics;
import java.io.Serializable;

import weka.core.RevisionHandler;

/** 
 * Abstract unit in a NeuralNetwork.
 *
 * @author Malcolm Ware ([email protected])
 * @version $Revision: 8034 $
 */
public abstract class NeuralConnection
  implements Serializable, RevisionHandler {

  /** for serialization */
  private static final long serialVersionUID = -286208828571059163L;

  //bitwise flags for the types of unit.

  /** This unit is not connected to any others. */
  public static final int UNCONNECTED = 0;
  
  /** This unit is a pure input unit. */
  public static final int PURE_INPUT = 1;
  
  /** This unit is a pure output unit. */
  public static final int PURE_OUTPUT = 2;
  
  /** This unit is an input unit. */
  public static final int INPUT = 4;
  
  /** This unit is an output unit. */
  public static final int OUTPUT = 8;
  
  /** This flag is set once the unit has a connection. */
  public static final int CONNECTED = 16;



  /////The difference between pure and not is that pure is used to feed 
  /////the neural network the attribute values and the errors on the outputs
  /////Beyond that they do no calculations, and have certain restrictions
  /////on the connections they can make.



  /** The list of inputs to this unit. */
  protected NeuralConnection[] m_inputList;

  /** The list of outputs from this unit. */
  protected NeuralConnection[] m_outputList;

  /** The numbering for the connections at the other end of the input lines. */
  protected int[] m_inputNums;
  
  /** The numbering for the connections at the other end of the out lines. */
  protected int[] m_outputNums;

  /** The number of inputs. */
  protected int m_numInputs;

  /** The number of outputs. */
  protected int m_numOutputs;

  /** The output value for this unit, NaN if not calculated. */
  protected double m_unitValue;

  /** The error value for this unit, NaN if not calculated. */
  protected double m_unitError;
  
  /** True if the weights have already been updated. */
  protected boolean m_weightsUpdated;
  
  /** The string that uniquely (provided naming is done properly) identifies
   * this unit. */
  protected String m_id;

  /** The type of unit this is. */
  protected int m_type;

  /** The x coord of this unit purely for displaying purposes. */
  protected double m_x;
  
  /** The y coord of this unit purely for displaying purposes. */
  protected double m_y;
  

  
  
  /**
   * Constructs The unit with the basic connection information prepared for
   * use. 
   * 
   * @param id the unique id of the unit
   */
  public NeuralConnection(String id) {
    
    m_id = id;
    m_inputList = new NeuralConnection[0];
    m_outputList = new NeuralConnection[0];
    m_inputNums = new int[0];
    m_outputNums = new int[0];

    m_numInputs = 0;
    m_numOutputs = 0;

    m_unitValue = Double.NaN;
    m_unitError = Double.NaN;

    m_weightsUpdated = false;
    m_x = 0;
    m_y = 0;
    m_type = UNCONNECTED;
  }
  
  
  /**
   * @return The identity string of this unit.
   */
  public String getId() {
    return m_id;
  }

  /**
   * @return The type of this unit.
   */
  public int getType() {
    return m_type;
  }

  /**
   * @param t The new type of this unit.
   */
  public void setType(int t) {
    m_type = t;
  }

  /**
   * Call this to reset the unit for another run.
   * It is expected by that this unit will call the reset functions of all 
   * input units to it. It is also expected that this will not be done
   * if the unit has already been reset (or atleast appears to be).
   */
  public abstract void reset();

  /**
   * Call this to get the output value of this unit. 
   * @param calculate True if the value should be calculated if it hasn't been
   * already.
   * @return The output value, or NaN, if the value has not been calculated.
   */
  public abstract double outputValue(boolean calculate);

  /**
   * Call this to get the error value of this unit.
   * @param calculate True if the value should be calculated if it hasn't been
   * already.
   * @return The error value, or NaN, if the value has not been calculated.
   */
  public abstract double errorValue(boolean calculate);
  
  /**
   * Call this to have the connection save the current
   * weights.
   */
  public abstract void saveWeights();
  
  /**
   * Call this to have the connection restore from the saved
   * weights.
   */
  public abstract void restoreWeights();

  /**
   * Call this to get the weight value on a particular connection.
   * @param n The connection number to get the weight for, -1 if The threshold
   * weight should be returned.
   * @return This function will default to return 1. If overridden, it should
   * return the value for the specified connection or if -1 then it should 
   * return the threshold value. If no value exists for the specified 
   * connection, NaN will be returned.
   */
  public double weightValue(int n) {
    return 1;
  }

  /**
   * Call this function to update the weight values at this unit.
   * After the weights have been updated at this unit, All the
   * input connections will then be called from this to have their
   * weights updated.
   * @param l The learning Rate to use.
   * @param m The momentum to use.
   */
  public void updateWeights(double l, double m) {
    
    //the action the subclasses should perform is upto them 
    //but if they coverride they should make a call to this to
    //call the method for all their inputs.
    
    if (!m_weightsUpdated) {
      for (int noa = 0; noa < m_numInputs; noa++) {
	m_inputList[noa].updateWeights(l, m);
      }
      m_weightsUpdated = true;
    }
    
  }

  /**
   * Use this to get easy access to the inputs.
   * It is not advised to change the entries in this list
   * (use the connecting and disconnecting functions to do that)
   * @return The inputs list.
   */
  public NeuralConnection[] getInputs() {
    return m_inputList;
  }

  /**
   * Use this to get easy access to the outputs.
   * It is not advised to change the entries in this list
   * (use the connecting and disconnecting functions to do that)
   * @return The outputs list.
   */
  public NeuralConnection[] getOutputs() {
    return m_outputList;
  }

  /**
   * Use this to get easy access to the input numbers.
   * It is not advised to change the entries in this list
   * (use the connecting and disconnecting functions to do that)
   * @return The input nums list.
   */
  public int[] getInputNums() {
    return m_inputNums;
  }

  /**
   * Use this to get easy access to the output numbers.
   * It is not advised to change the entries in this list
   * (use the connecting and disconnecting functions to do that)
   * @return The outputs list.
   */
  public int[] getOutputNums() {
    return m_outputNums;
  }

  /**
   * @return the x coord.
   */
  public double getX() {
    return m_x;
  }
  
  /**
   * @return the y coord.
   */
  public double getY() {
    return m_y;
  }
  
  /**
   * @param x The new value for it's x pos.
   */
  public void setX(double x) {
    m_x = x;
  }
  
  /**
   * @param y The new value for it's y pos.
   */
  public void setY(double y) {
    m_y = y;
  }
  
  
  /**
   * Call this function to determine if the point at x,y is on the unit.
   * @param g The graphics context for font size info.
   * @param x The x coord.
   * @param y The y coord.
   * @param w The width of the display.
   * @param h The height of the display.
   * @return True if the point is on the unit, false otherwise.
   */
  public boolean onUnit(Graphics g, int x, int y, int w, int h) {

    int m = (int)(m_x * w);
    int c = (int)(m_y * h);
    if (x > m + 10 || x < m - 10 || y > c + 10 || y < c - 10) {
      return false;
    }
    return true;

  }
  
  /**
   * Call this function to draw the node.
   * @param g The graphics context.
   * @param w The width of the drawing area.
   * @param h The height of the drawing area.
   */
  public void drawNode(Graphics g, int w, int h) {
    
    if ((m_type & OUTPUT) == OUTPUT) {
      g.setColor(Color.orange);
    }
    else {
      g.setColor(Color.red);
    }
    g.fillOval((int)(m_x * w) - 9, (int)(m_y * h) - 9, 19, 19);
    g.setColor(Color.gray);
    g.fillOval((int)(m_x * w) - 5, (int)(m_y * h) - 5, 11, 11);
  }

  /**
   * Call this function to draw the node highlighted.
   * @param g The graphics context.
   * @param w The width of the drawing area.
   * @param h The height of the drawing area.
   */
  public void drawHighlight(Graphics g, int w, int h) {
   
    drawNode(g, w, h);
    g.setColor(Color.yellow);
    g.fillOval((int)(m_x * w) - 5, (int)(m_y * h) - 5, 11, 11);
  }

  /** 
   * Call this function to draw the nodes input connections.
   * @param g The graphics context.
   * @param w The width of the drawing area.
   * @param h The height of the drawing area.
   */
  public void drawInputLines(Graphics g, int w, int h) {

    g.setColor(Color.black);
    
    int px = (int)(m_x * w);
    int py = (int)(m_y * h);
    for (int noa = 0; noa < m_numInputs; noa++) {
      g.drawLine((int)(m_inputList[noa].getX() * w)
		 , (int)(m_inputList[noa].getY() * h)
		 , px, py);
    }
  }

  /**
   * Call this function to draw the nodes output connections.
   * @param g The graphics context.
   * @param w The width of the drawing area.
   * @param h The height of the drawing area.
   */
  public void drawOutputLines(Graphics g, int w, int h) {
    
    g.setColor(Color.black);
    
    int px = (int)(m_x * w);
    int py = (int)(m_y * h);
    for (int noa = 0; noa < m_numOutputs; noa++) {
      g.drawLine(px, py
		 , (int)(m_outputList[noa].getX() * w)
		 , (int)(m_outputList[noa].getY() * h));
    }
  }


  /**
   * This will connect the specified unit to be an input to this unit.
   * @param i The unit.
   * @param n It's connection number for this connection.
   * @return True if the connection was made, false otherwise.
   */
  protected boolean connectInput(NeuralConnection i, int n) {
    
    for (int noa = 0; noa < m_numInputs; noa++) {
      if (i == m_inputList[noa]) {
	return false;
      }
    }
    if (m_numInputs >= m_inputList.length) {
      //then allocate more space to it.
      allocateInputs();
    }
    m_inputList[m_numInputs] = i;
    m_inputNums[m_numInputs] = n;
    m_numInputs++;
    return true;
  }
  
  /**
   * This will allocate more space for input connection information
   * if the arrays for this have been filled up.
   */
  protected void allocateInputs() {
    
    NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
    int[] temp2 = new int[m_inputNums.length + 15];

    for (int noa = 0; noa < m_numInputs; noa++) {
      temp1[noa] = m_inputList[noa];
      temp2[noa] = m_inputNums[noa];
    }
    m_inputList = temp1;
    m_inputNums = temp2;
  }

  /** 
   * This will connect the specified unit to be an output to this unit.
   * @param o The unit.
   * @param n It's connection number for this connection.
   * @return True if the connection was made, false otherwise.
   */
  protected boolean connectOutput(NeuralConnection o, int n) {
    
    for (int noa = 0; noa < m_numOutputs; noa++) {
      if (o == m_outputList[noa]) {
	return false;
      }
    }
    if (m_numOutputs >= m_outputList.length) {
      //then allocate more space to it.
      allocateOutputs();
    }
    m_outputList[m_numOutputs] = o;
    m_outputNums[m_numOutputs] = n;
    m_numOutputs++;
    return true;
  }
  
  /**
   * Allocates more space for output connection information
   * if the arrays have been filled up.
   */
  protected void allocateOutputs() {
    
    NeuralConnection[] temp1 
      = new NeuralConnection[m_outputList.length + 15];
    
    int[] temp2 = new int[m_outputNums.length + 15];
    
    for (int noa = 0; noa < m_numOutputs; noa++) {
      temp1[noa] = m_outputList[noa];
      temp2[noa] = m_outputNums[noa];
    }
    m_outputList = temp1;
    m_outputNums = temp2;
  }
  
  /**
   * This will disconnect the input with the specific connection number
   * From this node (only on this end however).
   * @param i The unit to disconnect.
   * @param n The connection number at the other end, -1 if all the connections
   * to this unit should be severed.
   * @return True if the connection was removed, false if the connection was 
   * not found.
   */
  protected boolean disconnectInput(NeuralConnection i, int n) {
    
    int loc = -1;
    boolean removed = false;
    do {
      loc = -1;
      for (int noa = 0; noa < m_numInputs; noa++) {
	if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
	  loc = noa;
	  break;
	}
      }
      
      if (loc >= 0) {
	for (int noa = loc+1; noa < m_numInputs; noa++) {
	  m_inputList[noa-1] = m_inputList[noa];
	  m_inputNums[noa-1] = m_inputNums[noa];
	  //set the other end to have the right connection number.
	  m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
	}
	m_numInputs--;
	removed = true;
      }
    } while (n == -1 && loc != -1);

    return removed;
  }

  /**
   * This function will remove all the inputs to this unit.
   * In doing so it will also terminate the connections at the other end.
   */
  public void removeAllInputs() {
    
    for (int noa = 0; noa < m_numInputs; noa++) {
      //this command will simply remove any connections this node has
      //with the other in 1 go, rather than seperately.
      m_inputList[noa].disconnectOutput(this, -1);
    }
    
    //now reset the inputs.
    m_inputList = new NeuralConnection[0];
    setType(getType() & (~INPUT));
    if (getNumOutputs() == 0) {
      setType(getType() & (~CONNECTED));
    }
    m_inputNums = new int[0];
    m_numInputs = 0;
    
  }

 

  /**
   * Changes the connection value information for one of the connections.
   * @param n The connection number to change.
   * @param v The value to change it to.
   */
  protected void changeInputNum(int n, int v) {
    
    if (n >= m_numInputs || n < 0) {
      return;
    }

    m_inputNums[n] = v;
  }
  
  /**
   * This will disconnect the output with the specific connection number
   * From this node (only on this end however).
   * @param o The unit to disconnect.
   * @param n The connection number at the other end, -1 if all the connections
   * to this unit should be severed.
   * @return True if the connection was removed, false if the connection was
   * not found.
   */  
  protected boolean disconnectOutput(NeuralConnection o, int n) {
    
    int loc = -1;
    boolean removed = false;
    do {
      loc = -1;
      for (int noa = 0; noa < m_numOutputs; noa++) {
	if (o == m_outputList[noa] && (n == -1 || n == m_outputNums[noa])) {
	  loc =noa;
	  break;
	}
      }
      
      if (loc >= 0) {
	for (int noa = loc+1; noa < m_numOutputs; noa++) {
	  m_outputList[noa-1] = m_outputList[noa];
	  m_outputNums[noa-1] = m_outputNums[noa];

	  //set the other end to have the right connection number
	  m_outputList[noa-1].changeInputNum(m_outputNums[noa-1], noa-1);
	}
	m_numOutputs--;
	removed = true;
      }
    } while (n == -1 && loc != -1);
    
    return removed;
  }

  /**
   * This function will remove all outputs to this unit.
   * In doing so it will also terminate the connections at the other end.
   */
  public void removeAllOutputs() {
    
    for (int noa = 0; noa < m_numOutputs; noa++) {
      //this command will simply remove any connections this node has
      //with the other in 1 go, rather than seperately.
      m_outputList[noa].disconnectInput(this, -1);
    }
    
    //now reset the inputs.
    m_outputList = new NeuralConnection[0];
    m_outputNums = new int[0];
    setType(getType() & (~OUTPUT));
    if (getNumInputs() == 0) {
      setType(getType() & (~CONNECTED));
    }
    m_numOutputs = 0;
    
  }

  /**
   * Changes the connection value information for one of the connections.
   * @param n The connection number to change.
   * @param v The value to change it to.
   */
  protected void changeOutputNum(int n, int v) {
    
    if (n >= m_numOutputs || n < 0) {
      return;
    }

    m_outputNums[n] = v;
  }
  
  /**
   * @return The number of input connections.
   */
  public int getNumInputs() {
    return m_numInputs;
  }

  /**
   * @return The number of output connections.
   */
  public int getNumOutputs() {
    return m_numOutputs;
  }


  /**
   * Connects two units together.
   * @param s The source unit.
   * @param t The target unit.
   * @return True if the units were connected, false otherwise.
   */
  public static boolean connect(NeuralConnection s, NeuralConnection t) {
    
    if (s == null || t == null) {
      return false;
    }
    //this ensures that there is no existing connection between these 
    //two units already. This will also cause the current weight there to be 
    //lost
 
    disconnect(s, t);
    if (s == t) {
      return false;
    }
    if ((t.getType() & PURE_INPUT) == PURE_INPUT) {
      return false;   //target is an input node.
    }
    if ((s.getType() & PURE_OUTPUT) == PURE_OUTPUT) {
      return false;   //source is an output node
    }
    if ((s.getType() & PURE_INPUT) == PURE_INPUT 
	&& (t.getType() & PURE_OUTPUT) == PURE_OUTPUT) {      
      return false;   //there is no actual working node in use
    }
    if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT && t.getNumInputs() > 0) {
      return false; //more than 1 node is trying to feed a particular output
    }

    if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT &&
	(s.getType() & OUTPUT) == OUTPUT) {
      return false; //an output node already feeding out a final answer
    }

    if (!s.connectOutput(t, t.getNumInputs())) {
      return false;
    }
    if (!t.connectInput(s, s.getNumOutputs() - 1)) {
      
      s.disconnectOutput(t, t.getNumInputs());
      return false;

    }

    //now ammend the type.
    if ((s.getType() & PURE_INPUT) == PURE_INPUT) {
      t.setType(t.getType() | INPUT);
    }
    else if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT) {
      s.setType(s.getType() | OUTPUT);
    }
    t.setType(t.getType() | CONNECTED);
    s.setType(s.getType() | CONNECTED);
    return true;
  }

  /**
   * Disconnects two units.
   * @param s The source unit.
   * @param t The target unit.
   * @return True if the units were disconnected, false if they weren't
   * (probably due to there being no connection).
   */
  public static boolean disconnect(NeuralConnection s, NeuralConnection t) {
    
    if (s == null || t == null) {
      return false;
    }

    boolean stat1 = s.disconnectOutput(t, -1);
    boolean stat2 = t.disconnectInput(s, -1);
    if (stat1 && stat2) {
      if ((s.getType() & PURE_INPUT) == PURE_INPUT) {
	t.setType(t.getType() & (~INPUT));
      }
      else if ((t.getType() & (PURE_OUTPUT)) == PURE_OUTPUT) {
	s.setType(s.getType() & (~OUTPUT));
      }
      if (s.getNumInputs() == 0 && s.getNumOutputs() == 0) {
	s.setType(s.getType() & (~CONNECTED));
      }
      if (t.getNumInputs() == 0 && t.getNumOutputs() == 0) {
	t.setType(t.getType() & (~CONNECTED));
      }
    }
    return stat1 && stat2;
  }
}
















© 2015 - 2025 Weber Informatics LLC | Privacy Policy