All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.clust4j.data.DataSet Maven / Gradle / Ivy

/*******************************************************************************
 *    Copyright 2015, 2016 Taylor G Smith
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 *******************************************************************************/
package com.clust4j.data;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;

import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;

import com.clust4j.Clust4j;
import com.clust4j.log.Loggable;
import com.clust4j.utils.DeepCloneable;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.MatrixFormatter;
import com.clust4j.utils.TableFormatter;
import com.clust4j.utils.VecUtils;

/**
 * A lightweight dataset wrapper that stores information on
 * header names, matrix data and classification labels.
 * @author Taylor G Smith
 */
public class DataSet extends Clust4j implements DeepCloneable, java.io.Serializable {
	private static final long serialVersionUID = -1203771047711850121L;
	
	final static String COL_PREFIX = "V";
	final static int DEF_HEAD = 6;
	
	public final static TableFormatter TABLE_FORMATTER= new TableFormatter();
	public final static MatrixFormatter DEF_FORMATTER = new MatrixFormatter();
	private final MatrixFormatter formatter;
	
	private Array2DRowRealMatrix data;
	private int[] labels;
	private String[] headers;
	
	
	private static String[] genHeaders(int size) {
		String[] out = new String[size];
		for(int i = 0; i < size; i++)
			out[i] = COL_PREFIX + i;
		return out;
	}
	
	public DataSet(double[][] data) {
		this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/));
	}
	
	public DataSet(Array2DRowRealMatrix data) {
		this(data, genHeaders(data.getColumnDimension()));
	}
	
	public DataSet(double[][] data, String[] headers) {
		this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/), headers);
	}
	
	public DataSet(Array2DRowRealMatrix data, String[] headers) {
		this(data, null, headers);
	}
	
	public DataSet(double[][] data, int[] labels) {
		this(new Array2DRowRealMatrix(data, false /*Going to copy later anyways*/), labels);
	}
	
	public DataSet(Array2DRowRealMatrix data, int[] labels) {
		this(data, labels, genHeaders(data.getColumnDimension()), DEF_FORMATTER, true);
	}
	
	public DataSet(Array2DRowRealMatrix data, int[] labels, MatrixFormatter formatter) {
		this(data, labels, genHeaders(data.getColumnDimension()), formatter, true);
	}
	
	public DataSet(double[][] data, int[] labels, String[] headers) {
		this(new Array2DRowRealMatrix(data, true), labels, headers, DEF_FORMATTER, false);
	}
	
	public DataSet(Array2DRowRealMatrix data, int[] labels, String[] headers) {
		this(data, labels, headers, DEF_FORMATTER);
	}
	
	public DataSet(double[][] data, int[] labels, String[] headers, MatrixFormatter formatter) {
		this(new Array2DRowRealMatrix(data, true), labels, headers, formatter, false);
	}
	
	public DataSet(Array2DRowRealMatrix data, int[] labels, String[] headers, MatrixFormatter formatter) {
		this(data, labels, headers, formatter, true);
	}
	
	public DataSet(Array2DRowRealMatrix data, int[] labels, String[] hdrz, MatrixFormatter formatter, boolean copyData) {
		
		/*// we should allow this behavior...
		if(null == labels)
			throw new IllegalArgumentException("labels cannot be null");
		*/

		if(null == data)
			throw new IllegalArgumentException("data cannot be null");
		if(null == hdrz)
			this.headers = genHeaders(data.getColumnDimension());
		else 
			this.headers = VecUtils.copy(hdrz);
		
		
		// Check to make sure dims match up...
		if((null != labels) && labels.length != data.getRowDimension())
			throw new DimensionMismatchException(labels.length, data.getRowDimension());
		if(this.headers.length != data.getColumnDimension())
			throw new DimensionMismatchException(this.headers.length, data.getColumnDimension());
		
		this.data = copyData ? (Array2DRowRealMatrix)data.copy() : data;
		this.labels = VecUtils.copy(labels);
		this.formatter = null == formatter ? DEF_FORMATTER : formatter;
	}
	
	public void addColumn(double[] col) {
		addColumn(COL_PREFIX + numCols(), col);
	}
	
	public void addColumns(double[][] cols) {
		MatUtils.checkDims(cols);
		
		final int n = data.getColumnDimension(), length = n + cols[0].length;
		final String[] newCols = new String[cols[0].length];
		for(int i = n, j = 0; i < length; i++, j++)
			newCols[j] = COL_PREFIX + i;
		
		addColumns(newCols, cols);
	}
	
	public void addColumn(String s, double[] col) {
		VecUtils.checkDims(col);
		
		final int m = col.length;
		if(m != data.getRowDimension())
			throw new DimensionMismatchException(m, data.getRowDimension());
		
		final int n = data.getColumnDimension();
		s = null == s ? (COL_PREFIX + n) : s;
		
		String[] newHeaders = new String[n + 1];
		double[][] newData = new double[m][n + 1];
		double[][] oldData = data.getDataRef();
		
		for(int i = 0; i < m; i++) {
			for(int j = 0; j < n + 1; j++) {
				if(i == 0)
					newHeaders[j] = j != n ? headers[j]: s;
				newData[i][j] = j != n ? oldData[i][j] : col[i];
			}
		}
				
		this.headers = newHeaders;
		this.data = new Array2DRowRealMatrix(newData, false);
	}
	
	public void addColumns(String[] s, double[][] cols) {
		MatUtils.checkDims(cols);
		
		final int m = cols.length;
		if(m != data.getRowDimension())
			throw new DimensionMismatchException(m, data.getRowDimension());
		
		int i, j;
		final int n = data.getColumnDimension(), newN = n + cols[0].length;
		
		// build headers
		if(null == s) {
			s = new String[cols[0].length];
			for(i = 0, j = n; i < cols[0].length; i++, j++)
				s[i] = COL_PREFIX + j;
		} else {
			// Ensure no nulls
			for(i = 0, j = n; i < cols[0].length; i++, j++)
				s[i] = null == s[i] ? (COL_PREFIX + j) : s[i];
		}
		
		
		String[] newHeaders = new String[newN];
		double[][] newData = new double[m][newN];
		double[][] oldData = data.getDataRef();
		
		for(i = 0; i < m; i++) {
			for(j = 0; j < newN; j++) {
				if(i == 0) {
					newHeaders[j] = j < n ? headers[j]: s[j - n];
				}
					
				newData[i][j] = j < n ? oldData[i][j] : cols[i][j - n];
			}
		}
				
		this.headers = newHeaders;
		this.data = new Array2DRowRealMatrix(newData, false);
	}
	
	@Override
	public DataSet copy() {
		return new DataSet(data, labels, headers, formatter, true);
	}
	
	public double[] dropCol(String nm) {
		return dropCol(getColumnIdx(nm));
	}
	
	public double[] dropCol(int idx) {
		double[] res;
		if(idx >= numCols() || idx < 0)
			throw new IllegalArgumentException("illegal column index: "+idx);
		
		final int m = numRows(), n = numCols();
		final double[][] dataRef = data.getDataRef();
		
		// We know idx won't throw exception
		res = data.getColumn(idx);
		
		
		if(n == 1) {
			throw new IllegalStateException("cannot drop last column");
		} else {
			double[][] newData = new double[m][n - 1];
			String[] newHeader = new String[n - 1];
			
			for(int i = 0; i < m; i++) {
				int k = 0;
				for(int j = 0; j < n; j++) {
					if(j == idx)
						continue;
					else {
						if(i == 0) // On first iter, also reassign headers
							newHeader[k] = headers[j];
						newData[i][k] = dataRef[i][j];
						k++;
					}
				}
			}
			
			data = new Array2DRowRealMatrix(newData, false);
			headers = newHeader;
		}
		
		return res;
	}
	
	@Override
	public boolean equals(Object o) {
		if(this == o)
			return true;
		if(o instanceof DataSet) {
			DataSet other = (DataSet)o;
			System.out.println(VecUtils.equalsExactly(labels, other.labels));
			
			return MatUtils.equalsExactly(data.getDataRef(), other.data.getDataRef())
				&& VecUtils.equalsExactly(headers, other.headers)
				&& VecUtils.equalsExactly(labels, other.labels);
		}
		
		return false;
	}
	
	/**
	 * Return a copy of the data
	 * @return
	 */
	public Array2DRowRealMatrix getData() {
		return (Array2DRowRealMatrix)data.copy();
	}
	
	/**
	 * Returns the column index of the header. If
	 * multiple columns share the same name (bad practice),
	 * returns the first which meets the criteria.
	 * @param header
	 * @return
	 */
	private int getColumnIdx(String header) {
		int idx = 0;
		boolean found = false;
		for(String head: headers) {
			if(head.equals(header)) {
				found = true;
				break;
			}
			
			idx++;
		}
			
		if(!found)
			throw new IllegalArgumentException("no such header: "+header);
		
		return idx;
	}
	
	/**
	 * Return a copy of the column 
	 * corresponding to the header
	 * @param header
	 * @return
	 */
	public double[] getColumn(String header) {
		return getColumn(getColumnIdx(header));
	}
	
	/**
	 * Return a copy of the column 
	 * corresponding to the header
	 * @param header
	 * @return
	 */
	public double[] getColumn(int i) {
		return data.getColumn(i);
	}
	
	/**
	 * Return a reference to the data
	 * @return
	 */
	public Array2DRowRealMatrix getDataRef() {
		return data;
	}
	
	/**
	 * Get the entry at the given row/col indices
	 * @param row
	 * @param col
	 * @return
	 */
	public double getEntry(int row, int col) {
		return this.data.getEntry(row, col);
	}
	
	/**
	 * Return a copy of the headers
	 * @return
	 */
	public String[] getHeaders() {
		return VecUtils.copy(headers);
	}
	
	/**
	 * Return a reference to the headers
	 * @return
	 */
	public String[] getHeaderRef() {
		return headers;
	}
	
	/**
	 * Return a copy of the labels
	 * @return
	 */
	public int[] getLabels() {
		return null == labels ? null : VecUtils.copy(labels);
	}
	
	/**
	 * Return a reference to the labels
	 * @return
	 */
	public int[] getLabelRef() {
		return labels;
	}
	
	@Override
	public int hashCode() {
		return 31 
			^ data.hashCode()
			^ headers.hashCode()
			^ labels.hashCode();
	}
	
	private ArrayList buildHead(int length) {
		if(length < 1)
			throw new IllegalArgumentException("length cannot be less than 1");
		
		int n = data.getColumnDimension();
		ArrayList o = new ArrayList();
		double[][] d = data.getDataRef();
		o.add(new Object[n]); // There's always one extra row
		
		for(int i = 0; i < length; i++) {
			o.add(new Object[n]);
			
			for(int j = 0; j < n; j++) {
				if(i == 0) {
					o.get(i)[j] = headers[j];
				}
				
				o.get(i+1)[j] = d[i][j];
			}
		}
		
		return o;
	}
	
	public void head() {
		head(DEF_HEAD);
	}
	
	public void head(int numRows) {
		System.out.println(TABLE_FORMATTER.format(buildHead(numRows)));
	}
	
	/**
	 * View the dataset in the log
	 * @param logger
	 */
	public void log(Loggable logger) {
		logger.info(this.toString());
	}
	
	public int numCols() {
		return data.getColumnDimension();
	}
	
	public int numRows() {
		return data.getRowDimension();
	}
	
	public void setColumn(String name, final double[] col) {
		setColumn(getColumnIdx(name), col);
	}
	
	public void setColumn(final int idx, final double[] col) {
		final int n = data.getColumnDimension();
		if(idx >= n || idx < 0)
			throw new IllegalArgumentException("illegal column index: "+idx);
		
		data.setColumn(idx, col);
	}
	
	/**
	 * Set the indices of row/col to the new value and
	 * return the old value
	 * @param row
	 * @param col
	 * @param newValue
	 * @return
	 */
	public double setEntry(int row, int col, double newValue) {
		double d = getEntry(row, col);
		this.data.setEntry(row, col, newValue);
		return d;
	}
	
	public void setLabels(final int[] labels) {
		if(null == labels) // null out existing labels
			this.labels = labels;
		else if(labels.length == data.getRowDimension()) {
			this.labels = labels;
		} else {
			throw new DimensionMismatchException(labels.length, data.getRowDimension());
		}
	}
	
	public void setRow(final int idx, final double[] newRow) {
		final int m = data.getRowDimension();
		if(idx >= m || idx < 0)
			throw new IllegalArgumentException("illegal row index: "+idx);
		
		data.setRow(idx, newRow);
	}
	
	/**
	 * Shuffle the rows (and corresponding labels, if they exist)
	 * and return the new dataset
	 * in place
	 */
	public DataSet shuffle() {
		final int m = numRows();
		boolean has_labels = null != labels; // if the labels are null, there are no labels to shuffle...
		
		/*
		 * Generate range of indices...
		 */
		ArrayList indices = new ArrayList();
		for(int i = 0; i < m; i++)
			indices.add(i);
		
		/*
		 * Shuffle indices in place...
		 */
		Collections.shuffle(indices);
		final int[] newLabels = has_labels ? new int[m] : null;
		final double[][] newData = new double[m][];
		
		/*
		 * Reorder things...
		 */
		int j = 0;
		for(Integer idx: indices) {
			if(has_labels) {
				newLabels[j] = this.labels[idx];
			}
			
			newData[j] = VecUtils.copy(this.data.getRow(idx));
			j++;
		}
		
		return new DataSet(
			new Array2DRowRealMatrix(newData, true),
			newLabels,
			getHeaders(),
			formatter,
			false
		);
	}
	
	public DataSet slice(int startInc, int endExc) {
		int[] labs = (null == labels) ? null : VecUtils.slice(labels, startInc, endExc);
		
		return new DataSet(
			MatUtils.slice(data.getDataRef(), startInc, endExc),
			labs,
			getHeaders()
		);
	}
	
	public void sortAscInPlace(String col) {
		sortAscInPlace(getColumnIdx(col));
	}
	
	public void sortAscInPlace(int colIdx) {
		if(colIdx < 0 || colIdx >= data.getColumnDimension())
			throw new IllegalArgumentException("col out of bounds");
		
		double[][] dataRef = data.getDataRef();
		data = new Array2DRowRealMatrix(MatUtils.sortAscByCol(dataRef, colIdx), false);
	}
	
	public void sortDescInPlace(String col) {
		sortDescInPlace(getColumnIdx(col));
	}
	
	public void sortDescInPlace(int colIdx) {
		if(colIdx < 0 || colIdx >= data.getColumnDimension())
			throw new IllegalArgumentException("col out of bounds");

		double[][] dataRef = data.getDataRef();
		data = new Array2DRowRealMatrix(MatUtils.sortDescByCol(dataRef, colIdx), false);
	}
	
	/**
	 * View the dataset in the console
	 */
	public void stdOut() {
		System.out.println(this.toString());
	}
	
	/**
	 * Write the dataset to a CSV
	 * @param header
	 * @throws IOException
	 */
	public void toFlatFile(boolean header, final File file) throws IOException {
		toFlatFile(header, file, ',');
	}
	
	/**
	 * Write the dataset to a flat file
	 * @param header
	 * @param sep
	 * @throws IOException
	 */
	public void toFlatFile(boolean header, final File file, char sep) throws IOException {
		synchronized(this) {
			boolean target = null != labels;
			
			int idx = 0, row_idx = 0;
			Object[] new_row;
			String[] output = new String[this.numRows() + (header?1:0)];
			
			/*
			 * If header, append.
			 */
			if(header) {
				new_row = new Object[this.headers.length + (target?1:0)];
				for(int i = 0; i < this.headers.length; i++) {
					new_row[i] = this.headers[i];
				}
				
				if(target) new_row[new_row.length - 1] = "target";
				output[idx++] = toString(new_row, sep);
			}
			
			/*
			 * Stringify data...
			 */
			for(double[] row: this.data.getData()) {
				new_row = new Object[this.headers.length + (target?1:0)];
				for(int i = 0; i < row.length; i++) {
					new_row[i] = row[i];
				}
				
				if(target) new_row[new_row.length - 1] = this.labels[row_idx++];
				output[idx++] = toString(new_row, sep);
			}
			
			/*
			 * Write the bytes...
			 */
			BufferedWriter bw = null;
			try {
				bw = new BufferedWriter(new FileWriter(file));
				
				String out, newline = System.getProperty("line.separator");
				for(int i = 0; i < output.length; i++) {
					out = output[i];
					bw.write(out);
					if(i!=output.length-1) bw.write(newline);
				}
			} finally {
				try {
					bw.close();
				} catch(IOException e) {
					// ignore...
				}
			}
		}
	}
	
	private static String toString(Object[] obj, char sep) {
		StringBuilder sb = new StringBuilder();
		for(int i = 0; i < obj.length; i++) {
			sb.append(obj[i]);
			if(i!=obj.length - 1) sb.append(sep);
		}
		
		return sb.toString();
	}

	@Override
	public String toString() {
		String ls = System.getProperty("line.separator");
		String lsls = ls + ls;
		
		StringBuilder sb = new StringBuilder();
		sb.append("Headers:" + ls);
		sb.append(Arrays.toString(headers) + lsls);
		
		sb.append("Data:");
		sb.append(formatter.format(data) + ls);
		
		sb.append("Labels:"+ls);
		sb.append(Arrays.toString(labels));
		
		return sb.toString();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy