All Downloads are FREE. Search and download functionalities are using the official Maven repository.

us.ihmc.scs2.simulation.physicsEngine.YoMatrix Maven / Gradle / Ivy

There is a newer version: 17-0.28.3
Show newest version
package us.ihmc.scs2.simulation.physicsEngine;

import org.ejml.data.DMatrix;
import org.ejml.data.Matrix;
import org.ejml.data.MatrixType;
import org.ejml.data.ReshapeMatrix;
import org.ejml.ops.MatrixIO;

import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

public class YoMatrix implements DMatrix, ReshapeMatrix
{
   // TODO: eventually consolidate YoMatrix implementations

   private static final long serialVersionUID = 2156411740647948028L;

   private final int maxNumberOfRows, maxNumberOfColumns;

   private final YoInteger numberOfRows, numberOfColumns;
   private final YoDouble[][] variables;

   public YoMatrix(String name, int maxNumberOfRows, int maxNumberOfColumns, YoRegistry registry)
   {
      this(name, null, maxNumberOfRows, maxNumberOfColumns, null, null, registry);
   }

   public YoMatrix(String name, int maxNumberOfRows, int maxNumberOfColumns, String[] rowNames, YoRegistry registry)
   {
      this(name, null, maxNumberOfRows, maxNumberOfColumns, rowNames, null, registry);
   }

   public YoMatrix(String name, int maxNumberOfRows, int maxNumberOfColumns, String[] rowNames, String[] columnNames, YoRegistry registry)
   {
      this(name, null, maxNumberOfRows, maxNumberOfColumns, rowNames, columnNames, registry);
   }

   public YoMatrix(String name, String description, int maxNumberOfRows, int maxNumberOfColumns, YoRegistry registry)
   {
      this(name, description, maxNumberOfRows, maxNumberOfColumns, null, null, registry);
   }

   public YoMatrix(String name, String description, int maxNumberOfRows, int maxNumberOfColumns, String[] rowNames, YoRegistry registry)
   {
      this(name, description, maxNumberOfRows, maxNumberOfColumns, rowNames, null, registry);
   }

   public YoMatrix(String name, String description, int maxNumberOfRows, int maxNumberOfColumns, String[] rowNames, String[] columnNames, YoRegistry registry)
   {
      this.maxNumberOfRows = maxNumberOfRows;
      this.maxNumberOfColumns = maxNumberOfColumns;

      this.numberOfRows = new YoInteger(name + "NumRows", registry);
      this.numberOfColumns = new YoInteger(name + "NumCols", registry);

      this.numberOfRows.set(maxNumberOfRows);
      this.numberOfColumns.set(maxNumberOfColumns);

      variables = new YoDouble[maxNumberOfRows][maxNumberOfColumns];

      for (int row = 0; row < maxNumberOfRows; row++)
      {
         for (int column = 0; column < maxNumberOfColumns; column++)
         {
            switch (checkNames(rowNames, columnNames))
            {
               case NONE:
               {
                  variables[row][column] = new YoDouble(name + "_" + row + "_" + column, description, registry);  // names are simply the row and column indices
                  break;
               }
               case ROWS:
               {
                  if (maxNumberOfColumns > 1)
                     throw new IllegalArgumentException("The YoMatrix must be a column vector if only row names are provided, else unique names cannot be generated.");

                  variables[row][column] = new YoDouble(name + "_" + rowNames[row], description, registry);  // names are the row names, no column identifier
                  break;
               }
               case ROWS_AND_COLUMNS:
               {
                  variables[row][column] = new YoDouble(name + "_" + rowNames[row] + "_" + columnNames[column], description, registry);  // names are the row and column names
                  break;
               }
            }
         }
      }
   }

   private enum NamesProvided
   {
      NONE, ROWS, ROWS_AND_COLUMNS
   }

   private NamesProvided checkNames(String[] rowNames, String[] columnNames)
   {
      if (rowNames == null && columnNames == null)
         return NamesProvided.NONE;
      else if (rowNames != null && columnNames == null)
         return NamesProvided.ROWS;
      else
         return NamesProvided.ROWS_AND_COLUMNS;
   }

   @Override
   public void reshape(int numRows, int numCols)
   {
      if (numRows > maxNumberOfRows)
         throw new IllegalArgumentException("Too many rows. Expected less or equal to " + maxNumberOfRows + ", was " + numRows);
      else if (numCols > maxNumberOfColumns)
         throw new IllegalArgumentException("Too many columns. Expected less or equal to " + maxNumberOfColumns + ", was " + numCols);
      else if (numRows < 0 || numCols < 0)
         throw new IllegalArgumentException("Cannot reshape with a negative number of rows or columns.");

      numberOfRows.set(numRows);
      numberOfColumns.set(numCols);

      for (int row = 0; row < numRows; row++)
      {
         for (int col = numCols; col < maxNumberOfColumns; col++)
         {
            unsafe_set(row, col, Double.NaN);
         }
      }

      for (int row = numRows; row < maxNumberOfRows; row++)
      {
         for (int col = 0; col < maxNumberOfColumns; col++)
         {
            unsafe_set(row, col, Double.NaN);
         }
      }
   }

   @Override
   public void set(Matrix original)
   {
      if (original instanceof DMatrix)
      {
         DMatrix otherMatrix = (DMatrix) original;
         reshape(otherMatrix.getNumRows(), otherMatrix.getNumRows());
         for (int row = 0; row < getNumRows(); row++)
         {
            for (int col = 0; col < getNumCols(); col++)
            {
               set(row, col, otherMatrix.unsafe_get(row, col));
            }
         }
      }
   }

   @Override
   public void set(int row, int col, double val)
   {
      if (col < 0 || col >= getNumCols() || row < 0 || row >= getNumRows())
         throw new IllegalArgumentException("Specified element is out of bounds: (" + row + " , " + col + ")");
      unsafe_set(row, col, val);
   }

   @Override
   public void unsafe_set(int row, int col, double val)
   {
      variables[row][col].set(val);
   }

   @Override
   public double get(int row, int col)
   {
      if (col < 0 || col >= getNumCols() || row < 0 || row >= getNumRows())
         throw new IllegalArgumentException("Specified element is out of bounds: (" + row + " , " + col + ")");
      return unsafe_get(row, col);
   }

   @Override
   public double unsafe_get(int row, int col)
   {
      return variables[row][col].getValue();
   }

   @Override
   public void zero()
   {
      for (int row = 0; row < getNumRows(); row++)
      {
         for (int col = 0; col < getNumCols(); col++)
         {
            variables[row][col].set(0.0);
         }
      }
   }

   public void setToNaN(int numRows, int numCols)
   {
      reshape(numRows, numCols);
      for (int row = 0; row < numRows; row++)
      {
         for (int col = 0; col < numCols; col++)
         {
            unsafe_set(row, col, Double.NaN);
         }
      }
   }

   public void add(DMatrix otherMatrix)
   {
      if (otherMatrix.getNumRows() != getNumRows() || otherMatrix.getNumCols() != getNumCols())
         throw new IllegalArgumentException("Incompatible matrix sizes.");
      for (int row = 0; row < getNumRows(); row++)
      {
         for (int col = 0; col < getNumCols(); col++)
         {
            unsafe_set(row, col, unsafe_get(row, col) + otherMatrix.unsafe_get(row, col));
         }
      }
   }

   @Override
   public int getNumRows()
   {
      return numberOfRows.getValue();
   }

   @Override
   public int getNumCols()
   {
      return numberOfColumns.getValue();
   }

   @Override
   public int getNumElements()
   {
      return getNumRows() * getNumCols();
   }

   @Override
   public MatrixType getType()
   {
      return MatrixType.UNSPECIFIED;
   }

   @Override
   public void print()
   {
      MatrixIO.printFancy(System.out, this, MatrixIO.DEFAULT_LENGTH);
   }

   @Override
   public void print(String format)
   {
      MatrixIO.print(System.out, this, format);
   }

   @Override
   public  T createLike()
   {
      throw new UnsupportedOperationException();
   }

   @Override
   public  T create(int numRows, int numCols)
   {
      throw new UnsupportedOperationException();
   }

   @Override
   public  T copy()
   {
      throw new UnsupportedOperationException();
   }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy