All Downloads are FREE. Search and download functionalities are using the official Maven repository.

us.ihmc.convexOptimization.quadraticProgram.SimpleEfficientActiveSetQPSolverWithInactiveVariables Maven / Gradle / Ivy

There is a newer version: 0.17.22
Show newest version
package us.ihmc.convexOptimization.quadraticProgram;

import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;

import us.ihmc.commons.MathTools;
import us.ihmc.matrixlib.MatrixTools;
import us.ihmc.matrixlib.NativeMatrix;

public class SimpleEfficientActiveSetQPSolverWithInactiveVariables extends SimpleEfficientActiveSetQPSolver
      implements NativeActiveSetQPSolverWithInactiveVariablesInterface
{
   private final NativeMatrix originalQuadraticCostQMatrix = new NativeMatrix(0, 0);
   private final NativeMatrix originalQuadraticCostQVector = new NativeMatrix(0, 0);

   private final NativeMatrix originalLinearEqualityConstraintsAMatrix = new NativeMatrix(0, 0);
   private final NativeMatrix originalLinearEqualityConstraintsBVector = new NativeMatrix(0, 0);

   private final NativeMatrix originalLinearInequalityConstraintsCMatrixO = new NativeMatrix(0, 0);
   private final NativeMatrix originalLinearInequalityConstraintsDVectorO = new NativeMatrix(0, 0);

   private final NativeMatrix originalVariableLowerBounds = new NativeMatrix(0, 0);
   private final NativeMatrix originalVariableUpperBounds = new NativeMatrix(0, 0);

   private final DMatrixRMaj activeVariables = new DMatrixRMaj(0, 0);
   private final DMatrixRMaj activeVariableSolution = new DMatrixRMaj(0, 0);

   private void setMatricesFromOriginal()
   {
      quadraticCostQMatrix.set(originalQuadraticCostQMatrix);
      quadraticCostQVector.set(originalQuadraticCostQVector);

      linearEqualityConstraintsAMatrix.set(originalLinearEqualityConstraintsAMatrix);
      linearEqualityConstraintsBVector.set(originalLinearEqualityConstraintsBVector);

      linearInequalityConstraintsCMatrixO.set(originalLinearInequalityConstraintsCMatrixO);
      linearInequalityConstraintsDVectorO.set(originalLinearInequalityConstraintsDVectorO);

      variableLowerBounds.set(originalVariableLowerBounds);
      variableUpperBounds.set(originalVariableUpperBounds);
   }

   private void removeInactiveVariables()
   {
      setMatricesFromOriginal();

      for (int variableIndex = activeVariables.getNumRows() - 1; variableIndex >= 0; variableIndex--)
      {
         if (activeVariables.get(variableIndex) == 1.0)
            continue;

         quadraticCostQMatrix.removeRow(variableIndex);
         quadraticCostQMatrix.removeColumn(variableIndex);

         quadraticCostQVector.removeRow(variableIndex);

         if (linearEqualityConstraintsAMatrix.getNumElements() > 0)
            linearEqualityConstraintsAMatrix.removeColumn(variableIndex);
         if (linearInequalityConstraintsCMatrixO.getNumElements() > 0)
            linearInequalityConstraintsCMatrixO.removeColumn(variableIndex);

         if (variableLowerBounds.getNumElements() > 0)
            variableLowerBounds.removeRow(variableIndex);
         if (variableUpperBounds.getNumElements() > 0)
            variableUpperBounds.removeRow(variableIndex);
      }

      int numVars = quadraticCostQMatrix.getNumRows();
      if (linearEqualityConstraintsAMatrix.getNumElements() == 0)
         linearEqualityConstraintsAMatrix.reshape(0, numVars);
      if (linearInequalityConstraintsCMatrixO.getNumElements() == 0)
         linearInequalityConstraintsCMatrixO.reshape(0, numVars);

      removeZeroRowsFromConstraints(linearEqualityConstraintsAMatrix, linearEqualityConstraintsBVector);
      removeZeroRowsFromConstraints(linearInequalityConstraintsCMatrixO, linearInequalityConstraintsDVectorO);
   }

   private static void removeZeroRowsFromConstraints(NativeMatrix matrix, NativeMatrix vector)
   {
      for (int rowIndex = vector.getNumRows() - 1; rowIndex >= 0; rowIndex--)
      {
         double sumOfRowElements = 0.0;

         for (int columnIndex = 0; columnIndex < matrix.getNumCols(); columnIndex++)
         {
            sumOfRowElements += Math.abs(matrix.get(rowIndex, columnIndex));
         }

         boolean isZeroRow = MathTools.epsilonEquals(sumOfRowElements, 0.0, 1e-12);
         if (isZeroRow)
         {
            matrix.removeRow(rowIndex);
            vector.removeRow(rowIndex);
         }
      }
   }

   private void copyActiveVariableSolutionToAllVariables(DMatrix solutionToPack, DMatrixRMaj activeVariableSolution)
   {
      if (MatrixTools.containsNaN(activeVariableSolution))
      {
         for (int i = 0; i < solutionToPack.getNumRows(); i++)
            solutionToPack.set(i, 0, Double.NaN);
         return;
      }

      int activeVariableIndex = 0;
      for (int variableIndex = 0; variableIndex < solutionToPack.getNumRows(); variableIndex++)
      {
         if (activeVariables.get(variableIndex) != 1.0)
         {
            solutionToPack.set(variableIndex, 0, 0.0);
            continue;
         }

         solutionToPack.set(variableIndex, 0, activeVariableSolution.get(activeVariableIndex, 0));
         activeVariableIndex++;
      }
   }

   @Override
   public void setLowerBounds(DMatrix variableLowerBounds)
   {
      if (variableLowerBounds.getNumRows() != originalQuadraticCostQMatrix.getNumRows())
         throw new RuntimeException("variableLowerBounds.getNumRows() != quadraticCostQMatrix.getNumRows()");

      originalVariableLowerBounds.set(variableLowerBounds);
   }

   @Override
   public NativeMatrix getLowerBoundsUnsafe()
   {
      return originalVariableLowerBounds;
   }

   @Override
   public void setUpperBounds(DMatrix variableUpperBounds)
   {
      if (variableUpperBounds.getNumRows() != originalQuadraticCostQMatrix.getNumRows())
         throw new RuntimeException("variableUpperBounds.getNumRows() != quadraticCostQMatrix.getNumRows()");

      originalVariableUpperBounds.set(variableUpperBounds);
   }

   @Override
   public NativeMatrix getUpperBoundsUnsafe()
   {
      return originalVariableUpperBounds;
   }

   @Override
   public void setQuadraticCostFunction(DMatrix costQuadraticMatrix, DMatrix costLinearVector, double quadraticCostScalar)
   {
      if (costLinearVector.getNumCols() != 1)
         throw new RuntimeException("costLinearVector.getNumCols() != 1");
      if (costQuadraticMatrix.getNumRows() != costLinearVector.getNumRows())
         throw new RuntimeException("costQuadraticMatrix.getNumRows() != costLinearVector.getNumRows()");
      if (costQuadraticMatrix.getNumRows() != costQuadraticMatrix.getNumCols())
         throw new RuntimeException("costQuadraticMatrix.getNumRows() != costQuadraticMatrix.getNumCols()");

      this.costQuadraticMatrix.set(costQuadraticMatrix);

      originalQuadraticCostQMatrix.transpose(this.costQuadraticMatrix);
      originalQuadraticCostQMatrix.addEquals(this.costQuadraticMatrix);
      originalQuadraticCostQMatrix.scale(0.5);
      originalQuadraticCostQVector.set(costLinearVector);
      this.quadraticCostScalar = quadraticCostScalar;

      setAllVariablesActive();
   }

   @Override
   public NativeMatrix getCostHessianUnsafe()
   {
      return originalQuadraticCostQMatrix;
   }

   @Override
   public NativeMatrix getCostGradientUnsafe()
   {
      return originalQuadraticCostQVector;
   }

   @Override
   public double getObjectiveCost(DMatrixRMaj x)
   {
      nativexSolutionMatrix.set(x);
      computedObjectiveFunctionValue.multQuad(nativexSolutionMatrix, originalQuadraticCostQMatrix);
      computedObjectiveFunctionValue.scale(0.5);
      computedObjectiveFunctionValue.multAddTransA(originalQuadraticCostQVector, nativexSolutionMatrix);
      return computedObjectiveFunctionValue.get(0, 0) + quadraticCostScalar;
   }

   @Override
   public void setLinearEqualityConstraints(DMatrix linearEqualityConstraintsAMatrix, DMatrix linearEqualityConstraintsBVector)
   {
      if (linearEqualityConstraintsBVector.getNumCols() != 1)
         throw new RuntimeException("linearEqualityConstraintsBVector.getNumCols() != 1");
      if (linearEqualityConstraintsAMatrix.getNumRows() != linearEqualityConstraintsBVector.getNumRows())
         throw new RuntimeException("linearEqualityConstraintsAMatrix.getNumRows() != linearEqualityConstraintsBVector.getNumRows()");
      if (linearEqualityConstraintsAMatrix.getNumCols() != originalQuadraticCostQMatrix.getNumCols())
         throw new RuntimeException("linearEqualityConstraintsAMatrix.getNumCols() != quadraticCostQMatrix.getNumCols()");

      originalLinearEqualityConstraintsBVector.set(linearEqualityConstraintsBVector);
      originalLinearEqualityConstraintsAMatrix.set(linearEqualityConstraintsAMatrix);
   }

   @Override
   public NativeMatrix getAeqUnsafe()
   {
      return originalLinearEqualityConstraintsAMatrix;
   }

   @Override
   public NativeMatrix getBeqUnsafe()
   {
      return originalLinearEqualityConstraintsBVector;
   }

   @Override
   public void setLinearInequalityConstraints(DMatrix linearInequalityConstraintCMatrix, DMatrix linearInequalityConstraintDVector)
   {
      if (linearInequalityConstraintDVector.getNumCols() != 1)
         throw new RuntimeException("linearInequalityConstraintDVector.getNumCols() != 1");
      if (linearInequalityConstraintCMatrix.getNumRows() != linearInequalityConstraintDVector.getNumRows())
         throw new RuntimeException("linearInequalityConstraintCMatrix.getNumRows() != linearInequalityConstraintDVector.getNumRows()");
      if (linearInequalityConstraintCMatrix.getNumCols() != originalQuadraticCostQMatrix.getNumCols())
         throw new RuntimeException("linearInequalityConstraintCMatrix.getNumCols() != quadraticCostQMatrix.getNumCols()");

      originalLinearInequalityConstraintsDVectorO.set(linearInequalityConstraintDVector);
      originalLinearInequalityConstraintsCMatrixO.set(linearInequalityConstraintCMatrix);
   }

   @Override
   public NativeMatrix getAinUnsafe()
   {
      return originalLinearInequalityConstraintsCMatrixO;
   }

   @Override
   public NativeMatrix getBinUnsafe()
   {
      return originalLinearInequalityConstraintsDVectorO;
   }

   @Override
   public void setActiveVariables(DMatrix activeVariables)
   {
      if (activeVariables.getNumRows() != originalQuadraticCostQMatrix.getNumRows())
         throw new RuntimeException("activeVariables.getNumRows() != quadraticCostQMatrix.getNumRows()");

      this.activeVariables.set(activeVariables);
   }

   @Override
   public void setVariableActive(int variableIndex)
   {
      if (variableIndex < 0 || variableIndex >= originalQuadraticCostQMatrix.getNumRows())
         throw new RuntimeException("variable index is outside the number of variables: " + variableIndex);

      if (variableIndex >= activeVariables.getNumRows())
         return; // Any variable that is outside the activeVariables vector will be considered, nothing to do then.

      activeVariables.set(variableIndex, 0, 1.0);
   }

   @Override
   public void setVariableInactive(int variableIndex)
   {
      if (variableIndex < 0 || variableIndex >= originalQuadraticCostQMatrix.getNumRows())
         throw new RuntimeException("variable index is outside the number of variables: " + variableIndex);

      if (variableIndex >= activeVariables.getNumRows())
         activeVariables.reshape(variableIndex + 1, 1, true);

      activeVariables.set(variableIndex, 0, 0.0);
   }

   @Override
   public void setAllVariablesActive()
   {
      activeVariables.reshape(originalQuadraticCostQMatrix.getNumRows(), 1);
      CommonOps_DDRM.fill(activeVariables, 1.0);
   }

   @Override
   public void clear()
   {
      super.clear();
//
//      originalQuadraticCostQMatrix.reshape(0, 0);
//      originalQuadraticCostQVector.reshape(0, 0);
//
//      originalLinearEqualityConstraintsAMatrix.reshape(0, 0);
//      originalLinearEqualityConstraintsBVector.reshape(0, 0);
//
//      originalLinearInequalityConstraintsCMatrixO.reshape(0, 0);
//      originalLinearInequalityConstraintsDVectorO.reshape(0, 0);
//
//      originalVariableLowerBounds.reshape(0, 0);
//      originalVariableUpperBounds.reshape(0, 0);
//
//      activeVariables.reshape(0, 0);
//      activeVariableSolution.reshape(0, 0);
   }

   @Override
   public int solve(DMatrix solutionToPack)
   {
      removeInactiveVariables();

      activeVariableSolution.reshape(quadraticCostQVector.getNumRows(), 1);

      if (solutionToPack.getNumRows() != originalQuadraticCostQMatrix.getNumRows() || solutionToPack.getNumCols() != 1)
         throw new IllegalArgumentException("Invalid matrix dimensions.");

      int numberOfIterations = super.solve(activeVariableSolution);

      copyActiveVariableSolutionToAllVariables(solutionToPack, activeVariableSolution);

      return numberOfIterations;
   }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy