All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.maltparserx.ml.lib.Lib Maven / Gradle / Ivy

package org.maltparserx.ml.lib;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;

import java.io.OutputStreamWriter;
import java.util.ArrayList;

import org.apache.log4j.Logger;

import java.util.LinkedHashMap;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;


import org.maltparserx.core.exception.MaltChainedException;
import org.maltparserx.core.feature.FeatureVector;
import org.maltparserx.core.feature.function.FeatureFunction;
import org.maltparserx.core.feature.value.FeatureValue;
import org.maltparserx.core.feature.value.MultipleFeatureValue;
import org.maltparserx.core.feature.value.SingleFeatureValue;
import org.maltparserx.core.syntaxgraph.DependencyStructure;
import org.maltparserx.ml.LearningMethod;
import org.maltparserx.parser.DependencyParserConfig;
import org.maltparserx.parser.guide.instance.InstanceModel;
import org.maltparserx.parser.history.action.SingleDecision;

public abstract class Lib implements LearningMethod {
	protected Verbostity verbosity;
	public enum Verbostity {
		SILENT, ERROR, ALL
	}
	protected InstanceModel owner;
	protected int learnerMode;
	protected String name;
	protected int numberOfInstances;
	protected boolean saveInstanceFiles;
	protected boolean excludeNullValues;
	protected BufferedWriter instanceOutput = null; 
	protected FeatureMap featureMap;
	protected String paramString;
	protected String pathExternalTrain;
	protected LinkedHashMap libOptions;
	protected String allowedLibOptionFlags;
	protected Logger configLogger;
	protected final Pattern tabPattern = Pattern.compile("\t");
	protected final Pattern pipePattern = Pattern.compile("\\|");	
	private final StringBuilder sb = new StringBuilder();
	protected MaltLibModel model = null;
	/**
	 * Constructs a Lib learner.
	 * 
	 * @param owner the guide model owner
	 * @param learnerMode the mode of the learner BATCH or CLASSIFY
	 */
	public Lib(InstanceModel owner, Integer learnerMode, String learningMethodName) throws MaltChainedException {
		setOwner(owner);
		setLearnerMode(learnerMode.intValue());
		setNumberOfInstances(0);
		setLearningMethodName(learningMethodName);
		verbosity = Verbostity.SILENT;
		configLogger = owner.getGuide().getConfiguration().getConfigLogger();
		initLibOptions();
		initAllowedLibOptionFlags();
		parseParameters(getConfiguration().getOptionValue("lib", "options").toString());
		initSpecialParameters();
		
		if (learnerMode == BATCH) {
			featureMap = new FeatureMap();
			instanceOutput = new BufferedWriter(getInstanceOutputStreamWriter(".ins"));
		} else if (learnerMode == CLASSIFY) {
			featureMap = loadFeatureMap(getInputStreamFromConfigFileEntry(".map"));
		}
	}
	
	
	public void addInstance(SingleDecision decision, FeatureVector featureVector) throws MaltChainedException {
		if (featureVector == null) {
			throw new LibException("The feature vector cannot be found");
		} else if (decision == null) {
			throw new LibException("The decision cannot be found");
		}	
		
		try {
			sb.append(decision.getDecisionCode()+"\t");
			final int n = featureVector.size();
			for (int i = 0; i < n; i++) {
				FeatureValue featureValue = featureVector.getFeatureValue(i);
				if (featureValue == null || (excludeNullValues == true && featureValue.isNullValue())) {
					sb.append("-1");
				} else {
					if (!featureValue.isMultiple()) {
						SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
						if (singleFeatureValue.getValue() == 1) {
							sb.append(singleFeatureValue.getIndexCode());
						} else if (singleFeatureValue.getValue() == 0) {
							sb.append("-1");
						} else {
							sb.append(singleFeatureValue.getIndexCode());
							sb.append(":");
							sb.append(singleFeatureValue.getValue());
						}
					} else { //if (featureValue instanceof MultipleFeatureValue) {
						Set values = ((MultipleFeatureValue)featureValue).getCodes();
						int j=0;
						for (Integer value : values) {
							sb.append(value.toString());
							if (j != values.size()-1) {
								sb.append("|");
							}
							j++;
						}
					}
//					else {
//						throw new LibException("Don't recognize the type of feature value: "+featureValue.getClass());
//					}
				}
				sb.append('\t');
			}
			sb.append('\n');
			instanceOutput.write(sb.toString());
			instanceOutput.flush();
			increaseNumberOfInstances();
			sb.setLength(0);
		} catch (IOException e) {
			throw new LibException("The learner cannot write to the instance file. ", e);
		}
	}

	public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { }

	public void moveAllInstances(LearningMethod method,
			FeatureFunction divideFeature,
			ArrayList divideFeatureIndexVector)
			throws MaltChainedException { 
		if (method == null) {
			throw new LibException("The learning method cannot be found. ");
		} else if (divideFeature == null) {
			throw new LibException("The divide feature cannot be found. ");
		} 
		
		try {
			final BufferedReader in = new BufferedReader(getInstanceInputStreamReader(".ins"));
			final BufferedWriter out = method.getInstanceWriter();
			final StringBuilder sb = new StringBuilder(6);
			int l = in.read();
			char c;
			int j = 0;
	
			while(true) {
				if (l == -1) {
					sb.setLength(0);
					break;
				}
				c = (char)l; 
				l = in.read();
				if (c == '\t') {
					if (divideFeatureIndexVector.contains(j-1)) {
						out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
						out.write('\t');
					}
					out.write(sb.toString());
					j++;
					out.write('\t');
					sb.setLength(0);
				} else if (c == '\n') {
					out.write(sb.toString());
					if (divideFeatureIndexVector.contains(j-1)) {
						out.write('\t');
						out.write(Integer.toString(((SingleFeatureValue)divideFeature.getFeatureValue()).getIndexCode()));
					}
					out.write('\n');
					sb.setLength(0);
					method.increaseNumberOfInstances();
					this.decreaseNumberOfInstances();
					j = 0;
				} else {
					sb.append(c);
				}
			}	
			in.close();
			getFile(".ins").delete();
			out.flush();
		} catch (SecurityException e) {
			throw new LibException("The learner cannot remove the instance file. ", e);
		} catch (NullPointerException  e) {
			throw new LibException("The instance file cannot be found. ", e);
		} catch (FileNotFoundException e) {
			throw new LibException("The instance file cannot be found. ", e);
		} catch (IOException e) {
			throw new LibException("The learner read from the instance file. ", e);
		}
	}

	public void noMoreInstances() throws MaltChainedException { 
		closeInstanceWriter();
	}

	public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
//		if (featureVector == null) {
//			throw new LibException("The learner cannot predict the next class, because the feature vector cannot be found. ");
//		}
		final FeatureList featureList = new FeatureList();
		final int size = featureVector.size();
		for (int i = 1; i <= size; i++) {
			final FeatureValue featureValue = featureVector.getFeatureValue(i-1);	
			if (featureValue != null && !(excludeNullValues == true && featureValue.isNullValue())) {
				if (!featureValue.isMultiple()) {
					SingleFeatureValue singleFeatureValue = (SingleFeatureValue)featureValue;
					final int index = featureMap.getIndex(i, singleFeatureValue.getIndexCode());
					if (index != -1 && singleFeatureValue.getValue() != 0) {
						featureList.add(index,singleFeatureValue.getValue());
					}
				} 
				else { //if (featureValue instanceof MultipleFeatureValue) {
					for (Integer value : ((MultipleFeatureValue)featureValue).getCodes()) {
						final int v = featureMap.getIndex(i, value);
						if (v != -1) {
							featureList.add(v,1);
						}
					}
				} 
			}
		}
		try {
			decision.getKBestList().addList(model.predict(featureList.toArray()));
		} catch (OutOfMemoryError e) {
			throw new LibException("Out of memory. Please increase the Java heap size (-Xmx). ", e);
		}
		return true;
	}
		
//	protected abstract int[] prediction(FeatureList featureList) throws MaltChainedException;
	
	public void train(FeatureVector featureVector) throws MaltChainedException { 
		if (featureVector == null) {
			throw new LibException("The feature vector cannot be found. ");
		} else if (owner == null) {
			throw new LibException("The parent guide model cannot be found. ");
		}
		long startTime = System.currentTimeMillis();
		
//		if (configLogger.isInfoEnabled()) {
//			configLogger.info("\nStart training\n");
//		}
		if (pathExternalTrain != null) {
			trainExternal(featureVector);
		} else {
			trainInternal(featureVector);
		}
//		long elapsed = System.currentTimeMillis() - startTime;
//		if (configLogger.isInfoEnabled()) {
//			configLogger.info("Time 1: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
//		}
		try {
//			if (configLogger.isInfoEnabled()) {
//				configLogger.info("\nSaving feature map "+getFile(".map").getName()+"\n");
//			}
			saveFeatureMap(new BufferedOutputStream(new FileOutputStream(getFile(".map").getAbsolutePath())), featureMap);
		} catch (FileNotFoundException e) {
			throw new LibException("The learner cannot save the feature map file '"+getFile(".map").getAbsolutePath()+"'. ", e);
		}
//		elapsed = System.currentTimeMillis() - startTime;
//		if (configLogger.isInfoEnabled()) {
//			configLogger.info("Time 2: " +new Formatter().format("%02d:%02d:%02d", elapsed/3600000, elapsed%3600000/60000, elapsed%60000/1000)+" ("+elapsed+" ms)\n");
//		}
	}
	protected abstract void trainExternal(FeatureVector featureVector) throws MaltChainedException;
	protected abstract void trainInternal(FeatureVector featureVector) throws MaltChainedException;
	
	public void terminate() throws MaltChainedException { 
		closeInstanceWriter();
		owner = null;
		model = null;
	}

	public BufferedWriter getInstanceWriter() {
		return instanceOutput;
	}
	
	protected void closeInstanceWriter() throws MaltChainedException {
		try {
			if (instanceOutput != null) {
				instanceOutput.flush();
				instanceOutput.close();
				instanceOutput = null;
			}
		} catch (IOException e) {
			throw new LibException("The learner cannot close the instance file. ", e);
		}
	}
	
	
	/**
	 * Returns the parameter string used for configure the learner
	 * 
	 * @return the parameter string used for configure the learner
	 */
	public String getParamString() {
		return paramString;
	}
	
	public InstanceModel getOwner() {
		return owner;
	}

	protected void setOwner(InstanceModel owner) {
		this.owner = owner;
	}
	
	public int getLearnerMode() {
		return learnerMode;
	}

	public void setLearnerMode(int learnerMode) throws MaltChainedException {
		this.learnerMode = learnerMode;
	}
	
	public String getLearningMethodName() {
		return name;
	}
	
	/**
	 * Returns the current configuration
	 * 
	 * @return the current configuration
	 * @throws MaltChainedException
	 */
	public DependencyParserConfig getConfiguration() throws MaltChainedException {
		return owner.getGuide().getConfiguration();
	}
	
	public int getNumberOfInstances() throws MaltChainedException {
		if(numberOfInstances!=0)
			return numberOfInstances;
		else{
			BufferedReader reader = new BufferedReader( getInstanceInputStreamReader(".ins"));
			try {
				while(reader.readLine()!=null){
					numberOfInstances++;
					owner.increaseFrequency();
				}
				reader.close();
			} catch (IOException e) {
				throw new MaltChainedException("No instances found in file",e);
			}
			return numberOfInstances;
		}
	}

	public void increaseNumberOfInstances() {
		numberOfInstances++;
		owner.increaseFrequency();
	}
	
	public void decreaseNumberOfInstances() {
		numberOfInstances--;
		owner.decreaseFrequency();
	}
	
	protected void setNumberOfInstances(int numberOfInstances) {
		this.numberOfInstances = 0;
	}

	protected void setLearningMethodName(String name) {
		this.name = name;
	}
	
	public String getPathExternalTrain() {
		return pathExternalTrain;
	}


	public void setPathExternalTrain(String pathExternalTrain) {
		this.pathExternalTrain = pathExternalTrain;
	}

	protected OutputStreamWriter getInstanceOutputStreamWriter(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getAppendOutputStreamWriter(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	protected InputStreamReader getInstanceInputStreamReader(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getInputStreamReader(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	protected InputStreamReader getInstanceInputStreamReaderFromConfigFile(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getInputStreamReaderFromConfigFile(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	protected InputStream getInputStreamFromConfigFileEntry(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getInputStreamFromConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	
	protected File getFile(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getFile(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	protected JarEntry getConfigFileEntry(String suffix) throws MaltChainedException {
		return getConfiguration().getConfigurationDir().getConfigFileEntry(owner.getModelName()+getLearningMethodName()+suffix);
	}
	
	protected void initSpecialParameters() throws MaltChainedException {
		if (getConfiguration().getOptionValue("singlemalt", "null_value") != null && getConfiguration().getOptionValue("singlemalt", "null_value").toString().equalsIgnoreCase("none")) {
			excludeNullValues = true;
		} else {
			excludeNullValues = false;
		}
		saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
		if (!getConfiguration().getOptionValue("lib", "external").toString().equals("")) {
			String path = getConfiguration().getOptionValue("lib", "external").toString(); 
			try {
				if (!new File(path).exists()) {
					throw new LibException("The path to the external  trainer 'svm-train' is wrong.");
				}
				if (new File(path).isDirectory()) {
					throw new LibException("The option --lib-external points to a directory, the path should point at the 'train' file or the 'train.exe' file in the libsvm or the liblinear package");
				}
				if (!(path.endsWith("train") ||path.endsWith("train.exe"))) {
					throw new LibException("The option --lib-external does not specify the path to 'train' file or the 'train.exe' file in the libsvm or the liblinear package. ");
				}
				setPathExternalTrain(path);
			} catch (SecurityException e) {
				throw new LibException("Access denied to the file specified by the option --lib-external. ", e);
			}
		}
		if (getConfiguration().getOptionValue("lib", "verbosity") != null) {
			verbosity = Verbostity.valueOf(getConfiguration().getOptionValue("lib", "verbosity").toString().toUpperCase());
		}
	}
	
	public String getLibOptions() {
		final StringBuilder sb = new StringBuilder();
		for (String key : libOptions.keySet()) {
			sb.append('-');
			sb.append(key);
			sb.append(' ');
			sb.append(libOptions.get(key));
			sb.append(' ');
		}
		return sb.toString();
	}
	
	public String[] getLibParamStringArray() {
		final ArrayList params = new ArrayList();

		for (String key : libOptions.keySet()) {
			params.add("-"+key); params.add(libOptions.get(key));
		}
		return params.toArray(new String[params.size()]);
	}
	
	public abstract void initLibOptions();
	public abstract void initAllowedLibOptionFlags();
	
	public void parseParameters(String paramstring) throws MaltChainedException {
		if (paramstring == null) {
			return;
		}
		final String[] argv;
		try {
			argv = paramstring.split("[_\\p{Blank}]");
		} catch (PatternSyntaxException e) {
			throw new LibException("Could not split the parameter string '"+paramstring+"'. ", e);
		}
		for (int i=0; i < argv.length-1; i++) {
			if(argv[i].charAt(0) != '-') {
				throw new LibException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
			}
			if(++i>=argv.length) {
				throw new LibException("The last argument does not have any value. ");
			}
			try {
				int index = allowedLibOptionFlags.indexOf(argv[i-1].charAt(1));
				if (index != -1) {
					libOptions.put(Character.toString(argv[i-1].charAt(1)), argv[i]);
				} else {
					throw new LibException("Unknown learner parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");		
				}
			} catch (ArrayIndexOutOfBoundsException e) {
				throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);
			} catch (NumberFormatException e) {
				throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);	
			} catch (NullPointerException e) {
				throw new LibException("The learner parameter '"+argv[i-1]+"' could not convert the string value '"+argv[i]+"' into a correct numeric value. ", e);	
			}
		}
	}
	
	protected void finalize() throws Throwable {
		try {
			closeInstanceWriter();
		} finally {
			super.finalize();
		}
	}
	
	public String toString() {
		final StringBuffer sb = new StringBuffer();
		sb.append("\n"+getLearningMethodName()+" INTERFACE\n");
		sb.append(getLibOptions());
		return sb.toString();
	}

	protected int binariesInstance(String line, FeatureList featureList) throws MaltChainedException {
		int y = -1; 
		featureList.clear();
		try {	
			String[] columns = tabPattern.split(line);

			if (columns.length == 0) {
				return -1;
			}
			try {
				y = Integer.parseInt(columns[0]);
			} catch (NumberFormatException e) {
				throw new LibException("The instance file contain a non-integer value '"+columns[0]+"'", e);
			}
			for(int j = 1; j < columns.length; j++) {
				final String[] items = pipePattern.split(columns[j]);
				for (int k = 0; k < items.length; k++) {
					try {
						int colon = items[k].indexOf(':');
						if (colon == -1) {
							if (Integer.parseInt(items[k]) != -1) {
								int v = featureMap.addIndex(j, Integer.parseInt(items[k]));
								if (v != -1) {
									featureList.add(v,1);
								}
							}
						} else {
							int index = featureMap.addIndex(j, Integer.parseInt(items[k].substring(0,colon)));
							double value;
							if (items[k].substring(colon+1).indexOf('.') != -1) {
								value = Double.parseDouble(items[k].substring(colon+1));
							} else {
								value = Integer.parseInt(items[k].substring(colon+1));
							}
							featureList.add(index,value);
						}
					} catch (NumberFormatException e) {
						throw new LibException("The instance file contain a non-numeric value '"+items[k]+"'", e);
					}
				}
			}
		} catch (ArrayIndexOutOfBoundsException e) {
			throw new LibException("Couln't read from the instance file. ", e);
		}
		return y;
	}

	protected void binariesInstances2SVMFileFormat(InputStreamReader isr, OutputStreamWriter osw) throws MaltChainedException {
		try {
			final BufferedReader in = new BufferedReader(isr);
			final BufferedWriter out = new BufferedWriter(osw);
			final FeatureList featureSet = new FeatureList();
			while(true) {
				String line = in.readLine();
				if(line == null) break;
				int y = binariesInstance(line, featureSet);
				if (y == -1) {
					continue;
				}
				out.write(Integer.toString(y));
				
		        for (int k=0; k < featureSet.size(); k++) {
		        	MaltFeatureNode x = featureSet.get(k);
					out.write(' ');
					out.write(Integer.toString(x.getIndex()));
					out.write(':');
					out.write(Double.toString(x.getValue()));         
				}
				out.write('\n');
			}			
			in.close();	
			out.close();
		} catch (NumberFormatException e) {
			throw new LibException("The instance file contain a non-numeric value", e);
		} catch (IOException e) {
			throw new LibException("Couln't read from the instance file, when converting the Malt instances into LIBSV/LIBLINEAR format. ", e);
		}
	}
	
	protected void saveFeatureMap(OutputStream os, FeatureMap map) throws MaltChainedException {
		try {
		    ObjectOutputStream output = new ObjectOutputStream(os);
	        try{
	          output.writeObject(map);
	        }
	        finally{
	          output.close();
	        }
		} catch (IOException e) {
			throw new LibException("Save feature map error", e);
		}
	}

	protected FeatureMap loadFeatureMap(InputStream is) throws MaltChainedException {
		FeatureMap map = new FeatureMap();
		try {
		    ObjectInputStream input = new ObjectInputStream(is);
		    try {
		    	map = (FeatureMap)input.readObject();
		    } finally {
		    	input.close();
		    }
		} catch (ClassNotFoundException e) {
			throw new LibException("Load feature map error", e);
		} catch (IOException e) {
			throw new LibException("Load feature map error", e);
		}
		return map;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy