All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.decode.AbstractNLPDecoder Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2015, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.decode;

import edu.emory.mathcs.nlp.common.constant.StringConst;
import edu.emory.mathcs.nlp.common.util.FileUtils;
import edu.emory.mathcs.nlp.common.util.IOUtils;
import edu.emory.mathcs.nlp.common.util.Joiner;
import edu.emory.mathcs.nlp.common.util.Language;
import edu.emory.mathcs.nlp.component.morph.MorphologicalAnalyzer;
import edu.emory.mathcs.nlp.component.template.NLPComponent;
import edu.emory.mathcs.nlp.component.template.lexicon.GlobalLexica;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.reader.TSVReader;
import edu.emory.mathcs.nlp.component.tokenizer.Tokenizer;
import edu.emory.mathcs.nlp.component.tokenizer.token.Token;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public abstract class AbstractNLPDecoder>
{
	private static final Logger LOG = LoggerFactory.getLogger(AbstractNLPDecoder.class);
	static final public String FORMAT_RAW  = "raw";
	static final public String FORMAT_LINE = "line";
	static final public String FORMAT_TSV  = "tsv";
	
	volatile private List> components;
	volatile private Tokenizer tokenizer;
	private DecodeConfig decode_config;

//	======================================== CONSTRUCTORS ========================================
	
	public AbstractNLPDecoder() {}
	
	public AbstractNLPDecoder(DecodeConfig config)
	{
		init(config);
	}
	
	public AbstractNLPDecoder(InputStream configuration)
	{
		init(new DecodeConfig(configuration));
	}
	
	public void init(DecodeConfig config)
	{
		List> components = new ArrayList<>();
		Language language = config.getLanguage();
		decode_config = config;
		
		components.add(new GlobalLexica<>(decode_config.getDocumentElement()));
		
		LOG.info("Loading tokenizer");
		setTokenizer(edu.emory.mathcs.nlp.common.util.NLPUtils.createTokenizer(language));
		
		if (decode_config.getPartOfSpeechTagging() != null)
		{
			LOG.info("Loading part-of-speech tagger");
			components.add(edu.emory.mathcs.nlp.common.util.NLPUtils.getComponent(decode_config.getPartOfSpeechTagging()));
			
			LOG.info("Loading morphological analyzer");
			components.add(new MorphologicalAnalyzer<>(language));
		}
		
		if (decode_config.getNamedEntityRecognition() != null)
		{
			LOG.info("Loading named entity recognizer");
			components.add(edu.emory.mathcs.nlp.common.util.NLPUtils.getComponent(decode_config.getNamedEntityRecognition()));
		}
		
		if (decode_config.getDependencyParsing() != null)
		{
			LOG.info("Loading dependency parser");
			components.add(edu.emory.mathcs.nlp.common.util.NLPUtils.getComponent(decode_config.getDependencyParsing()));
		}
		
//		if (decode_config.getSemanticRoleLabeling() != null)
//		{
//			BinUtils.LOG.info("Loading semantic role labeler\n");
//			add(compoinent, , );
//			components.add(NLPUtils.getComponent(IOUtils.getInputStream(decode_config.getSemanticRoleLabeling())));		
//		}

		setComponents(components);
	}
	
//	======================================== GETTERS/SETTERS ========================================
	
	public Tokenizer getTokenizer()
	{
		return tokenizer;
	}
	
	public List> getComponents()
	{
		return components;
	}
	
	public void setTokenizer(Tokenizer tokenizer)
	{
		this.tokenizer = tokenizer;
	}
	
	public void setComponents(List> components)
	{
		this.components = components;
	}
	
//	======================================== DECODE ========================================

	public void decode(List inputFiles, String outputExt, String format, int threads)
	{
		ExecutorService executor = Executors.newFixedThreadPool(threads);
		String outputFile;
		
		for (String inputFile : inputFiles)
		{
			outputFile = inputFile + StringConst.PERIOD + outputExt;
			executor.submit(new NLPTask(inputFile, outputFile, format));
		}
		
		executor.shutdown();
	}
	
	public String decode(String s, String format)
	{
		return new String(decodeByteArray(s, format));
	}
	
	public byte[] decodeByteArray(String s, String format)
	{
		InputStream bin = new ByteArrayInputStream(s.getBytes());
		ByteArrayOutputStream bout = new ByteArrayOutputStream();
		
		decode(bin, bout, format);
		
		try
		{
			bin .close();
			bout.close();
		}
		catch (IOException e) {e.printStackTrace();}
		
		return bout.toByteArray();
	}
	
	public void decode(InputStream in, OutputStream out, String format)
	{
		try
		{
			switch (format)
			{
			case FORMAT_RAW : decodeRaw (in, out); break;
			case FORMAT_LINE: decodeLine(in, out); break;
			case FORMAT_TSV : decodeTSV (createTSVReader(), in, out); break;
			}
		}
		catch (Exception e) {e.printStackTrace();}
	}
	
	public List decodeDocument(String s) throws IOException
	{
		return decodeDocument(new ByteArrayInputStream(s.getBytes()));
	}
	
	public List decodeDocument(InputStream in) throws IOException
	{
		List document = new ArrayList<>();
		N[] nodes;
		
		for (List tokens : tokenizer.segmentize(in))
		{
			nodes = toNodeArray(tokens);
			decode(nodes);
			document.add(nodes);
		}
		
		in.close();
		return document;
	}
	
	public void decodeRaw(String s, OutputStream out) throws IOException
	{
		decodeRaw(new ByteArrayInputStream(s.getBytes()), out);
	}
	
	public void decodeRaw(InputStream in, OutputStream out) throws IOException
	{
		PrintStream fout = IOUtils.createBufferedPrintStream(out);
		N[] nodes;
		
		for (List tokens : tokenizer.segmentize(in))
		{
			nodes = toNodeArray(tokens);
			decode(nodes);
			fout.println(toString(nodes)+"\n");
		}
		
		in.close();
		fout.close();
	}
	
	public void decodeLine(InputStream in, OutputStream out) throws IOException
	{
		BufferedReader reader = IOUtils.createBufferedReader(in);
		PrintStream fout = IOUtils.createBufferedPrintStream(out);
		N[] nodes;
		String line;
		
		while ((line = reader.readLine()) != null)
		{
			nodes = decode(line);
			fout.println(toString(nodes)+"\n");
		}
		
		reader.close();
		fout.close();
	}
	
	public void decodeTSV(TSVReader reader, InputStream in, OutputStream out) throws IOException
	{
		PrintStream fout = IOUtils.createBufferedPrintStream(out);
		N[] nodes;
		
		reader.open(in);
		
		while ((nodes = reader.next()) != null)
		{
			decode(nodes);
			fout.println(toString(nodes)+"\n");
		}
		
		reader.close();
		fout.close();
	}
	
	public N[] decode(String sentence)
	{
		List tokens = tokenizer.tokenize(sentence);
		return decode(toNodeArray(tokens));
	}
	
	public N[] decode(N[] nodes)
	{
		for (NLPComponent component : components)
			component.process(nodes);
		
		return nodes;
	}
	
	public N[] toNodeArray(List tokens)
	{
		return toNodeArray(tokens, t -> create(t));
	}
	
	@SuppressWarnings("unchecked")
	public N[] toNodeArray(List tokens, Function f)
	{
		N node = create(); node.toRoot();
		N[] nodes = (N[])Array.newInstance(node.getClass(), tokens.size() + 1);
		nodes[0] = node;	// root
		
		for (int i=0,j=1; i createTSVReader()
	{
		return new TSVReader(decode_config.getReaderFieldMap())
		{
			@Override
			protected N create() {return AbstractNLPDecoder.this.create();}
		};
	}
	
	public String toString(N[] nodes)
	{
		return Joiner.join(nodes, "\n", 1);
	}
	
	class NLPTask implements Runnable
	{
		private String input_file;
		private String output_file;
		private String format;
		
		public NLPTask(String inputFile, String outputFile, String format)
		{
			this.input_file  = inputFile;
			this.output_file = outputFile;
			this.format      = format;
		}
		
		@Override
		public void run()
		{
			LOG.info(FileUtils.getBaseName(input_file));
			InputStream  in  = IOUtils.createFileInputStream (input_file);
			OutputStream out = IOUtils.createFileOutputStream(output_file);
			decode(in, out, format);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy