![JAR search and dependency download from the Maven repository](/logo.png)
com.clust4j.data.DataSet Maven / Gradle / Ivy
/*******************************************************************************
* Copyright 2015, 2016 Taylor G Smith
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package com.clust4j.data;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import com.clust4j.Clust4j;
import com.clust4j.log.Loggable;
import com.clust4j.utils.DeepCloneable;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.MatrixFormatter;
import com.clust4j.utils.TableFormatter;
import com.clust4j.utils.VecUtils;
/**
* A lightweight dataset wrapper that stores information on
* header names, matrix data and classification labels.
* @author Taylor G Smith
*/
public class DataSet extends Clust4j implements DeepCloneable, java.io.Serializable {
private static final long serialVersionUID = -1203771047711850121L;
final static String COL_PREFIX = "V";
final static int DEF_HEAD = 6;
public final static TableFormatter TABLE_FORMATTER= new TableFormatter();
public final static MatrixFormatter DEF_FORMATTER = new MatrixFormatter();
private final MatrixFormatter formatter;
private Array2DRowRealMatrix data;
private int[] labels;
private String[] headers;
private static String[] genHeaders(int size) {
String[] out = new String[size];
for(int i = 0; i < size; i++)
out[i] = COL_PREFIX + i;
return out;
}
public DataSet(double[][] data) {
this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/));
}
public DataSet(Array2DRowRealMatrix data) {
this(data, genHeaders(data.getColumnDimension()));
}
public DataSet(double[][] data, String[] headers) {
this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/), headers);
}
public DataSet(Array2DRowRealMatrix data, String[] headers) {
this(data, null, headers);
}
public DataSet(double[][] data, int[] labels) {
this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/), labels);
}
public DataSet(Array2DRowRealMatrix data, int[] labels) {
this(data, labels, genHeaders(data.getColumnDimension()), DEF_FORMATTER, true);
}
public DataSet(Array2DRowRealMatrix data, int[] labels, MatrixFormatter formatter) {
this(data, labels, genHeaders(data.getColumnDimension()), formatter, true);
}
public DataSet(double[][] data, int[] labels, String[] headers) {
this(new Array2DRowRealMatrix(data, true), labels, headers, DEF_FORMATTER, false);
}
public DataSet(Array2DRowRealMatrix data, int[] labels, String[] headers) {
this(data, labels, headers, DEF_FORMATTER);
}
public DataSet(double[][] data, int[] labels, String[] headers, MatrixFormatter formatter) {
this(new Array2DRowRealMatrix(data, true), labels, headers, formatter, false);
}
public DataSet(Array2DRowRealMatrix data, int[] labels, String[] headers, MatrixFormatter formatter) {
this(data, labels, headers, formatter, true);
}
public DataSet(Array2DRowRealMatrix data, int[] labels, String[] hdrz, MatrixFormatter formatter, boolean copyData) {
/*// we should allow this behavior...
if(null == labels)
throw new IllegalArgumentException("labels cannot be null");
*/
if(null == data)
throw new IllegalArgumentException("data cannot be null");
if(null == hdrz)
this.headers = genHeaders(data.getColumnDimension());
else
this.headers = VecUtils.copy(hdrz);
// Check to make sure dims match up...
if((null != labels) && labels.length != data.getRowDimension())
throw new DimensionMismatchException(labels.length, data.getRowDimension());
if(this.headers.length != data.getColumnDimension())
throw new DimensionMismatchException(this.headers.length, data.getColumnDimension());
this.data = copyData ? (Array2DRowRealMatrix)data.copy() : data;
this.labels = VecUtils.copy(labels);
this.formatter = null == formatter ? DEF_FORMATTER : formatter;
}
public void addColumn(double[] col) {
addColumn(COL_PREFIX + numCols(), col);
}
public void addColumns(double[][] cols) {
MatUtils.checkDims(cols);
final int n = data.getColumnDimension(), length = n + cols[0].length;
final String[] newCols = new String[cols[0].length];
for(int i = n, j = 0; i < length; i++, j++)
newCols[j] = COL_PREFIX + i;
addColumns(newCols, cols);
}
public void addColumn(String s, double[] col) {
VecUtils.checkDims(col);
final int m = col.length;
if(m != data.getRowDimension())
throw new DimensionMismatchException(m, data.getRowDimension());
final int n = data.getColumnDimension();
s = null == s ? (COL_PREFIX + n) : s;
String[] newHeaders = new String[n + 1];
double[][] newData = new double[m][n + 1];
double[][] oldData = data.getDataRef();
for(int i = 0; i < m; i++) {
for(int j = 0; j < n + 1; j++) {
if(i == 0)
newHeaders[j] = j != n ? headers[j]: s;
newData[i][j] = j != n ? oldData[i][j] : col[i];
}
}
this.headers = newHeaders;
this.data = new Array2DRowRealMatrix(newData, false);
}
public void addColumns(String[] s, double[][] cols) {
MatUtils.checkDims(cols);
final int m = cols.length;
if(m != data.getRowDimension())
throw new DimensionMismatchException(m, data.getRowDimension());
int i, j;
final int n = data.getColumnDimension(), newN = n + cols[0].length;
// build headers
if(null == s) {
s = new String[cols[0].length];
for(i = 0, j = n; i < cols[0].length; i++, j++)
s[i] = COL_PREFIX + j;
} else {
// Ensure no nulls
for(i = 0, j = n; i < cols[0].length; i++, j++)
s[i] = null == s[i] ? (COL_PREFIX + j) : s[i];
}
String[] newHeaders = new String[newN];
double[][] newData = new double[m][newN];
double[][] oldData = data.getDataRef();
for(i = 0; i < m; i++) {
for(j = 0; j < newN; j++) {
if(i == 0) {
newHeaders[j] = j < n ? headers[j]: s[j - n];
}
newData[i][j] = j < n ? oldData[i][j] : cols[i][j - n];
}
}
this.headers = newHeaders;
this.data = new Array2DRowRealMatrix(newData, false);
}
@Override
public DataSet copy() {
return new DataSet(data, labels, headers, formatter, true);
}
public double[] dropCol(String nm) {
return dropCol(getColumnIdx(nm));
}
public double[] dropCol(int idx) {
double[] res;
if(idx >= numCols() || idx < 0)
throw new IllegalArgumentException("illegal column index: "+idx);
final int m = numRows(), n = numCols();
final double[][] dataRef = data.getDataRef();
// We know idx won't throw exception
res = data.getColumn(idx);
if(n == 1) {
throw new IllegalStateException("cannot drop last column");
} else {
double[][] newData = new double[m][n - 1];
String[] newHeader = new String[n - 1];
for(int i = 0; i < m; i++) {
int k = 0;
for(int j = 0; j < n; j++) {
if(j == idx)
continue;
else {
if(i == 0) // On first iter, also reassign headers
newHeader[k] = headers[j];
newData[i][k] = dataRef[i][j];
k++;
}
}
}
data = new Array2DRowRealMatrix(newData, false);
headers = newHeader;
}
return res;
}
@Override
public boolean equals(Object o) {
if(this == o)
return true;
if(o instanceof DataSet) {
DataSet other = (DataSet)o;
System.out.println(VecUtils.equalsExactly(labels, other.labels));
return MatUtils.equalsExactly(data.getDataRef(), other.data.getDataRef())
&& VecUtils.equalsExactly(headers, other.headers)
&& VecUtils.equalsExactly(labels, other.labels);
}
return false;
}
/**
* Return a copy of the data
* @return
*/
public Array2DRowRealMatrix getData() {
return (Array2DRowRealMatrix)data.copy();
}
/**
* Returns the column index of the header. If
* multiple columns share the same name (bad practice),
* returns the first which meets the criteria.
* @param header
* @return
*/
private int getColumnIdx(String header) {
int idx = 0;
boolean found = false;
for(String head: headers) {
if(head.equals(header)) {
found = true;
break;
}
idx++;
}
if(!found)
throw new IllegalArgumentException("no such header: "+header);
return idx;
}
/**
* Return a copy of the column
* corresponding to the header
* @param header
* @return
*/
public double[] getColumn(String header) {
return getColumn(getColumnIdx(header));
}
/**
* Return a copy of the column
* corresponding to the header
* @param header
* @return
*/
public double[] getColumn(int i) {
return data.getColumn(i);
}
/**
* Return a reference to the data
* @return
*/
public Array2DRowRealMatrix getDataRef() {
return data;
}
/**
* Get the entry at the given row/col indices
* @param row
* @param col
* @return
*/
public double getEntry(int row, int col) {
return this.data.getEntry(row, col);
}
/**
* Return a copy of the headers
* @return
*/
public String[] getHeaders() {
return VecUtils.copy(headers);
}
/**
* Return a reference to the headers
* @return
*/
public String[] getHeaderRef() {
return headers;
}
/**
* Return a copy of the labels
* @return
*/
public int[] getLabels() {
return null == labels ? null : VecUtils.copy(labels);
}
/**
* Return a reference to the labels
* @return
*/
public int[] getLabelRef() {
return labels;
}
@Override
public int hashCode() {
return 31
^ data.hashCode()
^ headers.hashCode()
^ labels.hashCode();
}
private ArrayList
© 2015 - 2025 Weber Informatics LLC | Privacy Policy