All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gate.plugin.learningframework.engines.EngineMBWekaWrapper Maven / Gradle / Ivy

Go to download

A GATE plugin that provides many different machine learning algorithms for a wide range of NLP-related machine learning tasks like text classification, tagging, or chunking.

There is a newer version: 4.2
Show newest version
/*
 * Copyright (c) 2015-2016 The University Of Sheffield.
 *
 * This file is part of gateplugin-LearningFramework 
 * (see https://github.com/GateNLP/gateplugin-LearningFramework).
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 2.1 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this software. If not, see .
 */
package gate.plugin.learningframework.engines;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import gate.Annotation;
import gate.AnnotationSet;
import gate.lib.interaction.data.SparseDoubleVector;
import gate.lib.interaction.process.Process4ObjectStream;
import gate.lib.interaction.process.ProcessBase;
import gate.lib.interaction.process.ProcessSimple;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.export.Exporter;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.Globals;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.export.CorpusExporter;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.Files;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.yaml.snakeyaml.Yaml;

/**
 * An engine that represents Weka through en external process.
 * 
 * 
 * This requires that the user configures the location of where weka-wrapper is installed.
 * This can be done by setting the environment variable WEKA_WRAPPPER_HOME, the Java property
 * gate.plugin.learningframework.wekawrapper.home or by adding another yaml file "weka.yaml" 
 * to the data directory which contains the setting wekawrapper.home.
 * If the path starts with a slash
 * it is an absolute path, otherwise the path is resolved relative to the 
 * directory. 
 * 
 * The data directory also needs to contain files lf.model, pipe.pipe, header.arff
 * 
 * 
 * @author Johann Petrak
 */
public class EngineMBWekaWrapper extends EngineMB {

  protected ProcessBase process;
  
  // These variables get set from the wrapper-specific config file, java properties or
  // environment variables.
  private String shellcmd = null;
  private String shellparms = null;
  private String wrapperhome = null;
  private boolean linuxLike = true;
  private boolean windowsLike = false;
  
  protected final String ENV_WRAPPER_HOME = "WEKA_WRAPPER_HOME";
  protected CorpusExporter corpusExporter = null;
  
  @Override
  protected void initWhenCreating(URL directory, Algorithm algorithm, String parameters, FeatureInfo fi, TargetType tt) {
    //Previously, this would create the proper corpus representation in the MB base class,
    //now we instead create the corpus exporter we use later and get the CR from it
    //super.initWhenCreating(directory, algorithm, parameters, fi, tt);
    corpusExporter = CorpusExporter.create(
            Exporter.CSV_CL_MR, "-t -n "+parameters, featureInfo, parameters, 
            directory);
    corpusRepresentation = (CorpusRepresentationMallet)corpusExporter.getCorpusRepresentation();
  } 

  /**
   * Try to find the script running the Weka-Wrapper command.
   * If apply is true, the executable for application is searched,
   * otherwise the one for training.
   * This checks the following settings (increasing priority): 
   * environment variable WEKA_WRAPPER_HOME,
   * java property gate.plugin.learningframework.wekawrapper.home and
   * the setting "wekawrapper.home" in file "weka.yaml" in the data directory,
   * if it exists. 
   * The setting for the weka wrapper home can be relative in which case it
   * will be resolved relative to the dataDirectory
   * @param dataDirectory
   * @return 
   */
  private File findWrapperCommand(File dataDirectory, boolean apply) {
    String homeDir = System.getenv(ENV_WRAPPER_HOME);
    String tmp = System.getProperty("gate.plugin.learningframework.wekawrapper.home");
    if(tmp!=null) {
      homeDir = tmp;
    }
    File wekaInfoFile = new File(dataDirectory,"weka.yaml");
    if(wekaInfoFile.exists()) {
      Yaml yaml = new Yaml();
      Object obj;
      try {
        obj = yaml.load(new InputStreamReader(new FileInputStream(wekaInfoFile),"UTF-8"));
      } catch (FileNotFoundException | UnsupportedEncodingException ex) {
        throw new GateRuntimeException("Could not load yaml file "+wekaInfoFile,ex);
      }    
      tmp = null;
      Map map = null;
      if(obj instanceof Map) {
        map = (Map)obj;
        tmp = (String)map.get("wekawrapper.home");      
      } else {
        throw new GateRuntimeException("Info file has strange format: "+wekaInfoFile.getAbsolutePath());
      }
      if(tmp != null) {
        homeDir = tmp;
      }      
      // Also get any other settings that may be present:
      // shell command
      shellcmd = (String)map.get("shellcmd");
      shellparms = (String)map.get("shellparms");
    }
    if(homeDir == null) {
      throw new GateRuntimeException("WekaWrapper home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingWeka");
    }
    File wrapperHome = new File(homeDir);
    if(!wrapperHome.isAbsolute()) {
      wrapperHome = new File(dataDirectory,homeDir);
    }
    if(!wrapperHome.isDirectory()) {
      throw new GateRuntimeException("WekaWrapper home is not a directory: "+wrapperHome.getAbsolutePath());
    }
    wrapperhome = wrapperHome.getAbsolutePath();
    // Now, depending on the operating system, and on train/apply,
    // find the correct script to execute
    File commandFile;
    // we use the simple heuristic that if the file separator is "/" 
    // we assume we can use the bash script, if it is "\" we use the windows
    // script and otherwise we give up
    linuxLike = System.getProperty("file.separator").equals("/");
    windowsLike = System.getProperty("file.separator").equals("\\");
    if(linuxLike) {
      if(apply) { 
        commandFile = new File(new File(wrapperHome,"bin"),"wekaWrapperApply.sh");
      } else {
        commandFile = new File(new File(wrapperHome,"bin"),"wekaWrapperTrain.sh");
      }
    } else if(windowsLike) {
      if(apply) { 
        commandFile = new File(new File(wrapperHome,"bin"),"wekaWrapperApply.cmd");
      } else {
        commandFile = new File(new File(wrapperHome,"bin"),"wekaWrapperTrain.cmd");
      }      
    } else {
      throw new GateRuntimeException("It appears this OS is not supported");
    }
    commandFile = commandFile.isAbsolute() ? 
            commandFile :
            new File(dataDirectory,commandFile.getPath());
    if(!commandFile.canExecute()) {
      throw new GateRuntimeException("Not an executable file or not found: "+commandFile+" please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingWeka");
    }
    return commandFile;
  }
  
  
  @Override
  protected void loadModel(URL directoryURL, String parms) {
    ArrayList finalCommand = new ArrayList<>();
    // TODO: for now, we only allow URLs which are file: URLs here.
    // This is because the script wrapping Weka is currently not able to access 
    // the model from any other location. Also, we need to export the 
    // data and currently this is done into the directoryURL.
    // At some later point, we may be able to e.g. copy the model into 
    // a temporary directory and use the demporary directory also to store 
    // the data! 
    File directoryFile = null;
    if("file".equals(directoryURL.getProtocol())) {
      directoryFile = Files.fileFromURL(directoryURL);
    } else {
      throw new GateRuntimeException("The dataDirectory for WekaWrapper must be a file: URL");
    }
    
    // Instead of loading a model, this establishes a connection with the 
    // external weka process. For this, we expect an additional file in the 
    // directory, weka.yaml, which describes how to run the weka wrapper
    File commandFile = findWrapperCommand(directoryFile, true);
    // If the directoryURL 
    String modelFileName = new File(directoryFile,FILENAME_MODEL).getAbsolutePath();
    if(!new File(modelFileName).exists()) {
      throw new GateRuntimeException("File not found: "+modelFileName);
    }
    String header = new File(directoryFile,"header.arff").getAbsolutePath();
    if(!new File(header).exists()) {
      throw new GateRuntimeException("File not found: "+header);
    }
    
    if(shellcmd != null) {
      finalCommand.add(shellcmd);
      if(shellparms != null) {
        String[] sps = shellparms.trim().split("\\s+");
        for(String sp : sps) { finalCommand.add(sp); }
      }
    }
    
    finalCommand.add(commandFile.getAbsolutePath());
    finalCommand.add(modelFileName);
    finalCommand.add(header);
    
    //System.err.println("Running: "+finalCommand);
    // Create a fake Model jsut to make LF_Apply... happy which checks if this is null
    model = "ExternalWekaWrapperModel";
    Map env = new HashMap<>();
    env.put(ENV_WRAPPER_HOME,wrapperhome);
    // NOTE: if the directoryFile is null, the current Java process' directory is used
    process = Process4ObjectStream.create(directoryFile,env,finalCommand);
  }

  @Override
  protected void saveModel(File directory) {
    // NOTE: we do not need to save the model here because the external
    // WekaWrapper command does this.
    // However we still need to make sure a usable info file is saved!
    info.engineClass = EngineMBWekaWrapper.class.getName();
    info.save(directory);
  }

  @Override
  public void trainModel(File dataDirectory, String instanceType, String parms) {
    ArrayList finalCommand = new ArrayList<>();
    // TODO: invoke the weka wrapper
    // NOTE: for this the first word in parms must be the full weka class name, the rest are parms
    if(parms == null || parms.trim().isEmpty()) {
      throw new GateRuntimeException("Cannot train using WekaWrapper, algorithmParameter must contain Weka algorithm class as first word");
    }
    String wekaClass;
    String wekaParms = "";
    parms = parms.trim();
    int spaceIdx = parms.indexOf(" ");
    if(spaceIdx<0) {
      wekaClass = parms;
    } else {
      wekaClass = parms.substring(0,spaceIdx);
      wekaParms = parms.substring(spaceIdx).trim();
    }
    File commandFile = findWrapperCommand(dataDirectory, false);
    // Export the data 
    // Note: any scaling was already done in the PR before calling this method!
    // find out if we train classification or regression
    // NOTE: not sure if classification/regression matters here as long as
    // the actual exporter class does the right thing based on the corpus representation!
    
    // Was previously:
    //Exporter.export(corpusRepresentation, 
    //        Exporter.ARFF_CL_MR, dataDirectory, instanceType, parms);
    corpusExporter.export();

    String dataFileName = new File(dataDirectory,Globals.dataBasename+".arff").getAbsolutePath();
    String modelFileName = new File(dataDirectory, FILENAME_MODEL).getAbsolutePath();

    if(shellcmd != null) {
      finalCommand.add(shellcmd);
      if(shellparms != null) {
        String[] sps = shellparms.trim().split("\\s+");
        for(String sp : sps) { finalCommand.add(sp); }
      }
    }
    
    finalCommand.add(commandFile.getAbsolutePath());
    finalCommand.add(dataFileName);
    finalCommand.add(modelFileName);
    finalCommand.add(wekaClass);
    if(!wekaParms.isEmpty()) {
      String[] tmp = wekaParms.split("\\s+",-1);
      finalCommand.addAll(Arrays.asList(tmp));
    }
    // Create a fake Model jsut to make LF_Apply... happy which checks if this is null
    model = "ExternalWekaWrapperModel";
    
    model = "ExternalWekaWrapperModel";
    Map env = new HashMap<>();
    process = ProcessSimple.create(dataDirectory,env,finalCommand);
    process.waitFor();
    SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    info.modelWhenTrained = sdf.format(new Date());    
    info.save(dataDirectory);    
    featureInfo.save(dataDirectory);
  }

  @Override
  public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) {
    throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
  }

  @Override
  public List applyModel(AnnotationSet instanceAS, AnnotationSet inputAS, 
          AnnotationSet sequenceAS, String parms) {
    CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentation;
    data.stopGrowth();
    //System.err.println("Running EngineWeka.applyModel on document "+instanceAS.getDocument().getName());
    List gcs = new ArrayList<>();
    LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe();
    for(Annotation instAnn : instanceAS.inDocumentOrder()) {
      Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
      
      //FeatureVector fv = (FeatureVector)inst.getData();      
      //System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
      inst = pipe.instanceFrom(inst);
      
      FeatureVector fv = (FeatureVector)inst.getData();
      //System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
      
      double weight = Double.NaN;
      Object weightObj = inst.getProperty("instanceWeight");
      if(weightObj != null) {
        weight = (double)weightObj;
      }
      // Convert to the sparse vector we use to send to the weka process
      int locs = fv.numLocations();
      SparseDoubleVector sdv = new SparseDoubleVector(locs);
      sdv.setInstanceWeight(weight);
      int[] locations = sdv.getLocations();
      double[] values = sdv.getValues();
      for(int i=0;i= 2
        if(ret.length < 2) {
          throw new RuntimeException("We think we have classification but Weka process sent a ret of length "+ret.length);
        }
        double bestprob = 0.0;
        int bestlabel = 0;
        /*
        System.err.print("DEBUG: got classes from pipe: ");
    		Object[] cls = pipe.getTargetAlphabet().toArray();
        boolean first = true;
        for(Object cl : cls) {
          if(first) { first = false; } else { System.err.print(", "); }
          System.err.print(">"+cl+"<");
        }
        System.err.println();
         */
        List classList = new ArrayList<>();
        List confidenceList = new ArrayList<>();
        for (int i = 0; i < ret.length; i++) {
          int thislabel = i;
          double thisprob = ret[i];
          String labelstr = pipe.getTargetAlphabet().lookupObject(thislabel).toString();
          classList.add(labelstr);
          confidenceList.add(thisprob);
          if (thisprob > bestprob) {
            bestlabel = thislabel;
            bestprob = thisprob;
          }
        } // end for i < predictionDistribution.length

        String cl
                = pipe.getTargetAlphabet().lookupObject(bestlabel).toString();

        gc = new ModelApplication(
                instAnn, cl, bestprob, classList, confidenceList);
      }
      gcs.add(gc);
    }
    data.startGrowth();
    return gcs;
  }

  @Override
  public void initializeAlgorithm(Algorithm algorithm, String parms) {
    // do not do anything
  }

  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy