org.deeplearning4j.util.StringGrid Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.util;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import static org.deeplearning4j.berkeley.StringUtils.splitOnCharWithQuoting;
/**
* String matrix
* @author Adam Gibson
*
*/
public class StringGrid extends ArrayList> {
private static final long serialVersionUID = 4702427632483221813L;
private String sep;
private int numColumns = -1;
private static final Logger log = LoggerFactory.getLogger(StringGrid.class);
public final static String NONE = "NONE";
public StringGrid(StringGrid grid) {
this.sep = grid.sep;
this.numColumns = grid.numColumns;
addAll(grid);
fillOut();
}
public StringGrid(String sep,int numColumns) {
this(sep,new ArrayList());
this.numColumns = numColumns;
fillOut();
}
public int getNumColumns() {
return numColumns;
}
private void fillOut() {
for(List list : this) {
if(list.size() < numColumns) {
int diff = numColumns - list.size();
for(int i = 0; i< diff;i++) {
list.add(NONE);
}
}
}
}
public static StringGrid fromFile(String file,String sep) throws IOException {
List read = FileUtils.readLines(new File(file));
if(read.isEmpty())
throw new IllegalStateException("Nothing to read; file is empty");
return new StringGrid(sep,read);
}
public static StringGrid fromInput(InputStream from,String sep) throws IOException {
List read = IOUtils.readLines(from);
if(read.isEmpty())
throw new IllegalStateException("Nothing to read; file is empty");
return new StringGrid(sep,read);
}
public StringGrid(String sep, Collection data) {
super();
this.sep = sep;
List list = new ArrayList<>(data);
for(int i = 0;i< list.size(); i++) {
String line = list.get(i).trim();
//text delimiter
if(line.indexOf('\"') > 0) {
Counter counter = new Counter<>();
for(int j = 0; j < line.length(); j++) {
counter.incrementCount(line.charAt(j),1.0);
}
if(counter.getCount('"') > 1) {
String[] split = splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\');
add(new ArrayList<>(Arrays.asList(split)));
}
else {
List row = new ArrayList<>(Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
if(numColumns < 0)
numColumns = row.size();
else if(row.size() != numColumns)
log.warn("Row " + i + " had invalid number of columns line was " + line);
add(row);
}
}
else {
List row = new ArrayList<>(Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
if(numColumns < 0)
numColumns = row.size();
else if(row.size() != numColumns) {
log.warn("Could not add " + line);
}
add(row);
}
}
fillOut();
}
/**
* Removes all rows with a column of NONE
* @param column the column to remove by
*/
public void removeRowsWithEmptyColumn(int column) {
List> remove = new ArrayList<>();
for(List list : this) {
if(list.get(column).equals(NONE))
remove.add(list);
}
removeAll(remove);
}
public void head(int num) {
if(num >= size())
num = size();
StringBuilder builder = new StringBuilder();
for(int i = 0; i < num; i++) {
builder.append(get(i) + "\n");
}
log.info(builder.toString());
}
/**
* Removes the specified columns from the grid
* @param columns the columns to remove
*/
public void removeColumns(Integer...columns) {
if(columns.length < 1)
throw new IllegalArgumentException("Columns must contain at least one column");
List removeOrder = Arrays.asList(columns);
//put them in the right order for removing
Collections.sort(removeOrder);
for(List list : this) {
List remove = new ArrayList<>();
for(int i = 0; i < columns.length; i++) {
remove.add(list.get(columns[i]));
}
list.removeAll(remove);
}
}
/**
* Removes all rows with a column of missingValue
* @param column he column to remove by
* @param missingValue the missingValue sentinel value
*/
public void removeRowsWithEmptyColumn(int column,String missingValue) {
List> remove = new ArrayList<>();
for(List list : this) {
if(list.get(column).equals(missingValue))
remove.add(list);
}
removeAll(remove);
}
public List> getRowsWithColumnValues(Collection values,int column) {
List> ret = new ArrayList<>();
for(List val : this) {
if(values.contains(val.get(column)))
ret.add(val);
}
return ret;
}
public void sortColumnsByWordLikelihoodIncluded(final int column) {
final Counter counter = new Counter<>();
List col = getColumn(column);
for(String s : col) {
StringTokenizer tokenizer = new StringTokenizer(s);
while(tokenizer.hasMoreTokens()) {
counter.incrementCount(tokenizer.nextToken(),1.0);
}
}
if(counter.totalCount() <= 0.0) {
log.warn("Unable to calculate probability; nothing found");
return;
}
//laplace smoothing
counter.incrementAll(counter.keySet(), 1.0);
Set remove = new HashSet<>();
for(String key : counter.keySet())
if(key.length() < 2 || key.matches("[a-z]+"))
remove.add(key);
for(String key : remove)
counter.removeKey(key);
counter.pruneKeysBelowThreshold(4.0);
final double totalCount = counter.totalCount();
Collections.sort(this,new Comparator>() {
@Override
public int compare(List o1, List o2) {
double c1 = sumOverTokens(counter,o1.get(column),totalCount);
double c2 = sumOverTokens(counter,o2.get(column),totalCount);
return Double.compare(c1, c2);
}
});
}
/* Return the log sum of the column relative to the word frequencies (equivalent to the probability in log space */
private double sumOverTokens(Counter counter,String column,double totalCount) {
StringTokenizer tokenizer = new StringTokenizer(column);
double count = 0;
while(tokenizer.hasMoreTokens())
count += Math.log(counter.getCount(column)/totalCount);
return count;
}
public StringCluster clusterColumn(int column) {
return new StringCluster(getColumn(column));
}
public void dedupeByClusterAll() {
for(int i = 0; i< size(); i++)
dedupeByCluster(i);
}
/**
* Deduplicate based on the column clustering signature
* @param column
*/
public void dedupeByCluster(int column) {
StringCluster cluster = clusterColumn(column);
System.out.println(cluster.get("family mcdonalds restaurant"));
System.out.println(cluster.get("family mcdonalds restaurants"));
List
© 2015 - 2024 Weber Informatics LLC | Privacy Policy