All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.itc.irst.tcc.sre.Map Maven / Gradle / Ivy

/*
 * Copyright 2005 FBK-irst (http://www.fbk.eu)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.itc.irst.tcc.sre;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Iterator;
import java.util.Properties;

import org.itc.irst.tcc.sre.data.ArgumentSet;
import org.itc.irst.tcc.sre.data.ExampleSet;
import org.itc.irst.tcc.sre.data.SentenceSetCopy;
import org.itc.irst.tcc.sre.kernel.expl.Mapping;
import org.itc.irst.tcc.sre.kernel.expl.MappingFactory;
import org.itc.irst.tcc.sre.util.FeatureIndex;
import org.itc.irst.tcc.sre.util.Vector;
import org.itc.irst.tcc.sre.util.svm_train;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * TO DO
 *
 * @author 	Claudio Giuliano
 * @version %I%, %G%
 * @since		1.0
 */
public class Map
{
	/**
	 * Define a static logger variable so that it references the
	 * Logger instance named Map.
	 */
	static Logger logger = LoggerFactory.getLogger(Map.class.getName());

	//
	public static final int MAX_NUMBER_OF_CLASSES = 20;

	//
	private Properties parameter;

	//
	public Map(Properties parameter)
	{
		this.parameter = parameter;
	} // end constructor

	//
	public void run() throws Exception
	{
		logger.info("Map a relation extraction model");

		// read training set
		File trainingFile = new File(parameter.getProperty("train-file"));
		ExampleSet trainingSet = readDataSet(trainingFile);
		logger.info("input training set size: " + trainingSet.size());

		// read test set
		File testFile = new File(parameter.getProperty("test-file"));
		ExampleSet testSet = readDataSet(testFile);
		logger.info("input test set size: " + testSet.size());


		// get the class freq
		int[] freq = classFreq(trainingSet);

		// calculate the class weight
		double[] weight = classWeigth(freq);

		// find argument types
		ArgumentSet.getInstance().init(trainingSet);

		// set the relation type
		int count = trainingSet.getClassCount();
		//setRelationType(count);

		logger.debug("number of classes: " + count);
		//logger.info("learn " + (relationType == DIRECTED_RELATION ? "directed" : "undirected") + " relations (" + relationType + ")");

		// create the mapping factory
		MappingFactory mappingFactory = MappingFactory.getMappingFactory();
		Mapping mapping = mappingFactory.getInstance(parameter.getProperty("kernel-type"));

		// set the command line parameters
		mapping.setParameters(parameter);

		// get the number of subspaces
		int subspaceCount = mapping.subspaceCount();
		logger.debug("number of subspaces: " + subspaceCount);

		// create the index
		FeatureIndex[] index = createFeatureIndex(subspaceCount);

		// embed the input data into a feature space
		logger.info("embed the training set");
		ExampleSet outputSet = mapping.map(trainingSet, index);
		logger.debug("embedded training set size: " + outputSet.size());

		// if not specified, calculate SVM parameter C
		double c = calculateC(outputSet);
		logger.info("cost parameter C = " + c);

		// save the training set
		File training = saveExampleSet(outputSet);

		// save the indexes
		saveFeatureIndexes(index);

		// save param
		saveParameters();

		// Map the svm
		svmTrain(training, c, weight);


	} // end run

	// read the data set
	private ExampleSet readDataSet(File in) throws IOException
	{
		logger.info("read the example set");

		//
		ExampleSet trainingSet = new SentenceSetCopy();
		trainingSet.read(new BufferedReader(new FileReader(in)));

		String trainFrac = parameter.getProperty("Map-frac");
		if (trainFrac != null)
		{
			double f = Double.parseDouble(trainFrac);
			logger.info("training original size: " + trainingSet.size());
			logger.info("training fraction: " + (100 * f) + "%");
			return trainingSet.subSet(0, (int) (trainingSet.size() * f));
		}

		return trainingSet;
	}	// end readDataSet
/*
	// get the feature mapping function
	private AbstractMapping mappingFactory() throws KernelNotFoundException
	{
		logger.info("get the feature mapping function");

		// kernel factory
		AbstractMapping mapping = null;

		String kernelType = parameter.getProperty("kernel-type").toUpperCase();
		if (kernelType.equals(GLOBAL_CONTEXT_KERNEL))
			mapping = new GlobalContextMapping();
		else if (kernelType.equals(LOCAL_CONTEXT_KERNEL))
			mapping = new LocalContextMapping();
		else if (kernelType.equals(SHALLOW_LINGUISTIC_KERNEL))
			mapping = new ShallowLinguisticMapping();
		else
			throw new KernelNotFoundException(kernelType + " not found.");

		return mapping;
	} // end mappingFactory
*/

	// calculate parameter C of SVM
	//
	// To allow some flexibility in separating the categories,
	// SVM models have a cost parameter, C, that controls the
	// trade off between allowing training errors and forcing
	// rigid margins. It creates a soft margin that permits
	// some misclassifications. Increasing the value of C
	// increases the cost of misclassifying points and forces
	// the creation of a more accurate model that may not
	// generalize well
	private double calculateC(ExampleSet data) //throws Exception
	{
		String svmCost = parameter.getProperty("svm-cost");
		if (svmCost != null)
			return Integer.parseInt(svmCost);

		logger.info("calculate default SVM cost parameter C");

		//double c = 1;
		double avr = 0;

		// the example set is normalized
		// all vectors have the same norm
		for (int i=0;i file with training data (SRE format)\n");
			sb.append("\ttest-file\t-> file with training data (SRE format)\n");

			sb.append("Options:\n");
			sb.append("\t-h\t\t-> this help\n");
			sb.append("\t-k string\t-> set type of kernel function (default SL):\n");
			sb.append("\t\t\t\tLC: Local Context Kernel\n");
			sb.append("\t\t\t\tGC: Global Context Kernel\n");
			sb.append("\t\t\t\tSL: Shallow Linguistic Context Kernel\n");

			sb.append("\t-n [1..]\t-> set the parameter n-gram of kernels SL and GC  (default 3)\n");
			sb.append("\t-w [0..]\t-> set the window size of kernel LC (default 2)\n");
			sb.append("\t-c [0..]\t-> set the trade-off between training error and margin (default 1/[avg. x*x'])\n");

			sb.append("\t-f\t-> fraction of training set (default 1)\n");
			sb.append("\t-m int\t\t-> set cache memory size in MB (default 128)\n");

			return sb.toString();
		} // end getHelp

		//
		public String toString()
		{
			StringWriter sw = new StringWriter();
			list(new PrintWriter(sw));

			return sw.toString();
		} // end toString

		//
		class IllegalParameterException extends IllegalArgumentException
		{
			public IllegalParameterException(String s)
			{
				super(s);
			} // end constructor

		} // end IllegalParameterException

	} // end class CommandLineParameters

} // end class Map




© 2015 - 2025 Weber Informatics LLC | Privacy Policy