All Downloads are FREE. Search and download functionalities are using the official Maven repository.

us.ihmc.convexOptimization.quadraticProgram.BlockDiagSquareMatrix Maven / Gradle / Ivy

There is a newer version: 0.17.22
Show newest version
package us.ihmc.convexOptimization.quadraticProgram;

import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;

public class BlockDiagSquareMatrix extends DMatrixRMaj
{
   private static final long serialVersionUID = 8813856249678942997L;

   int[] blockSizes;
   int[] blockStarts;

   DMatrixRMaj[] tmpMatrix;

   public BlockDiagSquareMatrix(int... blockSizes)
   {
      super(0);
      this.blockSizes = blockSizes;
      blockStarts = new int[getNumBlocks() + 1];
      tmpMatrix = new DMatrixRMaj[getNumBlocks()];
      int matrixRows = 0;
      for (int i = 0; i < getNumBlocks(); i++)
      {
         tmpMatrix[i] = new DMatrixRMaj(blockSizes[i], blockSizes[i]);
         blockStarts[i] = matrixRows;
         matrixRows += blockSizes[i];
      }

      blockStarts[blockStarts.length - 1] = matrixRows;
      super.reshape(matrixRows, matrixRows);
   }

   public int getNumBlocks()
   {
      return blockSizes.length;
   }

   public void setBlock(DMatrixRMaj srcBlock, int blockId)
   {
      setBlock(srcBlock, blockId, this);
   }

   public void setBlock(DMatrixRMaj srcBlock, int blockId, DMatrixRMaj dstMatrix)
   {
      dstMatrix.reshape(numRows, numCols);
      int startIndex = blockStarts[blockId];
      CommonOps_DDRM.insert(srcBlock, dstMatrix, startIndex, startIndex);
   }

   public void packBlock(DMatrixRMaj dstBlock, int blockId, int destX0, int destY0)
   {
      int startIndex = blockStarts[blockId];
      int endIndex = blockStarts[blockId + 1];
      CommonOps_DDRM.extract(this, startIndex, endIndex, startIndex, endIndex, dstBlock, destX0, destY0);
   }

   public void packInverse(LinearSolverDense solver, BlockDiagSquareMatrix matrixToPack)
   {
      for (int i = 0; i < blockSizes.length; i++)
      {
         tmpMatrix[i].reshape(blockSizes[i], blockSizes[i]);
         packBlock(tmpMatrix[i], i, 0, 0);
         solver.setA(tmpMatrix[i]);
         solver.invert(tmpMatrix[i]);
         matrixToPack.setBlock(tmpMatrix[i], i);
      }
   }

   public void packInverse(LinearSolverDense solver, DMatrixRMaj matrixToPack)
   {
      matrixToPack.zero();
      for (int i = 0; i < blockSizes.length; i++)
      {
         tmpMatrix[i].reshape(blockSizes[i], blockSizes[i]);
         packBlock(tmpMatrix[i], i, 0, 0);
         solver.setA(tmpMatrix[i]);
         solver.invert(tmpMatrix[i]);
         setBlock(tmpMatrix[i], i, matrixToPack);
      }
   }

   /**
    * c = this*bT
    * 
    * @param b
    * @param c
    */
   DMatrixRMaj multTempB = new DMatrixRMaj(0);
   DMatrixRMaj multTempC = new DMatrixRMaj(0);

   public void multTransB(DMatrixRMaj b, DMatrixRMaj c)
   {
      for (int i = 0; i < blockSizes.length; i++)
      {
         for (int crow = blockStarts[i]; crow < blockStarts[i + 1]; crow++)
         {
            int aIndex0 = getIndex(crow, blockStarts[i]);
            for (int ccol = 0; ccol < c.numCols; ccol++)
            {
               double val = 0.0;
               int aIndex = aIndex0;
               int bIndex = b.getIndex(ccol, blockStarts[i]);
               int bEnd = bIndex + blockSizes[i];
               while (bIndex < bEnd)
                  val += data[aIndex++] * b.data[bIndex++];

               c.set(crow, ccol, val);
            }
            /*
             * tmpMatrix[i].reshape(blockSizes[i] , blockSizes[i]); packBlock(tmpMatrix[i], i, 0, 0);
             * multTempB.reshape(b.numRows, blockSizes[i]); multTempC.reshape(blockSizes[i], c.numCols);
             * CommonOps_DDRM.extract(b, 0, b.numRows, blockStarts[i], blockStarts[i+1], multTempB, 0, 0);
             * CommonOps_DDRM.multTransB(tmpMatrix[i], multTempB, multTempC); CommonOps_DDRM.insert(multTempC, c,
             * blockStarts[i], 0);
             */
         }
      }

   }

   /**
    * c = this*b
    * 
    * @param b
    * @param c
    */
   public void mult(double alpha, DMatrixRMaj b, DMatrixRMaj c)
   {
      for (int i = 0; i < blockSizes.length; i++)
      {
         tmpMatrix[i].reshape(blockSizes[i], blockSizes[i]);
         packBlock(tmpMatrix[i], i, 0, 0);
         multTempB.reshape(blockSizes[i], b.numCols);
         multTempC.reshape(blockSizes[i], c.numCols);
         CommonOps_DDRM.extract(b, blockStarts[i], blockStarts[i + 1], 0, b.numCols, multTempB, 0, 0);
         CommonOps_DDRM.mult(alpha, tmpMatrix[i], multTempB, multTempC);
         CommonOps_DDRM.insert(multTempC, c, blockStarts[i], 0);
      }
   }

   public static void main(String[] arg)
   {
      BlockDiagSquareMatrix m = new BlockDiagSquareMatrix(1, 2);
      DMatrixRMaj b1 = new DMatrixRMaj(1, 1, true, 1);
      DMatrixRMaj b2 = new DMatrixRMaj(2, 2, true, 2, 3, 4, 5);

      m.setBlock(b1, 0);
      m.setBlock(b2, 1);

      System.out.println(m);

      m.packInverse(LinearSolverFactory_DDRM.general(m.numRows, m.numCols), m);
      b1.zero();
      b2.zero();

      m.packBlock(b1, 0, 0, 0);
      m.packBlock(b2, 1, 0, 0);

      System.out.println(b1);
      System.out.println(b2);
      System.out.println("m=\n" + m);
   }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy