All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gate.plugin.learningframework.export.CorpusExporterMRMatrixMarket2 Maven / Gradle / Ivy

Go to download

A GATE plugin that provides many different machine learning algorithms for a wide range of NLP-related machine learning tasks like text classification, tagging, or chunking.

There is a newer version: 4.2
Show newest version
/*
 * Copyright (c) 2015-2016 The University Of Sheffield.
 *
 * This file is part of gateplugin-LearningFramework 
 * (see https://github.com/GateNLP/gateplugin-LearningFramework).
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 2.1 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this software. If not, see .
 */
package gate.plugin.learningframework.export;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import gate.plugin.learningframework.LFUtils;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.engines.Info;
import gate.plugin.learningframework.features.FeatureExtractionMalletSparse;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.NominalTargetWithCosts;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.text.DecimalFormat;

/**
 * Export data as two files in MatrixMarket format.
 * This exports the attributes as indep.mtx and the targets as dep.mtx.
 * 
 * @author Johann Petrak
 */
public class CorpusExporterMRMatrixMarket2 extends CorpusExporterMR {

  @Override
  public Info getInfo() {
    Info info = new Info();
    // TODO: we should check for regression vs classification HERE
    info.algorithmClass = "gate.plugin.learningframework.engines.AlgorithmClassification";
    info.algorithmName = "DUMMY";
    info.engineClass = "DUMMY";
    info.modelClass =  "DUMMY";    
    return info;
  }

  @Override
  public void export() {
    exportMeta();
    CorpusRepresentationMallet crm = (CorpusRepresentationMallet)corpusRepresentation;
    PrintStream outDep = null;
    PrintStream outIndep = null;
    PrintStream outInstWeights = null;
    PrintStream outCosts = null;
    File outFileIndep = new File(dataDirFile, "indep.mtx");
    File outFileDep = new File(dataDirFile, "dep.mtx");
    File outFileCosts = new File(dataDirFile,"instcosts.mtx");
    File outFileInstWeights = new File(dataDirFile, "instweights.mtx");
    try {      
      if(corpusRepresentation.getTargetType() != TargetType.NONE) {
        outDep = new PrintStream(outFileDep);
      }
      outIndep = new PrintStream(outFileIndep);
    } catch (FileNotFoundException ex) {
      throw new GateRuntimeException("Could not open output file ",ex);
    }
    // NOTE: the following code is based on simple reverse engineering of the format
    // as generated by python scipy.io.mmwrite for float vectors.
    // Format description see http://math.nist.gov/MatrixMarket/formats.html
    
    InstanceList instances = crm.getRepresentationMallet();
    // for the MatrixMarket format we need the following information 
    // beforehand: the total number of non-zero values and the shape of 
    // the matrix (rows/columns).
    // The only way to find the total number of non-zero values is to actually
    // add them up over all the instances.
    int nrRows = instances.size();
    int nrCols = crm.getPipe().getDataAlphabet().size();
    int nrVals = 0;
    for(Instance instance : instances) {
      Object fvObj = instance.getData();
      if(fvObj instanceof FeatureVector) {
        FeatureVector fv = (FeatureVector)fvObj;
        nrVals += fv.numLocations();
      } else {
        throw new GateRuntimeException("Instance is not a feature vector but "+fvObj.getClass());
      }
    }
    System.err.println("DEBUG: rows="+nrRows+", cols="+nrCols+", vals="+nrVals);
    // TODO: check if these formats are correct!
    DecimalFormat DFf = new DecimalFormat("0.0#########");
    DecimalFormat DFi = new DecimalFormat("0");
    // write the headers of the two files
    // indep
    outIndep.println("%%MatrixMarket matrix coordinate real general\n%");
    outIndep.print(DFi.format(nrRows));
    outIndep.print(" ");
    outIndep.print(DFi.format(nrCols));
    outIndep.print(" ");
    outIndep.print(DFi.format(nrVals));
    outIndep.println();
    if(outDep != null) {
      // dep
      outDep.println("%%MatrixMarket matrix coordinate real general\n%");
      // TODO: we could actually also count the non-zero targets above and
      // not include the zeros!
      outDep.print(DFi.format(nrRows)); // Each row has one value, non-sparse
      outDep.print(" ");
      outDep.print(DFi.format(1));
      outDep.print(" ");
      outDep.print(DFi.format(nrRows));
      outDep.println();
    }
    // NOTE: MatrixMarket numbers are base-1!!
    int rowNr = 0;
    for(Instance instance : instances) {
      rowNr++;
      Boolean ignoreInstance = (Boolean)instance.getProperty(FeatureExtractionMalletSparse.PROP_IGNORE_HAS_MV);
      if(ignoreInstance != null && ignoreInstance) {
        continue;
      }
      // to export instance weights, we check the first instance if a weight is set: 
      // if yes, then a third file is created which will contain the weights for each instance
      Object instanceWeightObject = instance.getProperty("instanceWeight");
      if(rowNr==1) {
        if(instanceWeightObject !=null) {
          try {
            outInstWeights = new PrintStream(outFileInstWeights);
          } catch (FileNotFoundException ex) {
            throw new GateRuntimeException("Could not open output file "+outFileInstWeights,ex);
          }        
          outInstWeights.println("%%MatrixMarket matrix coordinate real general\n%");
          outInstWeights.print(DFi.format(nrRows)); // Each row has one value, non-sparse
          outInstWeights.print(" ");
          outInstWeights.print(DFi.format(1));
          outInstWeights.print(" ");
          outInstWeights.print(DFi.format(nrRows));
          outInstWeights.println();        
        } else {
          outFileInstWeights.delete();
        }
      }
      if(outInstWeights!=null) {
        double weight = LFUtils.anyToDoubleOrElse(instanceWeightObject, 1.0);
        outInstWeights.print(rowNr);
        outInstWeights.print(" ");
        outInstWeights.print("1 ");
        outInstWeights.println(DFf.format(weight));
      }
      Boolean haveMV = (Boolean)instance.getProperty(FeatureExtractionMalletSparse.PROP_HAVE_MV);      
      Object targetObj = instance.getTarget();
      double target = 0.0;
      if(targetObj == null) {
        // TODO: NOTE: we export instances with missing targets in the ARFF exporter, so 
        // we do it here too. However, other than there we always have to output a target
        // because the format requires it. We therefore always output target 0.0 
        // which is the default value, so nothing to do        
      } else if (targetObj instanceof Double) {
        target = (Double)targetObj;
      } else if(targetObj instanceof Label) {
        Label label = (Label)targetObj;
        target = label.getIndex();
        // TODO: if we have row 1 and we find that the entry of this label is of type LabelWithCosts,
        // then open yet another output file for exporting the per-instance costss.
        // Then if that file is open, write subsequent costs!
        if(rowNr==1) {
          if(label.getEntry() instanceof NominalTargetWithCosts) {
            NominalTargetWithCosts lwc = (NominalTargetWithCosts)label.getEntry();
            try {
              outCosts = new PrintStream(outFileCosts);
            } catch (FileNotFoundException ex) {
              throw new GateRuntimeException("Could not open output file "+outFileCosts,ex);
            }        
            outCosts.println("%%MatrixMarket matrix coordinate real general\n%");
            outCosts.print(DFi.format(nrRows)); // Each row has one value, non-sparse
            outCosts.print(" ");
            outCosts.print(DFi.format(lwc.getCosts().length));
            outCosts.print(" ");
            outCosts.print(DFi.format(nrRows*lwc.getCosts().length));
            outCosts.println();                    
          } else {
            outFileCosts.delete();
          }
        }
        if(outCosts != null) {
          NominalTargetWithCosts lwc = (NominalTargetWithCosts)label.getEntry();
          double[] costs = lwc.getCosts();
          for(int i=0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy