All Downloads are FREE. Search and download functionalities are using the official Maven repository.

umcg.genetica.math.matrix2.DoubleMatrixDataset Maven / Gradle / Ivy

There is a newer version: 1.0.7
Show newest version
/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package umcg.genetica.math.matrix2;

import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseLargeDoubleMatrix2D;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import umcg.genetica.io.text.TextFile;

/**
 *
 * @author MarcJan
 */
public class DoubleMatrixDataset {

    static final IOException doubleMatrixDatasetNonUniqueHeaderException = new IOException("Tried to use a non-unique header set in an identifier HashMap");
    static final Logger LOGGER = Logger.getLogger(DoubleMatrixDataset.class.getName());
    protected DoubleMatrix2D matrix;
    protected LinkedHashMap hashRows;
    protected LinkedHashMap hashCols;

    public DoubleMatrixDataset() {
        hashRows = new LinkedHashMap();
        hashCols = new LinkedHashMap();
    }

    public DoubleMatrixDataset(int rows, int columns) {
        hashRows = new LinkedHashMap((int) Math.ceil(rows / 0.75));
        hashCols = new LinkedHashMap((int) Math.ceil(columns / 0.75));
        if ((rows * (long) columns) < (Integer.MAX_VALUE - 2)) {
            matrix = new DenseDoubleMatrix2D(rows, columns);
        } else {
            matrix = new DenseLargeDoubleMatrix2D(rows, columns);
        }
    }

    public DoubleMatrixDataset(LinkedHashMap hashRows, LinkedHashMap hashCols) {
        this.hashRows = hashRows;
        this.hashCols = hashCols;
        if ((hashRows.size() * (long) hashCols.size()) < (Integer.MAX_VALUE - 2)) {
            matrix = new DenseDoubleMatrix2D(hashRows.size(), hashCols.size());
        } else {
            matrix = new DenseLargeDoubleMatrix2D(hashRows.size(), hashCols.size());
        }
    }

    public DoubleMatrixDataset(DoubleMatrix2D matrix, LinkedHashMap hashRows, LinkedHashMap hashCols) {
        this.hashRows = hashRows;
        this.hashCols = hashCols;
        this.matrix = matrix;
    }

    public DoubleMatrixDataset(List rowNames, List colNames) {

        hashRows = new LinkedHashMap(rowNames.size());
        hashCols = new LinkedHashMap(colNames.size());

        int i = 0;
        for (R row : rowNames) {
            hashRows.put(row, i);
            ++i;
        }

        i = 0;
        for (C col : colNames) {
            hashCols.put(col, i);
        }

        if ((hashRows.size() * (long) hashCols.size()) < (Integer.MAX_VALUE - 2)) {
            matrix = new DenseDoubleMatrix2D(hashRows.size(), hashCols.size());
        } else {
            matrix = new DenseLargeDoubleMatrix2D(hashRows.size(), hashCols.size());
        }

    }

    public static DoubleMatrixDataset loadDoubleData(String fileName) throws IOException {
        if ((fileName.endsWith(".txt") || fileName.endsWith(".tsv") || fileName.endsWith(".txt.gz"))) {
            return loadDoubleTextData(fileName, "\t");
        } else if (fileName.endsWith(".binary")) {
            return null;
        } else {
            throw new IllegalArgumentException("File type must be \".txt\", \".tsv\" or \".txt.gz\" when delimiter is set to: \"tab\" \n Input filename: " + fileName);
        }
    }

    public static DoubleMatrixDataset loadDoubleTextData(String fileName, String delimiter) throws IOException {
        if (!(fileName.endsWith(".txt") || fileName.endsWith(".tsv") || fileName.endsWith(".txt.gz"))) {
            throw new IllegalArgumentException("File type must be \".txt\", \".tsv\" or \".txt.gz\" when delimiter is set. \n Input filename: " + fileName);
        }

        Pattern splitPatern = Pattern.compile(delimiter);

        int columnOffset = 1;

        TextFile in = new TextFile(fileName, TextFile.R);
        String str = in.readLine(); // header
        String[] data = splitPatern.split(str);

        int tmpCols = (data.length - columnOffset);

        LinkedHashMap colMap = new LinkedHashMap((int) Math.ceil(tmpCols / 0.75));

        for (int s = 0; s < tmpCols; s++) {
            String colName = data[s + columnOffset];
            if (!colMap.containsKey(colName)) {
                colMap.put(colName, s);
            } else {
                LOGGER.warning("Duplicated column name!");
                throw (doubleMatrixDatasetNonUniqueHeaderException);
            }
        }

        int tmpRows = 0;

        while (in.readLine() != null) {
            tmpRows++;
        }
        in.close();

        DoubleMatrixDataset dataset;

        if ((tmpRows * (long) tmpCols) < (Integer.MAX_VALUE - 2)) {
            LinkedHashMap rowMap = new LinkedHashMap((int) Math.ceil(tmpRows / 0.75));
            DenseDoubleMatrix2D tmpMatrix = new DenseDoubleMatrix2D(tmpRows, tmpCols);

            in.open();
            in.readLine(); // read header
            int row = 0;

            boolean correctData = true;
            while ((str = in.readLine()) != null) {
                data = splitPatern.split(str);

                if (!rowMap.containsKey(data[0])) {
                    rowMap.put(data[0], row);
                    for (int s = 0; s < tmpCols; s++) {
                        double d;
                        try {
                            d = Double.parseDouble(data[s + columnOffset]);
                        } catch (NumberFormatException e) {
                            correctData = false;
                            d = Double.NaN;
                        }
                        tmpMatrix.setQuick(row, s, d);
                    }
                    row++;
                } else {
                    LOGGER.warning("Duplicated row name!");
                    throw (doubleMatrixDatasetNonUniqueHeaderException);
                }

            }
            if (!correctData) {
                LOGGER.warning("Your data contains NaN/unparseable values!");
            }
            in.close();

            dataset = new DoubleMatrixDataset(tmpMatrix, rowMap, colMap);

        } else {
            LinkedHashMap rowMap = new LinkedHashMap((int) Math.ceil(tmpRows / 0.75));
            DenseLargeDoubleMatrix2D tmpMatrix = new DenseLargeDoubleMatrix2D(tmpRows, tmpCols);
            in.open();
            in.readLine(); // read header
            int row = 0;

            boolean correctData = true;
            while ((str = in.readLine()) != null) {
                data = splitPatern.split(str);

                if (!rowMap.containsKey(data[0])) {
                    rowMap.put(data[0], row);
                    for (int s = 0; s < tmpCols; s++) {
                        double d;
                        try {
                            d = Double.parseDouble(data[s + columnOffset]);
                        } catch (NumberFormatException e) {
                            correctData = false;
                            d = Double.NaN;
                        }
                        tmpMatrix.setQuick(row, s, d);
                    }
                    row++;
                } else {
                    LOGGER.warning("Duplicated row name!");
                    throw (doubleMatrixDatasetNonUniqueHeaderException);
                }

            }
            if (!correctData) {
                LOGGER.warning("Your data contains NaN/unparseable values!");
            }
            in.close();
            dataset = new DoubleMatrixDataset(tmpMatrix, rowMap, colMap);
        }

        LOGGER.log(Level.INFO, "''{0}'' has been loaded, nrRows: {1} nrCols: {2}", new Object[]{fileName, dataset.matrix.rows(), dataset.matrix.columns()});
        return dataset;
    }

    public static DoubleMatrixDataset loadSubsetOfTextDoubleData(String fileName, String delimiter, HashSet desiredRows, HashSet desiredCols) throws IOException {
        if (!(fileName.endsWith(".txt") || fileName.endsWith(".txt.gz"))) {
            throw new IllegalArgumentException("File type must be .txt when delimiter is given (given filename: " + fileName + ")");
        }

        LinkedHashSet desiredColPos = new LinkedHashSet();

        Pattern splitPatern = Pattern.compile(delimiter);

        int columnOffset = 1;

        TextFile in = new TextFile(fileName, TextFile.R);
        String str = in.readLine(); // header
        String[] data = splitPatern.split(str);

        int tmpCols = (data.length - columnOffset);

        LinkedHashMap colMap = new LinkedHashMap((int) Math.ceil(tmpCols / 0.75));

        int storedCols = 0;
        for (int s = 0; s < tmpCols; s++) {
            String colName = data[s + columnOffset];
            if (!colMap.containsKey(colName) && (desiredCols == null || desiredCols.contains(colName) || desiredCols.isEmpty())) {
                colMap.put(colName, storedCols);
                desiredColPos.add((s));
                storedCols++;
            } else if (colMap.containsKey(colName)) {
                LOGGER.warning("Duplicated column name!");
                System.out.println("Tried to add: " + colName);
                throw (doubleMatrixDatasetNonUniqueHeaderException);
            }
        }

        LinkedHashSet desiredRowPos = new LinkedHashSet();
        int rowsToStore = 0;
        int totalRows = 0;
        //System.out.println(desiredRows.toString());
        while ((str = in.readLine()) != null) {
            String[] info = splitPatern.split(str);
            if (desiredRows == null || desiredRows.contains(info[0]) || desiredRows.isEmpty()) {
                rowsToStore++;
                desiredRowPos.add(totalRows);
            }
            totalRows++;
        }
        in.close();

        DoubleMatrixDataset dataset;

        if ((rowsToStore * (long) tmpCols) < (Integer.MAX_VALUE - 2)) {
            DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(rowsToStore, storedCols);
            in.open();
            in.readLine(); // read header
            int storingRow = 0;
            totalRows = 0;
            LinkedHashMap rowMap = new LinkedHashMap((int) Math.ceil(rowsToStore / 0.75));

            boolean correctData = true;
            while ((str = in.readLine()) != null) {

                if (desiredRowPos.contains(totalRows)) {
                    data = splitPatern.split(str);
                    if (!rowMap.containsKey(data[0])) {
                        rowMap.put(data[0], storingRow);
                        int storingCol = 0;
                        for (int s : desiredColPos) {
                            double d;
                            try {
                                d = Double.parseDouble(data[s + columnOffset]);
                            } catch (NumberFormatException e) {
                                correctData = false;
                                d = Double.NaN;
                            }
                            matrix.setQuick(storingRow, storingCol, d);
                            storingCol++;
                        }
                        storingRow++;
                    } else if (rowMap.containsKey(data[0])) {
                        LOGGER.warning("Duplicated row name!");
                        System.out.println("Tried to add: " + data[0]);
                        throw (doubleMatrixDatasetNonUniqueHeaderException);
                    }
                }
                totalRows++;
            }
            if (!correctData) {
                LOGGER.warning("Your data contains NaN/unparseable values!");
            }
            in.close();

            dataset = new DoubleMatrixDataset(matrix, rowMap, colMap);

        } else {
            DenseLargeDoubleMatrix2D matrix = new DenseLargeDoubleMatrix2D(rowsToStore, storedCols);

            in.open();
            in.readLine(); // read header
            int storingRow = 0;
            totalRows = 0;
            LinkedHashMap rowMap = new LinkedHashMap((int) Math.ceil(rowsToStore / 0.75));

            boolean correctData = true;
            while ((str = in.readLine()) != null) {
                if (desiredRowPos.contains(totalRows)) {
                    data = splitPatern.split(str);
                    if (!rowMap.containsKey(data[0])) {
                        rowMap.put(data[0], storingRow);
                        int storingCol = 0;
                        for (int s : desiredColPos) {
                            double d;
                            try {
                                d = Double.parseDouble(data[s + columnOffset]);
                            } catch (NumberFormatException e) {
                                correctData = false;
                                d = Double.NaN;
                            }
                            matrix.setQuick(storingRow, storingCol, d);
                            storingCol++;
                        }
                        storingRow++;
                    } else if (rowMap.containsKey(data[0])) {
                        LOGGER.warning("Duplicated row name!");
                        System.out.println("Tried to add: " + data[0]);
                        throw (doubleMatrixDatasetNonUniqueHeaderException);
                    }
                }
                totalRows++;
            }
            if (!correctData) {
                LOGGER.warning("Your data contains NaN/unparseable values!");
            }
            in.close();

            dataset = new DoubleMatrixDataset(matrix, rowMap, colMap);
        }

        LOGGER.log(Level.INFO, "''{0}'' has been loaded, nrRows: {1} nrCols: {2}", new Object[]{fileName, dataset.matrix.rows(), dataset.matrix.columns()});
        return dataset;
    }

    public void save(File file) throws IOException {
        TextFile out = new TextFile(file, TextFile.W);

        out.append('-');
        for (C col : hashCols.keySet()) {

            out.append('\t');
            out.append(col.toString());
        }
        out.append('\n');
        int r = 0;
        for (R row : hashRows.keySet()) {
            out.append(row.toString());
            for (int c = 0; c < matrix.columns(); c++) {
                out.append('\t');
                out.append(String.valueOf(matrix.getQuick(r, c)));
            }
            out.append('\n');
            ++r;
        }
        out.close();
    }

    public void save(String fileName) throws IOException {
        save(new File(fileName));
    }

    public void saveDice(String fileName) throws IOException {
        TextFile out = new TextFile(fileName, TextFile.W);

        out.append('-');
        for (R row : hashRows.keySet()) {
            out.append('\t');
            out.append(row.toString());
        }
        out.append('\n');

        int c = 0;
        for (C col : hashCols.keySet()) {
            out.append(col.toString());
            for (int r = 0; r < matrix.rows(); r++) {

                out.append('\t');
                out.append(String.valueOf(matrix.getQuick(r, c)));
            }
            out.append('\n');
            ++c;
        }
        out.close();
    }

    //Getters and setters
    public int rows() {
        return matrix.rows();
    }

    public int columns() {
        return matrix.columns();
    }

    public LinkedHashMap getHashRows() {
        return hashRows;
    }

    public void setHashRows(LinkedHashMap hashRows) {
        this.hashRows = hashRows;
    }

    public LinkedHashMap getHashCols() {
        return hashCols;
    }

    public void setHashCols(LinkedHashMap hashCols) {
        this.hashCols = hashCols;
    }

    public ArrayList getRowObjects() {
        return new ArrayList(hashRows.keySet());
    }

    public void setRowObjects(List arrayList) throws Exception {
        LinkedHashMap newHashRows = new LinkedHashMap((int) Math.ceil(arrayList.size() / 0.75));
        int i = 0;
        for (R s : arrayList) {
            if (!newHashRows.containsKey(s)) {
                newHashRows.put(s, i);
            } else {
                System.out.println("Error, new row names contains dupilcates.");
                throw (doubleMatrixDatasetNonUniqueHeaderException);
            }
            i++;
        }

        this.hashRows = newHashRows;
    }

    public ArrayList getColObjects() {
        return new ArrayList(hashCols.keySet());
    }

    public void setColObjects(List arrayList) throws Exception {
        LinkedHashMap newHashCols = new LinkedHashMap((int) Math.ceil(arrayList.size() / 0.75));
        int i = 0;
        for (C s : arrayList) {
            if (!newHashCols.containsKey(s)) {
                newHashCols.put(s, i);
            } else {
                System.out.println("Error, new column names contains dupilcates.");
                throw (doubleMatrixDatasetNonUniqueHeaderException);
            }
            i++;
        }
        this.hashCols = newHashCols;
    }

    public DoubleMatrix2D getMatrix() {
        return matrix;
    }

    public void setMatrix(DoubleMatrix2D matrix) {
        this.matrix = matrix;
    }

    public void setMatrix(double[][] matrix) {
        if ((matrix.length * (long) matrix[0].length) < (Integer.MAX_VALUE - 2)) {
            this.matrix = new DenseDoubleMatrix2D(matrix);
        } else {
            this.matrix = new DenseLargeDoubleMatrix2D(matrix.length, matrix[0].length);
            this.matrix.assign(matrix);
        }
    }

    /**
     * Order columns
     *
     * @param dataset DoubleMatrixDataset Expression matrix
     */
    public void OrderOnColumnnames() {
        LinkedHashMap newColHash = new LinkedHashMap((int) Math.ceil(this.matrix.columns() / 0.75));
        ArrayList names = this.getColObjects();
        Collections.sort(names);

        int pos = 0;
        for (C name : names) {
            newColHash.put(name, pos);
            pos++;
        }
        reorderCols(newColHash);
    }

    /**
     * Order rows
     *
     * @param dataset DoubleMatrixDataset Expression matrix
     */
    public void OrderOnRownames() {
        LinkedHashMap newRowHash = new LinkedHashMap((int) Math.ceil(this.matrix.rows() / 0.75));
        ArrayList names = this.getRowObjects();
        Collections.sort(names);

        int pos = -1;
        for (R name : names) {
            pos++;
            newRowHash.put(name, pos);
        }
        reorderRows(newRowHash);

    }

    public void reorderRows(LinkedHashMap mappingIndex) {
        boolean equal = compareHashRows(mappingIndex, this.hashRows);
        if (!equal) {
            DoubleMatrix2D newRawData;
            if ((this.rows() * (long) this.columns()) < (Integer.MAX_VALUE - 2)) {
                newRawData = new DenseDoubleMatrix2D(this.rows(), this.columns());
            } else {
                newRawData = new DenseLargeDoubleMatrix2D(this.rows(), this.columns());
            }

            for (Map.Entry ent : mappingIndex.entrySet()) {
                int pos = this.getHashRows().get(ent.getKey());
                for (int s = 0; s < this.columns(); ++s) {
                    newRawData.set(ent.getValue(), s, this.getMatrix().get(pos, s));
                }
            }
            this.setHashRows(mappingIndex);
            this.setMatrix(newRawData);
        }

    }

    public void reorderCols(LinkedHashMap mappingIndex) {
        boolean equal = compareHashCols(mappingIndex, this.hashCols);
        if (!equal) {
            DoubleMatrix2D newRawData;
            if ((this.rows() * (long) this.columns()) < (Integer.MAX_VALUE - 2)) {
                newRawData = new DenseDoubleMatrix2D(this.rows(), this.columns());
            } else {
                newRawData = new DenseLargeDoubleMatrix2D(this.rows(), this.columns());
            }

            for (Map.Entry ent : mappingIndex.entrySet()) {
                int pos = this.getHashCols().get(ent.getKey());
                for (int p = 0; p < this.rows(); ++p) {
                    newRawData.set(p, ent.getValue(), this.getMatrix().get(p, pos));
                }
            }

            this.setHashCols(mappingIndex);
            this.setMatrix(newRawData);
        }
    }

    public DoubleMatrixDataset viewDice() {
        return new DoubleMatrixDataset(matrix.viewDice(), hashCols, hashRows);
    }

    private boolean compareHashCols(LinkedHashMap mappingIndex, LinkedHashMap originalHash) {

        for (Entry entry : mappingIndex.entrySet()) {
            if (entry.getValue() != originalHash.get(entry.getKey())) {
                return false;
            }
        }
        return true;
    }

    private boolean compareHashRows(LinkedHashMap mappingIndex, LinkedHashMap originalHash) {

        for (Entry entry : mappingIndex.entrySet()) {
            if (entry.getValue() != originalHash.get(entry.getKey())) {
                return false;
            }
        }
        return true;
    }

    /**
     * Set a element of the dataset.
     *
     * @param rowName
     * @param columnName
     * @param value
     * @return
     */
    public void setElement(R rowName, C columnName, double value) {

        Integer row = hashRows.get(rowName);
        Integer column = hashCols.get(columnName);

        if (row != null && column != null) {
            matrix.setQuick(row, column, value);
        } else {
            if (row == null) {
                throw new NoSuchElementException("Row not found: " + rowName.toString());
            } else {
                throw new NoSuchElementException("Column not found: " + columnName.toString());
            }

        }

    }

    /**
     * Get specific element.
     *
     * @param rowName
     * @param columnName
     * @return
     */
    public double getElement(R rowName, C columnName) {

        Integer row = hashRows.get(rowName);
        Integer column = hashCols.get(columnName);

        if (row != null && column != null) {
            return matrix.getQuick(row, column);
        } else {
            if (row == null) {
                throw new NoSuchElementException("Row not found: " + rowName.toString());
            } else {
                throw new NoSuchElementException("Column not found: " + columnName.toString());
            }
        }
    }

    /**
     * Get specific element.
     *
     * @param row
     * @param column
     * @return
     */
    public double getElement(int row, int column) {

        return matrix.get(row, column);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy