All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hr.irb.fastRandomForest.Benchmark Maven / Gradle / Ivy

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    Benchmark.java
 *    Copyright (C) 2008 Fran Supek
 */
package hr.irb.fastRandomForest;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;
import weka.core.Utils;
import weka.experiment.PairedStatsCorrected;

/**
 * Runs 10 iterations of 10-fold crossvalidation on the supplied arff file(s)
 * using weka.classifiers.trees.RandomForest, and FastRandomForest, and prints
 * accuracy, AUC (averaged over all classes) and computation time to stdout.
 * 
 * As first command-line parameter, supply an arff file, a directory name, or
 * a text file with list of arff datasets.
 * 
 * As second command-line parameter, supply the number of trees in the forests.
 * (-I parameter for the classifiers).
 *
 * As third command-line parameter, supply a comma separated list of number of
 * threads to use, eg. "1,2,4". If ommited, default value is assumed (autodetect
 * number of cores in machine).
 * 
 * @author Fran Supek (fran.supek[AT]irb.hr)
 */
public class Benchmark {

  public static final int numRuns = 10;

  public static final int numFolds = 10;

  public static void main(String[] args) throws Exception {

    List trainFiles =
            getMatchingFiles(args[0], ".arff");

    List threadNums = new ArrayList();
    if ( args.length < 3 ) {
      threadNums.add(0);
    } else {
      String[] arr = args[2].split(",");
      for ( String curNum : arr ) {
        threadNums.add( Integer.parseInt(curNum) );
      }
    }

    // create classifiers to compare and set their parameters
    AbstractClassifier[] classifiers = new AbstractClassifier[ threadNums.size() * 2 ];

    for ( int i = 0; i < threadNums.size() * 2; i += 2 ) {

      classifiers[i] = new weka.classifiers.trees.RandomForest();
      classifiers[i].setOptions(new String[]{"-I", args[1],
        "-num-slots", Integer.toString(threadNums.get(i)) });
      
      classifiers[i+1] = new hr.irb.fastRandomForest.FastRandomForest();
      classifiers[i+1].setOptions(new String[]{"-I", args[1], // "-import",
        "-threads", Integer.toString(threadNums.get(i)) });
      
    }

    StringBuilder s = new StringBuilder("dataset\tnumInstances\tnumNumericAtt\t"
            + "numNominalAtt\tnumClasses");
    for (int i = 0; i < classifiers.length * numRuns; i++) {
      s.append("\tcfr|run\tAUC\tAccy\tmillis");
    }
    s.append("\tsummary");
    System.err.println(s.toString());


    for (File curArff : trainFiles) {

      // load data
      Instances data = new weka.core.converters.ConverterUtils.DataSource(curArff.toString()).getDataSet();
      if (data.classIndex() == -1)
        data.setClassIndex(data.numAttributes() - 1);
      data.deleteWithMissingClass();

      // count numeric and nominal attributes
      int numNumeric = 0, numNominal = 0;
      for (int i = 0; i < data.numAttributes(); i++) {
        if ( data.classIndex()==i )
          continue;
        if ( data.attribute(i).isNominal() ) 
          numNominal++;
        if ( data.attribute(i).isNumeric() ) 
          numNumeric++;
      }
      System.err.printf( "%s\t%d\t%d\t%d\t%d\t", curArff.getName(), data.numInstances(), numNumeric, numNominal, data.numClasses() );
     

      /* We have adopted a generalization of AUC score to multiclass problems as
       * described in Provost and Domingos (CeDER Working Paper #IS-00-04, Stern
       * School of Business, New York University, 2001), computed by taking a
       * weighted average over all one-vs-all binary classification problems
       * that can be derived from the multiclass problem, where weights
       * correspond to class prior probabilities. */
      double[] classProps = new double[data.numClasses()];
      for ( int i = 0; i < data.numInstances(); i++ )
        classProps[ (int) data.instance(i).classValue() ] += data.instance(i).weight();
      Utils.normalize(classProps);

      double[][] aucScore = new double[classifiers.length][numRuns];
      double[][] accyScore = new double[classifiers.length][numRuns];
      double[][] timeScore = new double[classifiers.length][numRuns];

      for (int curRun = 1; curRun <= numRuns; curRun++) {

        s = new StringBuilder(); 

        for (int curCfr = 0; curCfr < classifiers.length; curCfr++ ) {

          AbstractClassifier aClassifier = classifiers[curCfr];

          Evaluation eval = new Evaluation(data);

          Long millis = System.currentTimeMillis();
          eval.crossValidateModel(aClassifier, data, numFolds, new Random(curRun));
          long elapsedTime = System.currentTimeMillis() - millis;

          double aucSum = 0.0;
          double sumClassProps = 0;
          for (int c = 0; c < data.numClasses(); c++) {
            if (Double.isNaN(eval.areaUnderROC(c)))
              continue;
            aucSum += eval.areaUnderROC(c) * classProps[c];
            // this should sum to 1.0 in the end, as all the classes with AUC==NaN should have weight 0
            sumClassProps += classProps[c];
          }


          aucScore[curCfr][curRun-1] = aucSum / sumClassProps;
          accyScore[curCfr][curRun-1] = eval.pctCorrect();
          timeScore[curCfr][curRun-1] = elapsedTime;

          s.append(String.format( Locale.US, "%02d|%02d\t%.5f\t%.2f\t%6d\t",
                  curCfr, curRun, aucSum / sumClassProps,
                  eval.pctCorrect(), elapsedTime));

          System.gc();

        } // classifier by classifier

        System.err.print(s.toString());
        
      } // run by run
      
      // the t-test for accuracy is always performed only between classifier 0 and classifier 1 
      // meaning, the first instance of Weka RF and first instance of FastRF
      // the following instances use a different # of threads but that doesn't affect results
      double testTrainRatio = 1 / (double) (numFolds - 1);
      PairedStatsCorrected pscAuc = new PairedStatsCorrected(0.05, testTrainRatio);
      pscAuc.add(aucScore[0], aucScore[1]);
      pscAuc.calculateDerived();
      PairedStatsCorrected pscAccy = new PairedStatsCorrected(0.05, testTrainRatio);
      pscAccy.add(accyScore[0], accyScore[1]);
      pscAccy.calculateDerived();
      PairedStatsCorrected pscTime = new PairedStatsCorrected(0.05, testTrainRatio);
      pscTime.add(timeScore[0], timeScore[1]);
      pscTime.calculateDerived();

      System.err.printf( Locale.US, "| Statistical significance of difference in mean of " +
              "AUC scores is p=%6.4f (%s wins), " +
              "in accuracy is p=%6.4f (%s wins). " +
              " Average speedup is: %4.2f times.\n",
              pscAuc.differencesProbability, getTextForSignificance( pscAuc.differencesSignificance, "WekaRF", "FastRF" ),
              pscAccy.differencesProbability, getTextForSignificance( pscAccy.differencesSignificance, "WekaRF", "FastRF" ),
              pscTime.xStats.mean / pscTime.yStats.mean
              );

    } // arff by arff

    
  }

  

  /**
   * When supplied with a directory name, returns an ArrayList with all the
   * files inside that directory that have the specified extension.
   *  
   * When supplied with a single filename...
   * (a) if the extension matches the specified extension, returs an
   *     ArrayList with a single File object inside.
   * (b) if the extension DOES NOT match the specifed extension, treats the
   *     file as a list and extracts filenames from it - one per line - and
   *     returns them within the ArrayList
   */
  public static List getMatchingFiles(
          String fileOrDir, String extension)
          throws FileNotFoundException, IOException {

    ArrayList result = new ArrayList();

    File myFile = new File(fileOrDir);

    String myExt;
    if (extension.length() > 0 && extension.charAt(0) != '.')
      myExt = "." + extension;
    else
      myExt = extension;

    if (!myFile.exists())
      throw new FileNotFoundException(
              "Specified File or directory doesn't exist!");

    if (myFile.isDirectory()) {

      File[] trainFiles;
      trainFiles = new File(fileOrDir).listFiles();
      for (int i = 0; i < trainFiles.length; i++)
        if (trainFiles[i].getName().endsWith(myExt))
          result.add(trainFiles[i]);

    } else if (myFile.getName().endsWith(myExt)) {

      result.add(myFile);

    } else {

      BufferedReader bufRdr = new BufferedReader(new FileReader(fileOrDir));
      String line = null;
      while ((line = bufRdr.readLine()) != null) {
        if (line.endsWith(myExt))
          result.add(new File(line));
        else
          result.add(new File(line + myExt));
      }

    }

    return result;

  }

  
  private static String getTextForSignificance( int significanceFlag, String party1, String party2 ) {
    
    if ( significanceFlag == 0 ) 
      return "noone";
    else if ( significanceFlag > 0 )
      return party1;
    else
      return party2;
    
  }
  

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy