All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.gui.beans.CrossValidationFoldMaker Maven / Gradle / Ivy

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    CrossValidationFoldMaker.java
 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.gui.beans;

import weka.core.Instances;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
 * Bean for splitting instances into training ant test sets according to
 * a cross validation
 *
 * @author Mark Hall
 * @version $Revision: 7059 $
 */
public class CrossValidationFoldMaker 
  extends AbstractTrainAndTestSetProducer
  implements DataSourceListener, TrainingSetListener, TestSetListener, 
	     UserRequestAcceptor, EventConstraints, Serializable {

  /** for serialization */
  private static final long serialVersionUID = -6350179298851891512L;

  private int m_numFolds = 10;
  private int m_randomSeed = 1;

  private transient Thread m_foldThread = null;

  public CrossValidationFoldMaker() {
    m_visual.loadIcons(BeanVisual.ICON_PATH
		       +"CrossValidationFoldMaker.gif",
		       BeanVisual.ICON_PATH
		       +"CrossValidationFoldMaker_animated.gif");
    m_visual.setText("CrossValidationFoldMaker");
  }

  /**
   * Set a custom (descriptive) name for this bean
   * 
   * @param name the name to use
   */
  public void setCustomName(String name) {
    m_visual.setText(name);
  }

  /**
   * Get the custom (descriptive) name for this bean (if one has been set)
   * 
   * @return the custom name (or the default name)
   */
  public String getCustomName() {
    return m_visual.getText();
  }

  /**
   * Global info for this bean
   *
   * @return a String value
   */
  public String globalInfo() {
    return Messages.getInstance().getString("CrossValidationFoldMaker_GlobalInfo_Text");
  }

  /**
   * Accept a training set
   *
   * @param e a TrainingSetEvent value
   */
  public void acceptTrainingSet(TrainingSetEvent e) {
    Instances trainingSet = e.getTrainingSet();
    DataSetEvent dse = new DataSetEvent(this, trainingSet);
    acceptDataSet(dse);
  }

  /**
   * Accept a test set
   *
   * @param e a TestSetEvent value
   */
  public void acceptTestSet(TestSetEvent e) {
    Instances testSet = e.getTestSet();
    DataSetEvent dse = new DataSetEvent(this, testSet);
    acceptDataSet(dse);
  }

  /**
   * Accept a data set
   *
   * @param e a DataSetEvent value
   */
  public void acceptDataSet(DataSetEvent e) {
    if (e.isStructureOnly()) {
      // Pass on structure to training and test set listeners
      TrainingSetEvent tse = new TrainingSetEvent(this, e.getDataSet());
      TestSetEvent tsee = new TestSetEvent(this, e.getDataSet());
      notifyTrainingSetProduced(tse);
      notifyTestSetProduced(tsee);
      return;
    }
    if (m_foldThread == null) {
      final Instances dataSet = new Instances(e.getDataSet());
      m_foldThread = new Thread() {
	  public void run() {
	    boolean errorOccurred = false;
	    try {
	      Random random = new Random(getSeed());
	      dataSet.randomize(random);
	      if (dataSet.classIndex() >= 0 && 
		  dataSet.attribute(dataSet.classIndex()).isNominal()) {
		dataSet.stratify(getFolds());
		if (m_logger != null) {
		  m_logger.logMessage(Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_First") + getCustomName() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Second"));
		}
	      }
	      
	      for (int i = 0; i < getFolds(); i++) {
		if (m_foldThread == null) {
		  if (m_logger != null) {
		    m_logger.logMessage(Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Third") + getCustomName() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Fourth"));
		  }
		  // exit gracefully
		  break;
		}
		Instances train = dataSet.trainCV(getFolds(), i, random);
		Instances test  = dataSet.testCV(getFolds(), i);
		
		// inform all training set listeners
		TrainingSetEvent tse = new TrainingSetEvent(this, train);
		tse.m_setNumber = i+1; tse.m_maxSetNumber = getFolds();
		String msg = getCustomName() + "$" 
		  + CrossValidationFoldMaker.this.hashCode() + "|";
		if (m_logger != null) {
		  m_logger.statusMessage(msg + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_First") + getSeed() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Second")
		      + getFolds() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Third") + (i+1));
		}
		if (m_foldThread != null) {
		  //		  System.err.println("--Just before notify training set");
		  notifyTrainingSetProduced(tse);
		  //		  System.err.println("---Just after notify");
		}
	      
		// inform all test set listeners
		TestSetEvent teste = new TestSetEvent(this, test);
		teste.m_setNumber = i+1; teste.m_maxSetNumber = getFolds();
		
		if (m_logger != null) {
		  m_logger.statusMessage(msg + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Fourth") + getSeed() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Fifth")
		      + getFolds() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Sixth") + (i+1));
		}
		if (m_foldThread != null) {
		  notifyTestSetProduced(teste);
		}
	      }
	    } catch (Exception ex) {
	      // stop all processing
	      errorOccurred = true;
	      CrossValidationFoldMaker.this.stop();
	      if (m_logger != null) {
	        m_logger.logMessage(Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Fifth") + getCustomName() 
	            + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Sixth")
	            + ex.getMessage());
	      }
	      ex.printStackTrace();
	    } finally {
	      m_foldThread = null;
	      if (errorOccurred) {
	        if (m_logger != null) {
	          m_logger.statusMessage(getCustomName() 
	              + "$" + CrossValidationFoldMaker.this.hashCode()
	              + "|"
	              + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Seventh"));
	        }
	      } else if (isInterrupted()) {
	        String msg = Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_Msg_Text_First") + getCustomName() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_Msg_Text_Second");
	        if (m_logger != null) {
	          m_logger.logMessage(Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Fifth") + getCustomName() + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_LogMessage_Text_Sixth_Alpha"));
	          m_logger.statusMessage(getCustomName() + "$"
	              + CrossValidationFoldMaker.this.hashCode() + "|"
	              + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Eighth"));
	        } else {
	          System.err.println(msg);
	        }
	      } else {
	        String msg = getCustomName() + "$" 
	        + CrossValidationFoldMaker.this.hashCode() + "|";
	        if (m_logger != null) {
	          m_logger.statusMessage(msg + Messages.getInstance().getString("CrossValidationFoldMaker_AcceptDataSet_StatusMessage_Text_Nineth"));
	        }
	      }
	      block(false);
	    }
	  }
	};
      m_foldThread.setPriority(Thread.MIN_PRIORITY);
      m_foldThread.start();

      //      if (m_foldThread.isAlive()) {
      block(true);
	//      }
      m_foldThread = null;
    }
  }


  /**
   * Notify all test set listeners of a TestSet event
   *
   * @param tse a TestSetEvent value
   */
  private void notifyTestSetProduced(TestSetEvent tse) {
    Vector l;
    synchronized (this) {
      l = (Vector)m_testListeners.clone();
    }
    if (l.size() > 0) {
      for(int i = 0; i < l.size(); i++) {
        if (m_foldThread == null) {
          break;
        }
	//	System.err.println("Notifying test listeners "
	//			   +"(cross validation fold maker)");
	((TestSetListener)l.elementAt(i)).acceptTestSet(tse);
      }
    }
  }

  /**
   * Notify all listeners of a TrainingSet event
   *
   * @param tse a TrainingSetEvent value
   */
  protected void notifyTrainingSetProduced(TrainingSetEvent tse) {
    Vector l;
    synchronized (this) {
      l = (Vector)m_trainingListeners.clone();
    }
    if (l.size() > 0) {
      for(int i = 0; i < l.size(); i++) {
        if (m_foldThread == null) {
          break;
        }
	//	System.err.println("Notifying training listeners "
	//			   +"(cross validation fold maker)");
	((TrainingSetListener)l.elementAt(i)).acceptTrainingSet(tse);
      }
    }
  }

  /**
   * Set the number of folds for the cross validation
   *
   * @param numFolds an int value
   */
  public void setFolds(int numFolds) {
    m_numFolds = numFolds;
  }
  
  /**
   * Get the currently set number of folds
   *
   * @return an int value
   */
  public int getFolds() {
    return m_numFolds;
  }

  /**
   * Tip text for this property
   *
   * @return a String value
   */
  public String foldsTipText() {
    return Messages.getInstance().getString("CrossValidationFoldMaker_FoldsTipText_Text");
  }
    
  /**
   * Set the seed
   *
   * @param randomSeed an int value
   */
  public void setSeed(int randomSeed) {
    m_randomSeed = randomSeed;
  }
  
  /**
   * Get the currently set seed
   *
   * @return an int value
   */
  public int getSeed() {
    return m_randomSeed;
  }
  
  /**
   * Tip text for this property
   *
   * @return a String value
   */
  public String seedTipText() {
    return Messages.getInstance().getString("CrossValidationFoldMaker_SeedTipText_Text");
  }
  
  /**
   * Returns true if. at this time, the bean is busy with some
   * (i.e. perhaps a worker thread is performing some calculation).
   * 
   * @return true if the bean is busy.
   */
  public boolean isBusy() {
    return (m_foldThread != null);
  }

  /**
   * Stop any action
   */
  public void stop() {
    // tell the listenee (upstream bean) to stop
    if (m_listenee instanceof BeanCommon) {
      //      System.err.println("Listener is BeanCommon");
      ((BeanCommon)m_listenee).stop();
    }

    // stop the fold thread
    if (m_foldThread != null) {
      Thread temp = m_foldThread;
      m_foldThread = null;
      temp.interrupt();
      temp.stop();
    }
  }

  /**
   * Function used to stop code that calls acceptDataSet. This is 
   * needed as cross validation is performed inside a separate
   * thread of execution.
   *
   * @param tf a boolean value
   */
  private synchronized void block(boolean tf) {
    if (tf) {
      try {
	// make sure the thread is still running before we block
	if (m_foldThread.isAlive()) {
	  wait();
	}
      } catch (InterruptedException ex) {
      }
    } else {
      notifyAll();
    }
  }

  /**
   * Return an enumeration of user requests
   *
   * @return an Enumeration value
   */
  public Enumeration enumerateRequests() {
    Vector newVector = new Vector(0);
    if (m_foldThread != null) {
      newVector.addElement("Stop");
    }
    return newVector.elements();
  }

  /**
   * Perform the named request
   *
   * @param request a String value
   * @exception IllegalArgumentException if an error occurs
   */
  public void performRequest(String request) {
    if (request.compareTo("Stop") == 0) {
      stop();
    } else {
      throw new IllegalArgumentException(request
					 + Messages.getInstance().getString("CrossValidationFoldMaker_PerformRequest_IllegalArgumentException_Text"));
    }
  }

  /**
   * Returns true, if at the current time, the named event could
   * be generated. Assumes that the supplied event name is
   * an event that could be generated by this bean
   *
   * @param eventName the name of the event in question
   * @return true if the named event could be generated at this point in
   * time
   */
  public boolean eventGeneratable(String eventName) {
    if (m_listenee == null) {
      return false;
    }
    
    if (m_listenee instanceof EventConstraints) {
      if (((EventConstraints)m_listenee).eventGeneratable("dataSet") ||
	  ((EventConstraints)m_listenee).eventGeneratable("trainingSet") ||
	  ((EventConstraints)m_listenee).eventGeneratable("testSet")) {
	return true;
      } else {
	return false;
      }
    }
    return true;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy