gate.plugin.learningframework.engines.EngineMBSklearnBase Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of learningframework Show documentation
Show all versions of learningframework Show documentation
A GATE plugin that provides many different machine learning
algorithms for a wide range of NLP-related machine learning tasks like
text classification, tagging, or chunking.
/*
* Copyright (c) 2015-2016 The University Of Sheffield.
*
* This file is part of gateplugin-LearningFramework
* (see https://github.com/GateNLP/gateplugin-LearningFramework).
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 2.1 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this software. If not, see .
*/
package gate.plugin.learningframework.engines;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import gate.Annotation;
import gate.AnnotationSet;
import gate.lib.interaction.process.Process4JsonStream;
import gate.lib.interaction.process.ProcessBase;
import gate.lib.interaction.process.ProcessSimple;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.export.Exporter;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.export.CorpusExporter;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.Files;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.yaml.snakeyaml.Yaml;
/**
* An engine that represents Python Scikit Learn through en external process.
*
* This requires that the user configures the location of where sklearn-wrapper is installed.
* This can be done by setting the environment variable SKLEARN_WRAPPPER_HOME, the Java property
* gate.plugin.learningframework.sklearnwrapper.home or by adding another yaml file "sklearn.yaml"
* to the data directory which contains the setting sklearnwrapper.home.
* If the path starts with a slash
* it is an absolute path, otherwise the path is resolved relative to the
* directory.
*
*
* @author Johann Petrak
*/
public abstract class EngineMBSklearnBase extends EngineMB {
// constants for the wrapper
protected String WRAPPER_NAME;
protected String ENV_WRAPPER_HOME;
protected String PROP_WRAPPER_HOME;
protected String YAML_FILE;
protected String YAML_SETTING_WRAPPER_HOME;
protected String SCRIPT_APPLY_BASENAME;
protected String SCRIPT_TRAIN_BASENAME;
protected String SCRIPT_EVAL_BASENAME;
protected String MODEL_BASENAME;
protected Object MODEL_INSTANCE;
protected ProcessBase process;
// These variables get set from the wrapper-specific config file, java properties or
// environment variables.
protected String shellcmd = null;
protected String shellparms = null;
protected String wrapperhome = null;
protected CorpusExporter corpusExporter = null;
@Override
protected void initWhenCreating(URL directory, Algorithm algorithm, String parameters, FeatureInfo fi, TargetType tt) {
//Previously, this would create the proper corpus representation in the MB base class,
//now we instead create the corpus exporter we use later and get the CR from it
//super.initWhenCreating(directory, algorithm, parameters, fi, tt);
corpusExporter = CorpusExporter.create(Exporter.CSV_CL_MR, "-t -n "+parameters, featureInfo, parameters, directory);
corpusRepresentation = (CorpusRepresentationMallet)corpusExporter.getCorpusRepresentation();
}
/**
* Try to find the script running the sklearn-Wrapper command.
*
* If apply is true, the executable for application is searched,
* otherwise the one for training.
* This checks the following settings (increasing priority):
* environment variable SKLEARN_WRAPPER_HOME,
* java property gate.plugin.learningframework.sklearnwrapper.home and
* the setting "sklearnwrapper.home" in file "sklearn.yaml" in the data directory,
* if it exists.
* The setting for the sklearn wrapper home can be relative in which case it
* will be resolved relative to the dataDirectory
*
* @param dataDirectory data/model directory file
* @param apply true for application, false for training
* @return command path file
*/
protected File findWrapperCommand(File dataDirectory, boolean apply) {
String homeDir = System.getenv(ENV_WRAPPER_HOME);
String tmp = System.getProperty(PROP_WRAPPER_HOME);
if(tmp!=null) {
homeDir = tmp;
}
File sklearnInfoFile = new File(dataDirectory,YAML_FILE);
if(sklearnInfoFile.exists()) {
Yaml yaml = new Yaml();
Object obj;
try {
obj = yaml.load(new InputStreamReader(new FileInputStream(sklearnInfoFile),"UTF-8"));
} catch (FileNotFoundException | UnsupportedEncodingException ex) {
throw new GateRuntimeException("Could not load yaml file "+sklearnInfoFile,ex);
}
tmp = null;
Map map = null;
if(obj instanceof Map) {
map = (Map)obj;
tmp = (String)map.get(YAML_SETTING_WRAPPER_HOME);
} else {
throw new GateRuntimeException("Info file has strange format: "+sklearnInfoFile.getAbsolutePath());
}
if(tmp != null) {
homeDir = tmp;
}
// Also get any other settings that may be present:
// shell command
shellcmd = (String)map.get("shellcmd");
shellparms = (String)map.get("shellparms");
}
if(homeDir == null) {
throw new GateRuntimeException(WRAPPER_NAME+" home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn");
}
File wrapperHome = new File(homeDir);
if(!wrapperHome.isAbsolute()) {
wrapperHome = new File(dataDirectory,homeDir);
}
if(!wrapperHome.isDirectory()) {
throw new GateRuntimeException(WRAPPER_NAME+" home is not a directory: "+wrapperHome.getAbsolutePath());
}
wrapperhome = wrapperHome.getAbsolutePath();
// Now, depending on the operating system, and on train/apply,
// find the correct script to execute
File commandFile;
// we use the simple heuristic that if the file separator is "/"
// we assume we can use the bash script, if it is "\" we use the windows
// script and otherwise we give up
boolean linuxLike = System.getProperty("file.separator").equals("/");
boolean windowsLike = System.getProperty("file.separator").equals("\\");
if(linuxLike) {
if(apply) {
commandFile = new File(new File(wrapperHome,"bin"),SCRIPT_APPLY_BASENAME+".sh");
} else {
commandFile = new File(new File(wrapperHome,"bin"),SCRIPT_TRAIN_BASENAME+".sh");
}
} else if(windowsLike) {
if(apply) {
commandFile = new File(new File(wrapperHome,"bin"),SCRIPT_APPLY_BASENAME+".cmd");
} else {
commandFile = new File(new File(wrapperHome,"bin"),SCRIPT_TRAIN_BASENAME+".cmd");
}
} else {
throw new GateRuntimeException("It appears this OS is not supported");
}
commandFile = commandFile.isAbsolute() ?
commandFile :
new File(dataDirectory,commandFile.getPath());
if(!commandFile.canExecute()) {
throw new GateRuntimeException("Not an executable file or not found: "+commandFile+" please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn");
}
return commandFile;
}
@Override
protected void loadModel(URL directoryURL, String parms) {
ArrayList finalCommand = new ArrayList<>();
// Instead of loading a model, this establishes a connection with the
// external sklearn process.
if(!"file".equals(directoryURL.getProtocol())) {
throw new GateRuntimeException("The dataDirectory URL must be a file: URL for sklearn");
}
File directory = Files.fileFromURL(directoryURL);
File commandFile = findWrapperCommand(directory, true);
String modelFileName = new File(directory,MODEL_BASENAME).getAbsolutePath();
finalCommand.add(commandFile.getAbsolutePath());
finalCommand.add(modelFileName);
// if we have a shell command prepend that, and if we have shell parms too, include them
if(shellcmd != null) {
finalCommand.add(0,shellcmd);
if(shellparms != null) {
String[] sps = shellparms.trim().split("\\s+");
int i=0; for(String sp : sps) { finalCommand.add(++i,sp); }
}
}
//System.err.println("Running: "+finalCommand);
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
model = MODEL_INSTANCE;
Map env = new HashMap<>();
env.put(ENV_WRAPPER_HOME, wrapperhome);
process = Process4JsonStream.create(directory,env,finalCommand);
}
@Override
protected void saveModel(File directory) {
// NOTE: we do not need to save the model here because the external
// sklearnWrapper command does this.
// However we still need to make sure a usable info file is saved!
info.engineClass = this.getClass().getName();
info.save(directory);
}
@Override
public void trainModel(File dataDirectory, String instanceType, String parms) {
ArrayList finalCommand = new ArrayList();
// invoke the sklearn wrapper for training
// NOTE: for this the first word in parms must be the full sklearn class name, the rest are parms
if(parms == null || parms.trim().isEmpty()) {
throw new GateRuntimeException(WRAPPER_NAME+": Cannot train, algorithmParameter must contain full algorithm class name as first word");
}
String sklearnClass;
String sklearnParms = "";
parms = parms.trim();
int spaceIdx = parms.indexOf(" ");
if(spaceIdx<0) {
sklearnClass = parms;
} else {
sklearnClass = parms.substring(0,spaceIdx);
sklearnParms = parms.substring(spaceIdx).trim();
}
File commandFile = findWrapperCommand(dataDirectory, false);
// Export the data
// Note: any scaling was already done in the PR before calling this method!
// find out if we train classification or regression
// TODO: NOTE: not sure if classification/regression matters here as long as
// the actual exporter class does the right thing based on the corpus representation!
// was previously:
//Exporter.export(corpusRepresentation,
// Exporter.MatrixMarket2_CL_MR, dataDirectory, instanceType, parms);
corpusExporter.export();
String dataFileName = dataDirectory.getAbsolutePath()+File.separator;
String modelFileName = new File(dataDirectory, MODEL_BASENAME).getAbsolutePath();
finalCommand.add(commandFile.getAbsolutePath());
finalCommand.add(dataFileName);
finalCommand.add(modelFileName);
finalCommand.add(sklearnClass);
if(!sklearnParms.isEmpty()) {
String[] tmp = sklearnParms.trim().split("\\s+",-1);
finalCommand.addAll(Arrays.asList(tmp));
}
// if we have a shell command prepend that, and if we have shell parms too, include them
if(shellcmd != null) {
finalCommand.add(0,shellcmd);
if(shellparms != null) {
String[] sps = shellparms.trim().split("\\s+");
int i=0; for(String sp : sps) { finalCommand.add(++i,sp); }
}
}
//System.err.println("Running: ");
//for(int i=0; i"+finalCommand.get(i)+"<");
//}
// Create a fake Model jsut to make LF_Apply... happy which checks if this is null
model = MODEL_INSTANCE;
Map env = new HashMap<>();
env.put(ENV_WRAPPER_HOME,wrapperhome);
process = ProcessSimple.create(dataDirectory,env,finalCommand);
process.waitFor();
updateInfo();
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
info.modelWhenTrained = sdf.format(new Date());
info.algorithmParameters = parms;
info.save(dataDirectory);
featureInfo.save(dataDirectory);
}
@Override
public EvaluationResult evaluate(String algorithmParameters, EvaluationMethod evaluationMethod, int numberOfFolds, double trainingFraction, int numberOfRepeats) {
throw new UnsupportedOperationException("Not supported yet."); //To change body of generated methods, choose Tools | Templates.
}
@Override
@SuppressWarnings("unchecked")
public List applyModel(AnnotationSet instanceAS, AnnotationSet inputAS,
AnnotationSet sequenceAS, String parms) {
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentation;
data.stopGrowth();
int nrCols = data.getPipe().getDataAlphabet().size();
//System.err.println("Running EngineSklearn.applyModel on document "+instanceAS.getDocument().getName());
List gcs = new ArrayList<>();
LFPipe pipe = (LFPipe)data.getRepresentationMallet().getPipe();
ArrayList classList = null;
// If we have a classification problem, pre-calculate the class label list
if(pipe.getTargetAlphabet() != null) {
classList = new ArrayList<>();
for(int i = 0; i map = new HashMap<>();
map.put("cmd", "CSR1");
ArrayList values = new ArrayList<>();
ArrayList rowinds = new ArrayList<>();
ArrayList colinds = new ArrayList<>();
int rowIndex = 0;
List instances = instanceAS.inDocumentOrder();
for(Annotation instAnn : instances) {
Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
//FeatureVector fv = (FeatureVector)inst.getData();
//System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
inst = pipe.instanceFrom(inst);
FeatureVector fv = (FeatureVector)inst.getData();
//System.out.println("Mallet instance, fv: "+fv.toString(true)+", len="+fv.numLocations());
// Convert to the sparse vector we use to send to the weka process
int locs = fv.numLocations();
for(int i=0;i response = null;
if(ret instanceof Map) {
@SuppressWarnings("unchecked")
Map tmpresponse = (Map)ret;
response = tmpresponse;
}
if(response == null) {
throw new RuntimeException("Got a response from Sklearn process which cannot be used: "+response);
}
// the response has the following format:
// - status: should be "OK" or an error message
// - targets: a vector of target indices/values
// - probas: if probabilities are supported, a vector of vectors of class probabilities, otherwise null
String status = (String)response.get("status");
if(status == null || !status.equals("OK")) {
throw new RuntimeException("Status of response is not OK but "+status);
}
@SuppressWarnings("unchecked")
ArrayList targets = (ArrayList)response.get("targets");
@SuppressWarnings("unchecked")
ArrayList> probas = (ArrayList>)response.get("probas");
ModelApplication gc;
// now check if the mallet representation and the weka process agree
// on if we have regression or classification
if(pipe.getTargetAlphabet() == null) {
// we expect a regression result, i.e probas should be null
if(probas != null) {
throw new RuntimeException("We think we have regression but the Sklearn process sent probabilities");
}
}
// now go through all the instances again and do the target assignment from the vector(s) we got
int instNr = 0;
for(Annotation instAnn : instances) {
if(pipe.getTargetAlphabet() == null) { // we have regression
gc = new ModelApplication(instAnn, targets.get(instNr));
} else {
int bestlabel = targets.get(instNr).intValue();
String cl
= pipe.getTargetAlphabet().lookupObject(bestlabel).toString();
double bestprob = Double.NaN;
if(probas != null) {
bestprob = Collections.max(probas.get(instNr));
}
gc = new ModelApplication(
instAnn, cl, bestprob, classList, probas.get(instNr));
}
gcs.add(gc);
instNr++;
}
data.startGrowth();
return gcs;
}
@Override
public void initializeAlgorithm(Algorithm algorithm, String parms) {
// do not do anything
}
}