gate.plugin.learningframework.export.CorpusExporterMRMatrixMarket2 Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of learningframework Show documentation
Show all versions of learningframework Show documentation
A GATE plugin that provides many different machine learning
algorithms for a wide range of NLP-related machine learning tasks like
text classification, tagging, or chunking.
/*
* Copyright (c) 2015-2016 The University Of Sheffield.
*
* This file is part of gateplugin-LearningFramework
* (see https://github.com/GateNLP/gateplugin-LearningFramework).
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 2.1 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this software. If not, see .
*/
package gate.plugin.learningframework.export;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import gate.plugin.learningframework.LFUtils;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.engines.Info;
import gate.plugin.learningframework.features.FeatureExtractionMalletSparse;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.NominalTargetWithCosts;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.text.DecimalFormat;
/**
* Export data as two files in MatrixMarket format.
* This exports the attributes as indep.mtx and the targets as dep.mtx.
*
* @author Johann Petrak
*/
public class CorpusExporterMRMatrixMarket2 extends CorpusExporterMR {
@Override
public Info getInfo() {
Info info = new Info();
// TODO: we should check for regression vs classification HERE
info.algorithmClass = "gate.plugin.learningframework.engines.AlgorithmClassification";
info.algorithmName = "DUMMY";
info.engineClass = "DUMMY";
info.modelClass = "DUMMY";
return info;
}
@Override
public void export() {
exportMeta();
CorpusRepresentationMallet crm = (CorpusRepresentationMallet)corpusRepresentation;
PrintStream outDep = null;
PrintStream outIndep = null;
PrintStream outInstWeights = null;
PrintStream outCosts = null;
File outFileIndep = new File(dataDirFile, "indep.mtx");
File outFileDep = new File(dataDirFile, "dep.mtx");
File outFileCosts = new File(dataDirFile,"instcosts.mtx");
File outFileInstWeights = new File(dataDirFile, "instweights.mtx");
try {
if(corpusRepresentation.getTargetType() != TargetType.NONE) {
outDep = new PrintStream(outFileDep);
}
outIndep = new PrintStream(outFileIndep);
} catch (FileNotFoundException ex) {
throw new GateRuntimeException("Could not open output file ",ex);
}
// NOTE: the following code is based on simple reverse engineering of the format
// as generated by python scipy.io.mmwrite for float vectors.
// Format description see http://math.nist.gov/MatrixMarket/formats.html
InstanceList instances = crm.getRepresentationMallet();
// for the MatrixMarket format we need the following information
// beforehand: the total number of non-zero values and the shape of
// the matrix (rows/columns).
// The only way to find the total number of non-zero values is to actually
// add them up over all the instances.
int nrRows = instances.size();
int nrCols = crm.getPipe().getDataAlphabet().size();
int nrVals = 0;
for(Instance instance : instances) {
Object fvObj = instance.getData();
if(fvObj instanceof FeatureVector) {
FeatureVector fv = (FeatureVector)fvObj;
nrVals += fv.numLocations();
} else {
throw new GateRuntimeException("Instance is not a feature vector but "+fvObj.getClass());
}
}
System.err.println("DEBUG: rows="+nrRows+", cols="+nrCols+", vals="+nrVals);
// TODO: check if these formats are correct!
DecimalFormat DFf = new DecimalFormat("0.0#########");
DecimalFormat DFi = new DecimalFormat("0");
// write the headers of the two files
// indep
outIndep.println("%%MatrixMarket matrix coordinate real general\n%");
outIndep.print(DFi.format(nrRows));
outIndep.print(" ");
outIndep.print(DFi.format(nrCols));
outIndep.print(" ");
outIndep.print(DFi.format(nrVals));
outIndep.println();
if(outDep != null) {
// dep
outDep.println("%%MatrixMarket matrix coordinate real general\n%");
// TODO: we could actually also count the non-zero targets above and
// not include the zeros!
outDep.print(DFi.format(nrRows)); // Each row has one value, non-sparse
outDep.print(" ");
outDep.print(DFi.format(1));
outDep.print(" ");
outDep.print(DFi.format(nrRows));
outDep.println();
}
// NOTE: MatrixMarket numbers are base-1!!
int rowNr = 0;
for(Instance instance : instances) {
rowNr++;
Boolean ignoreInstance = (Boolean)instance.getProperty(FeatureExtractionMalletSparse.PROP_IGNORE_HAS_MV);
if(ignoreInstance != null && ignoreInstance) {
continue;
}
// to export instance weights, we check the first instance if a weight is set:
// if yes, then a third file is created which will contain the weights for each instance
Object instanceWeightObject = instance.getProperty("instanceWeight");
if(rowNr==1) {
if(instanceWeightObject !=null) {
try {
outInstWeights = new PrintStream(outFileInstWeights);
} catch (FileNotFoundException ex) {
throw new GateRuntimeException("Could not open output file "+outFileInstWeights,ex);
}
outInstWeights.println("%%MatrixMarket matrix coordinate real general\n%");
outInstWeights.print(DFi.format(nrRows)); // Each row has one value, non-sparse
outInstWeights.print(" ");
outInstWeights.print(DFi.format(1));
outInstWeights.print(" ");
outInstWeights.print(DFi.format(nrRows));
outInstWeights.println();
} else {
outFileInstWeights.delete();
}
}
if(outInstWeights!=null) {
double weight = LFUtils.anyToDoubleOrElse(instanceWeightObject, 1.0);
outInstWeights.print(rowNr);
outInstWeights.print(" ");
outInstWeights.print("1 ");
outInstWeights.println(DFf.format(weight));
}
Boolean haveMV = (Boolean)instance.getProperty(FeatureExtractionMalletSparse.PROP_HAVE_MV);
Object targetObj = instance.getTarget();
double target = 0.0;
if(targetObj == null) {
// TODO: NOTE: we export instances with missing targets in the ARFF exporter, so
// we do it here too. However, other than there we always have to output a target
// because the format requires it. We therefore always output target 0.0
// which is the default value, so nothing to do
} else if (targetObj instanceof Double) {
target = (Double)targetObj;
} else if(targetObj instanceof Label) {
Label label = (Label)targetObj;
target = label.getIndex();
// TODO: if we have row 1 and we find that the entry of this label is of type LabelWithCosts,
// then open yet another output file for exporting the per-instance costss.
// Then if that file is open, write subsequent costs!
if(rowNr==1) {
if(label.getEntry() instanceof NominalTargetWithCosts) {
NominalTargetWithCosts lwc = (NominalTargetWithCosts)label.getEntry();
try {
outCosts = new PrintStream(outFileCosts);
} catch (FileNotFoundException ex) {
throw new GateRuntimeException("Could not open output file "+outFileCosts,ex);
}
outCosts.println("%%MatrixMarket matrix coordinate real general\n%");
outCosts.print(DFi.format(nrRows)); // Each row has one value, non-sparse
outCosts.print(" ");
outCosts.print(DFi.format(lwc.getCosts().length));
outCosts.print(" ");
outCosts.print(DFi.format(nrRows*lwc.getCosts().length));
outCosts.println();
} else {
outFileCosts.delete();
}
}
if(outCosts != null) {
NominalTargetWithCosts lwc = (NominalTargetWithCosts)label.getEntry();
double[] costs = lwc.getCosts();
for(int i=0;i