All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gate.plugin.learningframework.engines.EngineMBLibSVM Maven / Gradle / Ivy

Go to download

A GATE plugin that provides many different machine learning algorithms for a wide range of NLP-related machine learning tasks like text classification, tagging, or chunking.

There is a newer version: 4.2
Show newest version
/*
 * Copyright (c) 2015-2016 The University Of Sheffield.
 *
 * This file is part of gateplugin-LearningFramework 
 * (see https://github.com/GateNLP/gateplugin-LearningFramework).
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 2.1 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this software. If not, see .
 */

package gate.plugin.learningframework.engines;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import gate.Annotation;
import gate.AnnotationSet;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.data.CorpusRepresentationLibSVM;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.Files;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;
import libsvm.svm;
import static libsvm.svm.svm_set_print_string_function;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_print_interface;
import libsvm.svm_problem;

/**
 *
 * @author Johann Petrak
 */
public class EngineMBLibSVM extends EngineMB {


  @Override
  public void loadModel(URL directory, String parms) {
    if(!"file".equals(directory.getProtocol())) {
      throw new GateRuntimeException("The dataDirectory URL must be a file: URL for LibSVM");
    }
    try {      
      File directoryFile = Files.fileFromURL(directory);
      svm_model svmModel = svm.svm_load_model(new File(directoryFile, FILENAME_MODEL).getAbsolutePath());
      // System.out.println("Loaded LIBSVM model, nrclasses=" + svmModel.nr_class);
      model = svmModel;
    } catch (IOException | IllegalArgumentException ex) {
      throw new GateRuntimeException("Error loading the LIBSVM model from directory "+directory, ex);
    }
  }

  private svm_parameter makeSvmParms(String parms) {
    int nrIndepFeatures = corpusRepresentation.getRepresentationMallet().getDataAlphabet().size();
    double defaultGamma = 1.0 / nrIndepFeatures;
    // Parse all the necessary parameters. see
    // https://www.csie.ntu.edu.tw/~cjlin/libsvm/
    Parms ps = new Parms(parms, "s:svm_type:i", "t:kernel_type:i", "d:degree:i", "g:gamma:d",
            "r:coef0:d", "c:cost:d", "n:nu:d", "e:epsilon:d", "m:cachesize:i", "h:shrinking:i",
            "b:probability_estimates:i");
    svm_parameter svmparms = new svm_parameter();
    // we use 0 as the default for classification
    svmparms.svm_type = (int) ps.getValueOrElse("svm_type", 0);
    
    // immediately check if the svm type is compatible with regression or classification 
    // as it was selected
    if(algorithm instanceof AlgorithmRegression) {
      Integer algType = (Integer)ps.getValue("svm_type");
      // if the parameter is not given at all, we use 3 as the default for regression
      if(algType == null) {
        svmparms.svm_type = 3;
      }
      if(svmparms.svm_type != 3 && svmparms.svm_type != 4) {
        throw new GateRuntimeException("SvmLib: only -s 3 or -s 4 allowed for regression");
      }
    } else {
      if(svmparms.svm_type != 0 && svmparms.svm_type != 1) {
        throw new GateRuntimeException("SvmLib: only -s 0 or -s 1 allowed for classification");
      }      
    }
    
    svmparms.kernel_type = (int) ps.getValueOrElse("kernel_type", 2);
    svmparms.degree = (int) ps.getValueOrElse("degree", 3);
    svmparms.gamma = (double) ps.getValueOrElse("gamma", defaultGamma);
    svmparms.coef0 = (double) ps.getValueOrElse("coef0", 0.0);
    svmparms.C = (double) ps.getValueOrElse("cost", 1.0);
    svmparms.nu = (double) ps.getValueOrElse("nu", 0.5);
    svmparms.eps = (double) ps.getValueOrElse("epsilon", 0.1);
    svmparms.cache_size = (int) ps.getValueOrElse("cachesize", 100);
    svmparms.shrinking = (int) ps.getValueOrElse("shrinking", 1);
    svmparms.probability = (int) ps.getValueOrElse("probability_estimates", 1); // THIS ONE DIFFERS FROM SVMLIB DEFAULT!
    // for the weights, we need a different strategy: our Parms class cannot parse arbitrary 
    // numbered options so we have to do it ourselves here
    List weights = new ArrayList<>();
    List featureNumbers = new ArrayList<>();
    // make sure we have a parameter at all before trying to parse it
    if (parms != null && !parms.isEmpty()) {
      String[] tokens = parms.split("\\s+", -1);
      for (int i = 0; i < tokens.length - 1; i++) {
        String token = tokens[i];
        if (token.startsWith("-w")) {
          // this should be a weight parameter: we only use it if it really only contains a number
          // in the option name and does have something that can be parsed as a double as its value,
          // otherwise we simply ignore
          if (token.substring(2).matches("[0-9]+")) {
            String valueString = tokens[i + 1];
            Double value = Double.NaN;
            try {
              value = Double.parseDouble(valueString);
            } catch (NumberFormatException ex) {
              // ignore this
            }
            if (!Double.isNaN(value)) {
              int fn = Integer.parseInt(token.substring(2));
              if (fn < nrIndepFeatures) {
                featureNumbers.add(fn);
                weights.add(value);
              }
            }
          }
        }
      } // for int=0; i 0) {
        double[] ws = new double[weights.size()];
        int[] idxs = new int[weights.size()];
        for (int i = 0; i < weights.size(); i++) {
          ws[i] = weights.get(i);
          idxs[i] = featureNumbers.get(i);
        }
        svmparms.weight = ws;
        svmparms.weight_label = idxs;
      }
    }
    return svmparms;
  }
  
  
  @Override
  public void trainModel(File dataDirectory, String instanceType, String parms) {

    // 1) calculate the default parameter values that depend on the data
    //int nrIndepFeatures = corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().size();
    //double defaultGamma = 1.0 / nrIndepFeatures;

    // we also support the additional parameter -S/seed  to set the random seed (default is 1)
    Parms ps = new Parms("S:seed:i");
    svm_parameter svmparms = makeSvmParms(parms);
    int seed = (int) ps.getValueOrElse("seed", 1);
    System.err.println("SVM parms used: (seed="+seed+") "+libsvmParmsAsString(svmparms));
    svm_set_print_string_function(new svm_print_interface() {
      @Override
      public void print(String string) {
        System.err.print(string);
      }
    });

    libsvm.svm.rand.setSeed(seed);

    // convert the mallet instances to svm problem. For this we can simply use the static method,
    // no need really to create an instance of CorpusRepresentationLibSVM for now
    svm_problem svmprob = CorpusRepresentationLibSVM.getFromMallet(corpusRepresentation);

    svm_model svmModel = libsvm.svm.svm_train(svmprob, svmparms);
    model = svmModel;
    updateInfo();
    SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    info.modelWhenTrained = sdf.format(new Date());    
    info.algorithmParameters = parms;
    info.save(dataDirectory);    
    featureInfo.save(dataDirectory);    
  }

  @Override
  public List applyModel(
          AnnotationSet instanceAS, AnnotationSet inputAS, AnnotationSet sequenceAS, String parms) {
    
    CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget) corpusRepresentation;
    data.stopGrowth();
    // try to figure out if we have regression or classification:
    LFPipe pipe = data.getPipe();
    Alphabet talph = pipe.getTargetAlphabet();
    int numberOfLabels = 0;
    if (talph == null) {
      // regression
    } else {
      // classification
      numberOfLabels = talph.size();
    }
    svm_model svmModel = (svm_model) model;
    // iterate over all the mallet instances
    List gcs = new ArrayList<>();
    for (Annotation instAnn : instanceAS.inDocumentOrder()) {
      Instance malletInstance = data.extractIndependentFeatures(instAnn, inputAS);
      malletInstance = pipe.instanceFrom(malletInstance);
      svm_node[] svmInstance = CorpusRepresentationLibSVM.libSVMInstanceIndepFromMalletInstance(malletInstance);

      double bestConf = 0.0;

      // TODO: not sure how to handle regression models here, so far this works only with
      // classification!?!
      if(algorithm instanceof AlgorithmRegression) {
        double prediction = svm.svm_predict(svmModel, svmInstance);
        ModelApplication gc = new ModelApplication(instAnn, prediction);
        gcs.add(gc);
      } else {
        int bestLabel = ((Double)(svm.svm_predict(svmModel, svmInstance))).intValue();
        if (svm.svm_check_probability_model(svmModel) == 1) {
          double[] confidences = new double[numberOfLabels];
          svm.svm_predict_probability(svmModel, svmInstance, confidences);
          bestConf = confidences[bestLabel];
        } else {
          double[] confidences = new double[numberOfLabels * (numberOfLabels - 1) / 2];
          svm.svm_predict_values(svmModel, svmInstance, confidences);
          //For now we are not providing decision values for non-prob
          //models because it is complex, see here: 
          //http://www.csie.ntu.edu.tw/~r94100/libsvm-2.8/README
        }

        String labelstr = pipe.getTargetAlphabet().lookupObject(bestLabel).toString();
        ModelApplication gc = new ModelApplication(
              instAnn, labelstr, bestConf);
        gcs.add(gc);
      }

    }
    data.startGrowth();
    return gcs;
  }

  @Override
  public void initializeAlgorithm(Algorithm algorithm, String parms) {
    // we always use a predefined class to train, so not really necessary to do antything here.
  }

  @Override
  public void saveModel(File directory) {
    try {
      svm.svm_save_model(new File(directory, FILENAME_MODEL).getAbsolutePath(), (svm_model) model);
    } catch (IOException e) {
      throw new GateRuntimeException("Error saving LIBSVM model", e);
    }
    // Since we do not have a proper model, save our info here
    info.save(directory);
  }

  
  private String libsvmParmsAsString(libsvm.svm_parameter parms) {
    StringBuilder sb = new StringBuilder();
    sb.append("svmparms{");
    sb.append("C=");
    sb.append(parms.C);
    sb.append(",cache_size=");
    sb.append(parms.cache_size);
    sb.append(",coef0=");
    sb.append(parms.coef0);
    sb.append(",degree=");
    sb.append(parms.degree);
    sb.append(",eps=");
    sb.append(parms.eps);
    sb.append(",gamma=");
    sb.append(parms.gamma);
    sb.append(",kernel_type=");
    sb.append(parms.kernel_type);
    sb.append(",nr_weight=");
    sb.append(parms.nr_weight);
    sb.append(",nu=");
    sb.append(parms.nu);
    sb.append(",p=");
    sb.append(parms.p);
    sb.append(",probability=");
    sb.append(parms.probability);
    sb.append(",shrinking=");
    sb.append(parms.shrinking);
    sb.append(",svm_type=");
    sb.append(parms.svm_type);
    sb.append(",weight=");
    if(parms.weight!=null) {
      for(int i = 0; i accs = new ArrayList<>();
        int nrCorrectAll = 0;
        int nrIncorrectAll = 0;
        for(int repeat=0; repeat-->
            
    
            


© 2015 - 2024 Weber Informatics LLC | Privacy Policy