All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.emory.mathcs.nlp.component.doc.DOCFeatureTemplate Maven / Gradle / Ivy

The newest version!
/**
 * Copyright 2016, Emory University
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.emory.mathcs.nlp.component.doc;

import edu.emory.mathcs.nlp.common.collection.tuple.ObjectFloatPair;
import edu.emory.mathcs.nlp.common.util.MathUtils;
import edu.emory.mathcs.nlp.common.util.XMLUtils;
import edu.emory.mathcs.nlp.component.template.feature.FeatureItem;
import edu.emory.mathcs.nlp.component.template.feature.FeatureTemplate;
import edu.emory.mathcs.nlp.component.template.feature.Field;
import edu.emory.mathcs.nlp.component.template.node.AbstractNLPNode;
import edu.emory.mathcs.nlp.component.template.train.HyperParameter;
import edu.emory.mathcs.nlp.learning.util.SparseVector;
import it.unimi.dsi.fastutil.objects.Object2FloatMap;
import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;
import org.w3c.dom.Element;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;

/**
 * @author Jinho D. Choi ({@code [email protected]})
 */
public class DOCFeatureTemplate, S extends DOCState> extends FeatureTemplate
{
	private static final long serialVersionUID = 8581842859392646419L;
	protected List feature_list_type;
	
	public DOCFeatureTemplate(Element eFeatures, HyperParameter hp)
	{
		super(eFeatures, hp);
	}
	
	@Override
	protected void initFeatureItems(Element element)
	{
		FeatureItem[] items = createFeatureItems(element);
		if (feature_list_type == null) feature_list_type = new ArrayList<>();
		
		if (items != null && items.length > 0 && items[0].field == Field.word_embedding)
			addWordEmbedding(items[0]);
		else
			add(items);
		
		feature_list_type.add(Field.valueOf(XMLUtils.getTrimmedAttribute(element, "t")));
	}
	
	@Override
	public SparseVector createSparseVector(S state, boolean isTrain)
	{
		Collection> t;
		SparseVector x = new SparseVector();
		int i, type = 0;
		
		for (i=0; i s : t) add(x, type, s.o, s.f, isTrain);
		}
		
		return x;
	}
	
	protected Collection> getWeightedFeatures(S state, FeatureItem[] items, Field type)
	{
		Object2FloatMap map = getBagOfLexicons(state, items, type);
		return (map == null || map.isEmpty()) ? null : getBagOfLexicons(map, type);
	}
	
	protected Object2FloatMap getBagOfLexicons(S state, FeatureItem[] items, Field type)
	{
		switch (type)
		{
		case bag_of_words:
		case bag_of_words_norm:
		case bag_of_words_count:
			return getBagOfWords(state, items, false);
		case bag_of_words_stopwords:
		case bag_of_words_stopwords_norm:
		case bag_of_words_stopwords_count:
			return getBagOfWords(state, items, true);
		case bag_of_clusters:
		case bag_of_clusters_norm:
		case bag_of_clusters_count:
			return getBagOfClusters(state, false);
		case bag_of_clusters_stopwords:
		case bag_of_clusters_stopwords_norm:
		case bag_of_clusters_stopwords_count:
			return getBagOfClusters(state, true);
		default: return null;
		}
	}
	
	protected Collection> getBagOfLexicons(Object2FloatMap map, Field type)
	{
		switch (type)
		{
		case bag_of_words:
		case bag_of_clusters:
		case bag_of_words_stopwords:
		case bag_of_clusters_stopwords:
			return map.entrySet().stream().map(e -> new ObjectFloatPair<>(e.getKey(), 1f)).collect(Collectors.toList());
		case bag_of_words_count:
		case bag_of_clusters_count:
		case bag_of_words_stopwords_count:
		case bag_of_clusters_stopwords_count:
			return map.entrySet().stream().map(e -> new ObjectFloatPair<>(e.getKey(), e.getValue())).collect(Collectors.toList());
		case bag_of_words_norm:
		case bag_of_clusters_norm:
		case bag_of_words_stopwords_norm:
		case bag_of_clusters_stopwords_norm:
//			float total = (float)map.entrySet().stream().mapToDouble(e -> e.getValue()).sum();
			return map.entrySet().stream().map(e -> new ObjectFloatPair<>(e.getKey(), (float)MathUtils.sigmoid(e.getValue()))).collect(Collectors.toList());
		default: return null;
		}
	}
	
	protected Object2FloatMap getBagOfWords(S state, FeatureItem[] items, boolean stopwords)
	{
		Object2FloatMap map = new Object2FloatOpenHashMap<>();
		N node;
		int index;
		String f;
		
		for (N[] nodes : state.getDocument(stopwords))
		{
			outer: for (int i=1; i= nodes.length) continue outer;
					node = state.getRelativeNode(nodes[index], item.relation);
					if (node == null) continue outer;
					f = getFeature(state, item, node);
					if (f == null) continue outer;
					join.add(f);
				}
				
				map.merge(join.toString(), 1f, (oldCount, newCount) -> oldCount + newCount);
			}
		}
		
		return map;
	}
	
	protected Object2FloatMap getBagOfClusters(S state, boolean stopwords)
	{
		Object2FloatMap map = new Object2FloatOpenHashMap<>();
		Set clusters;
		
		for (N[] nodes : state.getDocument(stopwords))
		{
			for (int i=1; i oldCount + newCount);
			}
		}
		
		return map;
	}
	
	@Override
	public float[] createDenseVector(S state)
	{
		if (word_embeddings == null || word_embeddings.isEmpty()) return null;
		return getEmbeddings(state, true);
	}
	
	public float[] getEmbeddings(S state, boolean average)
	{
		float[] w, v = null;
		int count = 0;
		N node;
		
		for (N[] nodes : state.getDocument())
		{
			for (int i=1; i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy