org.deeplearning4j.util.StringGrid Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.util;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import static org.deeplearning4j.berkeley.StringUtils.splitOnCharWithQuoting;
/**
* String matrix
* @author Adam Gibson
*
*/
public class StringGrid extends ArrayList> {
private static final long serialVersionUID = 4702427632483221813L;
private String sep;
private int numColumns = -1;
private static final Logger log = LoggerFactory.getLogger(StringGrid.class);
public final static String NONE = "NONE";
public StringGrid(StringGrid grid) {
this.sep = grid.sep;
this.numColumns = grid.numColumns;
addAll(grid);
fillOut();
}
public StringGrid(String sep, int numColumns) {
this(sep, new ArrayList());
this.numColumns = numColumns;
fillOut();
}
public int getNumColumns() {
return numColumns;
}
private void fillOut() {
for (List list : this) {
if (list.size() < numColumns) {
int diff = numColumns - list.size();
for (int i = 0; i < diff; i++) {
list.add(NONE);
}
}
}
}
public static StringGrid fromFile(String file, String sep) throws IOException {
List read = FileUtils.readLines(new File(file));
if (read.isEmpty())
throw new IllegalStateException("Nothing to read; file is empty");
return new StringGrid(sep, read);
}
public static StringGrid fromInput(InputStream from, String sep) throws IOException {
List read = IOUtils.readLines(from);
if (read.isEmpty())
throw new IllegalStateException("Nothing to read; file is empty");
return new StringGrid(sep, read);
}
public StringGrid(String sep, Collection data) {
super();
this.sep = sep;
List list = new ArrayList<>(data);
for (int i = 0; i < list.size(); i++) {
String line = list.get(i).trim();
//text delimiter
if (line.indexOf('\"') > 0) {
Counter counter = new Counter<>();
for (int j = 0; j < line.length(); j++) {
counter.incrementCount(line.charAt(j), 1.0);
}
if (counter.getCount('"') > 1) {
String[] split = splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\');
add(new ArrayList<>(Arrays.asList(split)));
} else {
List row = new ArrayList<>(
Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
if (numColumns < 0)
numColumns = row.size();
else if (row.size() != numColumns)
log.warn("Row " + i + " had invalid number of columns line was " + line);
add(row);
}
} else {
List row =
new ArrayList<>(Arrays.asList(splitOnCharWithQuoting(line, sep.charAt(0), '"', '\\')));
if (numColumns < 0)
numColumns = row.size();
else if (row.size() != numColumns) {
log.warn("Could not add " + line);
}
add(row);
}
}
fillOut();
}
/**
* Removes all rows with a column of NONE
* @param column the column to remove by
*/
public void removeRowsWithEmptyColumn(int column) {
List> remove = new ArrayList<>();
for (List list : this) {
if (list.get(column).equals(NONE))
remove.add(list);
}
removeAll(remove);
}
public void head(int num) {
if (num >= size())
num = size();
StringBuilder builder = new StringBuilder();
for (int i = 0; i < num; i++) {
builder.append(get(i) + "\n");
}
log.info(builder.toString());
}
/**
* Removes the specified columns from the grid
* @param columns the columns to remove
*/
public void removeColumns(Integer... columns) {
if (columns.length < 1)
throw new IllegalArgumentException("Columns must contain at least one column");
List removeOrder = Arrays.asList(columns);
//put them in the right order for removing
Collections.sort(removeOrder);
for (List list : this) {
List remove = new ArrayList<>();
for (int i = 0; i < columns.length; i++) {
remove.add(list.get(columns[i]));
}
list.removeAll(remove);
}
}
/**
* Removes all rows with a column of missingValue
* @param column he column to remove by
* @param missingValue the missingValue sentinel value
*/
public void removeRowsWithEmptyColumn(int column, String missingValue) {
List> remove = new ArrayList<>();
for (List list : this) {
if (list.get(column).equals(missingValue))
remove.add(list);
}
removeAll(remove);
}
public List> getRowsWithColumnValues(Collection values, int column) {
List> ret = new ArrayList<>();
for (List val : this) {
if (values.contains(val.get(column)))
ret.add(val);
}
return ret;
}
public void sortColumnsByWordLikelihoodIncluded(final int column) {
final Counter counter = new Counter<>();
List col = getColumn(column);
for (String s : col) {
StringTokenizer tokenizer = new StringTokenizer(s);
while (tokenizer.hasMoreTokens()) {
counter.incrementCount(tokenizer.nextToken(), 1.0);
}
}
if (counter.totalCount() <= 0.0) {
log.warn("Unable to calculate probability; nothing found");
return;
}
//laplace smoothing
counter.incrementAll(counter.keySet(), 1.0);
Set remove = new HashSet<>();
for (String key : counter.keySet())
if (key.length() < 2 || key.matches("[a-z]+"))
remove.add(key);
for (String key : remove)
counter.removeKey(key);
counter.pruneKeysBelowThreshold(4.0);
final double totalCount = counter.totalCount();
Collections.sort(this, new Comparator>() {
@Override
public int compare(List o1, List o2) {
double c1 = sumOverTokens(counter, o1.get(column), totalCount);
double c2 = sumOverTokens(counter, o2.get(column), totalCount);
return Double.compare(c1, c2);
}
});
}
/* Return the log sum of the column relative to the word frequencies (equivalent to the probability in log space */
private double sumOverTokens(Counter counter, String column, double totalCount) {
StringTokenizer tokenizer = new StringTokenizer(column);
double count = 0;
while (tokenizer.hasMoreTokens())
count += Math.log(counter.getCount(column) / totalCount);
return count;
}
public StringCluster clusterColumn(int column) {
return new StringCluster(getColumn(column));
}
public void dedupeByClusterAll() {
for (int i = 0; i < size(); i++)
dedupeByCluster(i);
}
/**
* Deduplicate based on the column clustering signature
* @param column
*/
public void dedupeByCluster(int column) {
StringCluster cluster = clusterColumn(column);
System.out.println(cluster.get("family mcdonalds restaurant"));
System.out.println(cluster.get("family mcdonalds restaurants"));
List
© 2015 - 2024 Weber Informatics LLC | Privacy Policy