All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.neural.NeuralNode Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    NeuralNode.java
 *    Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand
 */

package weka.classifiers.functions.neural;

import weka.core.RevisionUtils;

import java.util.Random;

/**
 * This class is used to represent a node in the neuralnet.
 * 
 * @author Malcolm Ware ([email protected])
 * @version $Revision: 8034 $
 */
public class NeuralNode
  extends NeuralConnection {

  /** for serialization */
  private static final long serialVersionUID = -1085750607680839163L;
    
  /** The weights for each of the input connections, and the threshold. */
  private double[] m_weights;
  
  /** The best (lowest error) weights. Only used when validation set is used */
  private double[] m_bestWeights;
  
  /** The change in the weights. */
  private double[] m_changeInWeights;
  
  private Random m_random;

  /** Performs the operations for this node. Currently this
   * defines that the node is either a sigmoid or a linear unit. */
  private NeuralMethod m_methods;

  /** 
   * @param id The string name for this node (used to id this node).
   * @param r A random number generator used to generate initial weights.
   * @param m The methods this node should use to update.
   */
  public NeuralNode(String id, Random r, NeuralMethod m) {
    super(id);
    m_weights = new double[1];
    m_bestWeights = new double[1];
    m_changeInWeights = new double[1];
    
    m_random = r;
    
    m_weights[0] = m_random.nextDouble() * .1 - .05;
    m_changeInWeights[0] = 0;

    m_methods = m;
  }
  
  /**
   * Set how this node should operate (note that the neural method has no
   * internal state, so the same object can be used by any number of nodes.
   * @param m The new method.
   */
  public void setMethod(NeuralMethod m) {
    m_methods = m;
  } 

  public NeuralMethod getMethod() {
    return m_methods;
  }

  /**
   * Call this to get the output value of this unit. 
   * @param calculate True if the value should be calculated if it hasn't been
   * already.
   * @return The output value, or NaN, if the value has not been calculated.
   */
  public double outputValue(boolean calculate) {
    
    if (Double.isNaN(m_unitValue) && calculate) {
      //then calculate the output value;
      m_unitValue = m_methods.outputValue(this);
    }
    
    return m_unitValue;
  }

  
  /**
   * Call this to get the error value of this unit.
   * @param calculate True if the value should be calculated if it hasn't been
   * already.
   * @return The error value, or NaN, if the value has not been calculated.
   */
  public double errorValue(boolean calculate) {

    if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
      //then calculate the error.
      m_unitError = m_methods.errorValue(this);
    }
    return m_unitError;
  }

  /**
   * Call this to reset the value and error for this unit, ready for the next
   * run. This will also call the reset function of all units that are 
   * connected as inputs to this one.
   * This is also the time that the update for the listeners will be performed.
   */
  public void reset() {
    
    if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
      m_unitValue = Double.NaN;
      m_unitError = Double.NaN;
      m_weightsUpdated = false;
      for (int noa = 0; noa < m_numInputs; noa++) {
	m_inputList[noa].reset();
      }
    }
  }
  
  /**
   * Call this to have the connection save the current
   * weights.
   */
  public void saveWeights() {
    // copy the current weights
    System.arraycopy(m_weights, 0, m_bestWeights, 0, m_weights.length);
    
    // tell inputs to save weights
    for (int i = 0; i < m_numInputs; i++) {
      m_inputList[i].saveWeights();
    }
  }
  
  /**
   * Call this to have the connection restore from the saved
   * weights.
   */
  public void restoreWeights() {
    // copy the saved best weights back into the weights
    System.arraycopy(m_bestWeights, 0, m_weights, 0, m_weights.length);
    
    // tell inputs to restore weights
    for (int i = 0; i < m_numInputs; i++) {
      m_inputList[i].restoreWeights();
    }
  }

  /**
   * Call this to get the weight value on a particular connection.
   * @param n The connection number to get the weight for, -1 if The threshold
   * weight should be returned.
   * @return The value for the specified connection or if -1 then it should 
   * return the threshold value. If no value exists for the specified 
   * connection, NaN will be returned.
   */
  public double weightValue(int n) {
    if (n >= m_numInputs || n < -1) {
      return Double.NaN;
    }
    return m_weights[n + 1];
  }

  /**
   * call this function to get the weights array.
   * This will also allow the weights to be updated.
   * @return The weights array.
   */
  public double[] getWeights() {
    return m_weights;
  }

  /**
   * call this function to get the chnage in weights array.
   * This will also allow the change in weights to be updated.
   * @return The change in weights array.
   */
  public double[] getChangeInWeights() {
    return m_changeInWeights;
  }

  /**
   * Call this function to update the weight values at this unit.
   * After the weights have been updated at this unit, All the
   * input connections will then be called from this to have their
   * weights updated.
   * @param l The learning rate to use.
   * @param m The momentum to use.
   */
  public void updateWeights(double l, double m) {
    
    if (!m_weightsUpdated && !Double.isNaN(m_unitError)) {
      m_methods.updateWeights(this, l, m);
     
      //note that the super call to update the inputs is done here and
      //not in the m_method updateWeights, because it is not deemed to be
      //required to update the weights at this node (while the error and output
      //value ao need to be recursively calculated)
      super.updateWeights(l, m); //to call all of the inputs.
    }
    
  }

  /**
   * This will connect the specified unit to be an input to this unit.
   * @param i The unit.
   * @param n It's connection number for this connection.
   * @return True if the connection was made, false otherwise.
   */
  protected boolean connectInput(NeuralConnection i, int n) {
    
    //the function that this overrides can do most of the work.
    if (!super.connectInput(i, n)) {
      return false;
    }
    
    //note that the weights are shifted 1 forward in the array so
    //it leaves the numinputs aligned on the space the weight needs to go.
    m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05;
    m_changeInWeights[m_numInputs] = 0;
    
    return true;
  }

  /**
   * This will allocate more space for input connection information
   * if the arrays for this have been filled up.
   */
  protected void allocateInputs() {
    
    NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
    int[] temp2 = new int[m_inputNums.length + 15];
    double[] temp4 = new double[m_weights.length + 15];
    double[] temp5 = new double[m_changeInWeights.length + 15];
    double[] temp6 = new double[m_bestWeights.length + 15];

    temp4[0] = m_weights[0];
    temp5[0] = m_changeInWeights[0];
    temp6[0] = m_bestWeights[0];
    for (int noa = 0; noa < m_numInputs; noa++) {
      temp1[noa] = m_inputList[noa];
      temp2[noa] = m_inputNums[noa];
      temp4[noa+1] = m_weights[noa+1];
      temp5[noa+1] = m_changeInWeights[noa+1];
      temp6[noa+1] = m_bestWeights[noa+1];
    }
    
    m_inputList = temp1;
    m_inputNums = temp2;
    m_weights = temp4;
    m_changeInWeights = temp5;
    m_bestWeights = temp6;
  }

  
  

  /**
   * This will disconnect the input with the specific connection number
   * From this node (only on this end however).
   * @param i The unit to disconnect.
   * @param n The connection number at the other end, -1 if all the connections
   * to this unit should be severed (not the same as removeAllInputs).
   * @return True if the connection was removed, false if the connection was 
   * not found.
   */
  protected boolean disconnectInput(NeuralConnection i, int n) {
    
    int loc = -1;
    boolean removed = false;
    do {
      loc = -1;
      for (int noa = 0; noa < m_numInputs; noa++) {
	if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
	  loc = noa;
	  break;
	}
      }
      
      if (loc >= 0) {
	for (int noa = loc+1; noa < m_numInputs; noa++) {
	  m_inputList[noa-1] = m_inputList[noa];
	  m_inputNums[noa-1] = m_inputNums[noa];
	  
	  m_weights[noa] = m_weights[noa+1];
	  m_changeInWeights[noa] = m_changeInWeights[noa+1];
	  
	  m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
	}
	m_numInputs--;
	removed = true;
      }      
    } while (n == -1 && loc != -1);
    return removed;
  }
  
  /**
   * This function will remove all the inputs to this unit.
   * In doing so it will also terminate the connections at the other end.
   */
  public void removeAllInputs() {
    super.removeAllInputs();
    
    double temp1 = m_weights[0];
    double temp2 = m_changeInWeights[0];

    m_weights = new double[1];
    m_changeInWeights = new double[1];

    m_weights[0] = temp1;
    m_changeInWeights[0] = temp2;
    
  }  
  
  /**
   * Returns the revision string.
   * 
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 8034 $");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy