org.deeplearning4j.text.movingwindow.Window Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.text.movingwindow;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
/**
* A representation of a sliding window.
* This is used for creating training examples.
* @author Adam Gibson
*
*/
public class Window implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6359906393699230579L;
private List words;
private String label = "NONE";
private boolean beginLabel;
private boolean endLabel;
private int windowSize;
private int median;
private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>";
private static String END_LABEL = "([A-Z]+|\\d+)>";
private int begin,end;
/**
* Creates a window with a context of size 3
* @param words a collection of strings of size 3
*/
public Window(Collection words,int begin,int end) {
this(words,5,begin,end);
}
public String asTokens() {
return StringUtils.join(words, " ");
}
/**
* Initialize a window with the given size
* @param words the words to use
* @param windowSize the size of the window
* @param begin the begin index for the window
* @param end the end index for the window
*/
public Window(Collection words, int windowSize,int begin,int end) {
if(words == null)
throw new IllegalArgumentException("Words must be a list of size 3");
this.words = new ArrayList<>(words);
this.windowSize = windowSize;
this.begin = begin;
this.end = end;
initContext();
}
private void initContext() {
int median = (int) Math.floor(words.size() / 2);
List begin = words.subList(0, median);
List after = words.subList(median + 1,words.size());
for(String s : begin) {
if(s.matches(BEGIN_LABEL)) {
this.label = s.replaceAll("(<|>)","").replace("/","");
beginLabel = true;
}
else if(s.matches(END_LABEL)) {
endLabel = true;
this.label = s.replaceAll("(<|>|/)","").replace("/","");
}
}
for(String s1 : after) {
if(s1.matches(BEGIN_LABEL)) {
this.label = s1.replaceAll("(<|>)","").replace("/","");
beginLabel = true;
}
if(s1.matches(END_LABEL)) {
endLabel = true;
this.label = s1.replaceAll("(<|>)","");
}
}
this.median = median;
}
@Override
public String toString() {
return words.toString();
}
public List getWords() {
return words;
}
public void setWords(List words) {
this.words = words;
}
public String getWord(int i) {
return words.get(i);
}
public String getFocusWord() {
return words.get(median);
}
public boolean isBeginLabel() {
return !label.equals("NONE") && beginLabel;
}
public boolean isEndLabel() {
return !label.equals("NONE") && endLabel;
}
public String getLabel() {
return label.replace("/","");
}
public int getWindowSize() {
return words.size();
}
public int getMedian() {
return median;
}
public void setLabel(String label) {
this.label = label;
}
public int getBegin() {
return begin;
}
public void setBegin(int begin) {
this.begin = begin;
}
public int getEnd() {
return end;
}
public void setEnd(int end) {
this.end = end;
}
}