edu.stanford.nlp.neural.Embedding Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of stanford-parser Show documentation
Show all versions of stanford-parser Show documentation
Stanford Parser processes raw text in English, Chinese, German, Arabic, and French, and extracts constituency parse trees.
/**
*
*/
package edu.stanford.nlp.neural;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Pattern;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.Generics;
/**
* @author Minh-Thang Luong
* @author John Bauer
* @author Richard Socher
*/
public class Embedding {
private Map wordVectors;
private int embeddingSize;
static final String START_WORD = "*START*";
static final String END_WORD = "*END*";
static final String UNKNOWN_WORD = "*UNK*";
static final String UNKNOWN_NUMBER = "*NUM*";
static final String UNKNOWN_CAPS = "*CAPS*";
static final String UNKNOWN_CHINESE_YEAR = "*ZH_YEAR*";
static final String UNKNOWN_CHINESE_NUMBER = "*ZH_NUM*";
static final String UNKNOWN_CHINESE_PERCENT = "*ZH_PERCENT*";
static final Pattern NUMBER_PATTERN = Pattern.compile("-?[0-9][-0-9,.:]*");
static final Pattern CAPS_PATTERN = Pattern.compile("[a-zA-Z]*[A-Z][a-zA-Z]*");
static final Pattern CHINESE_YEAR_PATTERN = Pattern.compile("[〇零一二三四五六七八九0123456789]{4}+年");
static final Pattern CHINESE_NUMBER_PATTERN = Pattern.compile("(?:[〇0零一二三四五六七八九0123456789十百万千亿]+[点多]?)+");
static final Pattern CHINESE_PERCENT_PATTERN = Pattern.compile("百分之[〇0零一二三四五六七八九0123456789十点]+");
/**
* Some word vectors are trained with DG representing number.
* We mix all of those into the unknown number vectors.
*/
static final Pattern DG_PATTERN = Pattern.compile(".*DG.*");
public Embedding(Map wordVectors) {
this.wordVectors = wordVectors;
this.embeddingSize = getEmbeddingSize(wordVectors);
}
public Embedding(String wordVectorFile) {
this(wordVectorFile, 0);
}
public Embedding(String wordVectorFile, int embeddingSize) {
this.wordVectors = Generics.newHashMap();
this.embeddingSize = embeddingSize;
loadWordVectors(wordVectorFile);
}
public Embedding(String wordFile, String vectorFile) {
this(wordFile, vectorFile, 0);
}
public Embedding(String wordFile, String vectorFile, int embeddingSize) {
this.wordVectors = Generics.newHashMap();
this.embeddingSize = embeddingSize;
loadWordVectors(wordFile, vectorFile);
}
/**
* This method reads a file of raw word vectors, with a given expected size, and returns a map of word to vector.
*
* The file should be in the format
* WORD X1 X2 X3 ...
* If vectors in the file are smaller than expectedSize, an
* exception is thrown. If vectors are larger, the vectors are
* truncated and a warning is printed.
*/
private void loadWordVectors(String wordVectorFile) {
System.err.println("# Loading embedding ...\n word vector file = " + wordVectorFile);
boolean warned = false;
int numWords = 0;
for (String line : IOUtils.readLines(wordVectorFile, "utf-8")) {
String[] lineSplit = line.split("\\s+");
String word = lineSplit[0];
// check for unknown token
if(word.equals("UNKNOWN") || word.equals("UUUNKKK") || word.equals("UNK") || word.equals("*UNKNOWN*") || word.equals("")){
word = UNKNOWN_WORD;
}
// check for start token
if(word.equals("")){
word = START_WORD;
}
// check for end token
if(word.equals("")){
word = END_WORD;
}
int dimOfWords = lineSplit.length - 1;
if (embeddingSize <= 0) {
embeddingSize = dimOfWords;
System.err.println(" detected embedding size = " + dimOfWords);
}
// the first entry is the word itself
// the other entries will all be entries in the word vector
if (dimOfWords > embeddingSize) {
if (!warned) {
warned = true;
System.err.println("WARNING: Dimensionality of numHid parameter and word vectors do not match, deleting word vector dimensions to fit!");
}
dimOfWords = embeddingSize;
} else if (dimOfWords < embeddingSize) {
throw new RuntimeException("Word vectors file has dimension too small for requested numHid of " + embeddingSize);
}
double[][] vec = new double[dimOfWords][1];
for (int i = 1; i <= dimOfWords; i++) {
vec[i-1][0] = Double.parseDouble(lineSplit[i]);
}
SimpleMatrix vector = new SimpleMatrix(vec);
wordVectors.put(word, vector);
numWords++;
}
System.err.println(" num words = " + numWords);
}
/**
* This method takes as input two files: wordFile (one word per line) and a raw word vector file
* with a given expected size, and returns a map of word to vector.
*
* The word vector file should be in the format
* X1 X2 X3 ...
* If vectors in the file are smaller than expectedSize, an
* exception is thrown. If vectors are larger, the vectors are
* truncated and a warning is printed.
*/
private void loadWordVectors(String wordFile, String vectorFile) {
System.err.println("# Loading embedding ...\n word file = " + wordFile + "\n vector file = " + vectorFile);
boolean warned = false;
int numWords = 0;
Iterator wordIterator = IOUtils.readLines(wordFile, "utf-8").iterator();
for (String line : IOUtils.readLines(vectorFile, "utf-8")) {
String[] lineSplit = line.split("\\s+");
String word = wordIterator.next();
// check for unknown token
// FIXME cut and paste code
if(word.equals("UNKNOWN") || word.equals("UUUNKKK") || word.equals("UNK") || word.equals("*UNKNOWN*") || word.equals("")){
word = UNKNOWN_WORD;
}
// check for start token
if(word.equals("")){
word = START_WORD;
}
// check for end token
if(word.equals("")){
word = END_WORD;
}
int dimOfWords = lineSplit.length;
if (embeddingSize <= 0) {
embeddingSize = dimOfWords;
System.err.println(" detected embedding size = " + dimOfWords);
}
// the first entry is the word itself
// the other entries will all be entries in the word vector
if (dimOfWords > embeddingSize) {
if (!warned) {
warned = true;
System.err.println("WARNING: Dimensionality of numHid parameter and word vectors do not match, deleting word vector dimensions to fit!");
}
dimOfWords = embeddingSize;
} else if (dimOfWords < embeddingSize) {
throw new RuntimeException("Word vectors file has dimension too small for requested numHid of " + embeddingSize);
}
double[][] vec = new double[dimOfWords][1];
for (int i = 0; i < dimOfWords; i++) {
vec[i][0] = Double.parseDouble(lineSplit[i]);
}
SimpleMatrix vector = new SimpleMatrix(vec);
wordVectors.put(word, vector);
numWords++;
}
System.err.println(" num words = " + numWords);
}
/*** Getters & Setters ***/
public int size(){
return wordVectors.size();
}
public Collection values(){
return wordVectors.values();
}
public Set keySet(){
return wordVectors.keySet();
}
public Set> entrySet(){
return wordVectors.entrySet();
}
public SimpleMatrix get(String word) {
if(wordVectors.containsKey(word)){
return wordVectors.get(word);
} else {
return wordVectors.get(UNKNOWN_WORD);
}
}
public SimpleMatrix getStartWordVector() {
return wordVectors.get(START_WORD);
}
public SimpleMatrix getEndWordVector() {
return wordVectors.get(END_WORD);
}
public SimpleMatrix getUnknownWordVector() {
return wordVectors.get(UNKNOWN_WORD);
}
public Map getWordVectors() {
return wordVectors;
}
public int getEmbeddingSize() {
return embeddingSize;
}
public void setWordVectors(Map wordVectors) {
this.wordVectors = wordVectors;
this.embeddingSize = getEmbeddingSize(wordVectors);
}
private static int getEmbeddingSize(Map wordVectors){
if (!wordVectors.containsKey(UNKNOWN_WORD)){
// find if there's any other unk string
String unkStr = "";
if (wordVectors.containsKey("UNK")) { unkStr = "UNK"; }
if (wordVectors.containsKey("UUUNKKK")) { unkStr = "UUUNKKK"; }
if (wordVectors.containsKey("UNKNOWN")) { unkStr = "UNKNOWN"; }
if (wordVectors.containsKey("*UNKNOWN*")) { unkStr = "*UNKNOWN*"; }
if (wordVectors.containsKey("")) { unkStr = ""; }
// set UNKNOWN_WORD
if (!unkStr.equals("")){
wordVectors.put(UNKNOWN_WORD, wordVectors.get(unkStr));
} else {
throw new RuntimeException("! wordVectors used to initialize Embedding doesn't contain any recognized form of " + UNKNOWN_WORD);
}
}
return wordVectors.get(UNKNOWN_WORD).getNumElements();
}
}