org.itc.irst.tcc.sre.Train Maven / Gradle / Ivy
/*
* Copyright 2005 FBK-irst (http://www.fbk.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.itc.irst.tcc.sre;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Iterator;
import java.util.Properties;
import org.itc.irst.tcc.sre.data.ArgumentSet;
import org.itc.irst.tcc.sre.data.ExampleSet;
import org.itc.irst.tcc.sre.data.SentenceSetCopy;
import org.itc.irst.tcc.sre.kernel.expl.Mapping;
import org.itc.irst.tcc.sre.kernel.expl.MappingFactory;
import org.itc.irst.tcc.sre.util.FeatureIndex;
import org.itc.irst.tcc.sre.util.Vector;
import org.itc.irst.tcc.sre.util.ZipModel;
import org.itc.irst.tcc.sre.util.svm_train;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* TO DO
*
* @author Claudio Giuliano
* @version %I%, %G%
* @since 1.0
*/
public class Train
{
/**
* Define a static logger variable so that it references the
* Logger instance named Train
.
*/
static Logger logger = LoggerFactory.getLogger(Train.class.getName());
/*
//
public static final String GLOBAL_CONTEXT_KERNEL = "GC";
//
public static final String LOCAL_CONTEXT_KERNEL = "LC";
//
public static final String SHALLOW_LINGUISTIC_KERNEL = "SL";
//
private int relationType;
*/
//
public static final int MAX_NUMBER_OF_CLASSES = 20;
//
private Properties parameter;
//
public Train(Properties parameter)
{
this.parameter = parameter;
} // end constructor
//
public void run() throws Exception
{
logger.info("train a relation extraction model");
// create zip archive
//ZipModel model = new ZipModel(parameter.modelFile());
File modelFile = new File(parameter.getProperty("model-file"));
ZipModel model = new ZipModel(modelFile);
// read data set
//ExampleSet inputSet = readDataSet(parameter.inputFile());
File inputFile = new File(parameter.getProperty("example-file"));
ExampleSet inputSet = readDataSet(inputFile);
logger.info("input training set size: " + inputSet.size());
// get the class freq
int[] freq = classFreq(inputSet);
// calculate the class weight
double[] weight = classWeigth(freq);
// find argument types
ArgumentSet.getInstance().init(inputSet);
// set the relation type
int count = inputSet.getClassCount();
//setRelationType(count);
logger.debug("number of classes: " + count);
//logger.info("learn " + (relationType == DIRECTED_RELATION ? "directed" : "undirected") + " relations (" + relationType + ")");
// create the mapping factory
MappingFactory mappingFactory = MappingFactory.getMappingFactory();
Mapping mapping = mappingFactory.getInstance(parameter.getProperty("kernel-type"));
// set the command line parameters
mapping.setParameters(parameter);
// get the number of subspaces
int subspaceCount = mapping.subspaceCount();
logger.debug("number of subspaces: " + subspaceCount);
// create the index
FeatureIndex[] index = createFeatureIndex(subspaceCount);
// embed the input data into a feature space
logger.info("embed the training set");
ExampleSet outputSet = mapping.map(inputSet, index);
logger.debug("embedded training set size: " + outputSet.size());
// if not specified, calculate SVM parameter C
double c = calculateC(outputSet);
logger.info("cost parameter C = " + c);
// save the training set
File training = saveExampleSet(outputSet, model);
// save the indexes
saveFeatureIndexes(index, model);
// train the svm
svmTrain(training, c, weight, model);
// save param
saveParameters(model);
// close the model
model.close();
} // end run
// read the data set
private ExampleSet readDataSet(File in) throws IOException
{
logger.info("read the example set");
//
ExampleSet inputSet = new SentenceSetCopy();
inputSet.read(new BufferedReader(new FileReader(in)));
String trainFrac = parameter.getProperty("train-frac");
if (trainFrac != null)
{
double f = Double.parseDouble(trainFrac);
logger.info("training original size: " + inputSet.size());
logger.info("training fraction: " + (100 * f) + "%");
return inputSet.subSet(0, (int) (inputSet.size() * f));
}
return inputSet;
} // end readDataSet
/*
// get the feature mapping function
private AbstractMapping mappingFactory() throws KernelNotFoundException
{
logger.info("get the feature mapping function");
// kernel factory
AbstractMapping mapping = null;
String kernelType = parameter.getProperty("kernel-type").toUpperCase();
if (kernelType.equals(GLOBAL_CONTEXT_KERNEL))
mapping = new GlobalContextMapping();
else if (kernelType.equals(LOCAL_CONTEXT_KERNEL))
mapping = new LocalContextMapping();
else if (kernelType.equals(SHALLOW_LINGUISTIC_KERNEL))
mapping = new ShallowLinguisticMapping();
else
throw new KernelNotFoundException(kernelType + " not found.");
return mapping;
} // end mappingFactory
*/
// calculate parameter C of SVM
//
// To allow some flexibility in separating the categories,
// SVM models have a cost parameter, C, that controls the
// trade off between allowing training errors and forcing
// rigid margins. It creates a soft margin that permits
// some misclassifications. Increasing the value of C
// increases the cost of misclassifying points and forces
// the creation of a more accurate model that may not
// generalize well
private double calculateC(ExampleSet data) //throws Exception
{
String svmCost = parameter.getProperty("svm-cost");
if (svmCost != null)
return Integer.parseInt(svmCost);
logger.info("calculate default SVM cost parameter C");
//double c = 1;
double avr = 0;
// the example set is normalized
// all vectors have the same norm
for (int i=0;i file with training data (SRE format)\n");
sb.append("\tmodel-file\t-> file in which to store resulting model\n");
sb.append("Options:\n");
sb.append("\t-h\t\t-> this help\n");
sb.append("\t-k string\t-> set type of kernel function (default SL):\n");
sb.append("\t\t\t\tLC: Local Context Kernel\n");
sb.append("\t\t\t\tGC: Global Context Kernel\n");
sb.append("\t\t\t\tSL: Shallow Linguistic Context Kernel\n");
sb.append("\t-n [1..]\t-> set the parameter n-gram of kernels SL and GC (default 3)\n");
sb.append("\t-w [0..]\t-> set the window size of kernel LC (default 2)\n");
sb.append("\t-c [0..]\t-> set the trade-off between training error and margin (default 1/[avg. x*x'])\n");
sb.append("\t-f\t-> fraction of training set (default 1)\n");
sb.append("\t-m int\t\t-> set cache memory size in MB (default 128)\n");
return sb.toString();
} // end getHelp
//
public String toString()
{
StringWriter sw = new StringWriter();
list(new PrintWriter(sw));
return sw.toString();
} // end toString
//
class IllegalParameterException extends IllegalArgumentException
{
public IllegalParameterException(String s)
{
super(s);
} // end constructor
} // end IllegalParameterException
} // end class CommandLineParameters
} // end class Train
© 2015 - 2025 Weber Informatics LLC | Privacy Policy