All Downloads are FREE. Search and download functionalities are using the official Maven repository.

us.ihmc.convexOptimization.linearProgram.DictionaryFormLinearProgramSolver Maven / Gradle / Ivy

There is a newer version: 0.17.22
Show newest version
package us.ihmc.convexOptimization.linearProgram;

import gnu.trove.list.array.TIntArrayList;
import org.apache.commons.math3.util.Precision;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.commons.time.Stopwatch;
import us.ihmc.convexOptimization.linearProgram.SolverStatistics.LinearProgramFailureReason;

import java.util.Arrays;

import static us.ihmc.convexOptimization.linearProgram.LinearProgramSolver.*;

/**
 * Solves a dictionary form LP using the criss-cross or simplex methods.
 * Simplex implementation borrows from org.apache.commons.math3.optim.linear.SimplexSolver and doi.org/10.3929/ethz-b-000426221 (ch. 4)
 */
public class DictionaryFormLinearProgramSolver
{
   static final boolean debug = false;

   static final int maxVariables = 200;
   static final int maxCrissCrossIterations = 50000;
   static final int maxSimplexIterations = 1000;
   static final int nullMatrixIndex = -1;
   static final double epsilon = 1e-6;
   static final double zeroCutoff = 1e-10;

   private final LinearProgramDictionary dictionary = new LinearProgramDictionary();
   private final DMatrixRMaj primalSolution = new DMatrixRMaj(maxVariables);
   private final DMatrixRMaj dualSolution = new DMatrixRMaj(maxVariables);

   private final Stopwatch timer = new Stopwatch();
   private final SolverStatistics simplexStatistics = new SolverStatistics();
   private final SolverStatistics crissCrossStatistics = new SolverStatistics();

   private enum SimplexPhase
   {
      PHASE_I, PHASE_II;

      int objectiveSize()
      {
         return this == PHASE_I ? 2 : 1;
      }
   }

   /////////////////////////////////////////////////////////////////////////////////////
   //////////////////////////////// SIMPLEX METHOD /////////////////////////////////////
   /////////////////////////////////////////////////////////////////////////////////////

   public void solveSimplex(DMatrixRMaj startingDictionary)
   {
      if (startingDictionary.getNumCols() > maxVariables)
      {
         throw new IllegalArgumentException("Simplex method has a maximum of " + maxVariables + " decision variables, " + startingDictionary.getNumCols() + " provided.");
      }

      timer.reset();
      simplexStatistics.clear();

      if (dictionary.initialize(startingDictionary, SolverMethod.SIMPLEX))
      {
         /* Phase I: compute feasible dictionary */
         performSimplexPhase(SimplexPhase.PHASE_I);

         if (!simplexStatistics.foundSolution())
         {
            return;
         }
         else if (dictionary.getEntry(0, 0) < -epsilon)
         {
            simplexStatistics.onSolverFailure(LinearProgramFailureReason.INVALID_PHASE_I_SOLUTION, true);
            return;
         }

         dictionary.dropPhaseIVariables();
      }

      /* Phase II: optimize feasible dictionary */
      performSimplexPhase(SimplexPhase.PHASE_II);

      packSolution(simplexStatistics);
      simplexStatistics.setSolveTime(timer.lapElapsed());
   }

   /* package private for testing */
   void performSimplexPhase(SimplexPhase phase)
   {
      while (true)
      {
         if (simplexStatistics.getAndIncrementIterations() > maxSimplexIterations)
         {
            simplexStatistics.onSolverFailure(LinearProgramFailureReason.MAX_ITERATIONS_REACHED, phase == SimplexPhase.PHASE_I);
            break;
         }

         if (isSimplexOptimal())
         {
            simplexStatistics.onSolutionFound();
            break;
         }

         int s = computeSimplexPivotColumn();
         int r = computeSimplexPivotRow(s, phase);

         if (r == nullMatrixIndex)
         {
            simplexStatistics.onSolverFailure(LinearProgramFailureReason.NO_CANDIDATE_PIVOT, phase == SimplexPhase.PHASE_I);
            break;
         }

         if (debug)
         {
            System.out.println("Pivoting on (" + r + "," + s + ")\n");
         }

         dictionary.performPivot(r, s);
      }
   }

   private void packSolution(SolverStatistics solverStatistics)
   {
      primalSolution.reshape(dictionary.getNumberOfColumns() - 1, 1);
      dualSolution.reshape(dictionary.getNumberOfRows() - 1, 1);

      Arrays.fill(primalSolution.getData(), 0.0);
      Arrays.fill(dualSolution.getData(), 0.0);

      double minDictionaryRHSColumnEntry = Double.POSITIVE_INFINITY;

      for (int i = 1; i < dictionary.getBasisSize(); i++)
      {
         int lexicalIndex = dictionary.getBasisIndex(i);
         double entry = dictionary.getEntry(i, 0);
         minDictionaryRHSColumnEntry = Math.min(entry, minDictionaryRHSColumnEntry);

         if (isNonNegativeConstraint(lexicalIndex, primalSolution.getNumRows()))
         {
            int variableIndex = toVariableIndex(lexicalIndex);
            primalSolution.set(variableIndex, entry);
         }
      }

      for (int i = 1; i < dictionary.getNonBasisSize(); i++)
      {
         int lexicalIndex = dictionary.getNonBasisIndex(i);
         double entry = dictionary.getEntry(0, i);

         if (!isNonNegativeConstraint(lexicalIndex, primalSolution.getNumRows()))
         {
            int constraintIndex = toConstraintIndex(lexicalIndex, primalSolution.getNumRows());
            dualSolution.set(constraintIndex, -entry);
         }
      }
      
      solverStatistics.setMinDictionaryRHSColumnEntry(minDictionaryRHSColumnEntry);
   }

   /* Checks optimality assuming feasibility, so only the objective row needs to be checked */
   private boolean isSimplexOptimal()
   {
      for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
      {
         if (dictionary.getEntry(0, j) > epsilon)
         {
            return false;
         }
      }

      return true;
   }

   // use Bland pivot rule, which finds the positive objective row entry with the lowest corresponding variable (lexical) index
   private int computeSimplexPivotColumn()
   {
      int minimumEntryIndex = Integer.MAX_VALUE;
      int column = nullMatrixIndex;

      for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
      {
         double entry = dictionary.getEntry(0, j);
         if (entry < epsilon)
         {
            continue;
         }

         int index = dictionary.getNonBasisIndex(j);
         if (index < minimumEntryIndex)
         {
            minimumEntryIndex = index;
            column = j;
         }
      }

      return column;
   }

   private final TIntArrayList minRatioIndices = new TIntArrayList(maxVariables + 1);

   private int computeSimplexPivotRow(int column, SimplexPhase phase)
   {
      double minRatio = Double.MAX_VALUE;
      minRatioIndices.reset();

      for (int i = phase.objectiveSize(); i < dictionary.getNumberOfRows(); i++)
      {
         double d_ig = dictionary.getEntry(i, 0);
         double d_is = dictionary.getEntry(i, column);

         if (d_is > -epsilon)
         {
            continue;
         }

         double ratio = Math.abs(d_ig / d_is);
         int cmp = Precision.compareTo(ratio, minRatio, epsilon);

         if (cmp == 0)
         {
            minRatioIndices.add(i);
         }
         else if (cmp < 0)
         {
            minRatioIndices.reset();
            minRatioIndices.add(i);
            minRatio = ratio;
         }
      }

      if (minRatioIndices.isEmpty())
      {
         return nullMatrixIndex;
      }
      else if (minRatioIndices.size() > 1)
      {
         // (from apache impl...)
         // apply Bland's rule to prevent cycling:
         //    take the row for which the corresponding basic variable has the smallest index

         int minRowIndex = Integer.MAX_VALUE;
         int minRow = nullMatrixIndex;

         for (int i = 0; i < minRatioIndices.size(); i++)
         {
            int variableIndex = dictionary.getBasisIndex(minRatioIndices.get(i));
            if (variableIndex < minRowIndex)
            {
               minRowIndex = variableIndex;
               minRow = minRatioIndices.get(i);
            }
         }

         return minRow;
      }
      else
      {
         return minRatioIndices.get(0);
      }
   }

   public DMatrixRMaj getPrimalSolution()
   {
      return primalSolution;
   }

   public DMatrixRMaj getDualSolution()
   {
      return dualSolution;
   }

   public void printSolution()
   {
      System.out.println("Solution:");
      for (int i = 0; i < primalSolution.getNumRows(); i++)
      {
         System.out.println("\t " + primalSolution.get(i));
      }
   }

   /////////////////////////////////////////////////////////////////////////////////////
   //////////////////////////////// CRISS CROSS METHOD /////////////////////////////////
   /////////////////////////////////////////////////////////////////////////////////////

   public void solveCrissCross(DMatrixRMaj startingDictionary)
   {
      crissCrossStatistics.clear();
      timer.reset();

      dictionary.initialize(startingDictionary, SolverMethod.CRISS_CROSS);

      while (true)
      {
         if (crissCrossStatistics.getAndIncrementIterations() > maxCrissCrossIterations)
         {
            crissCrossStatistics.setFoundSolution(false);
            break;
         }

         int candidateBasisPivot = findNegativeColumnEntryWithBlandRule(0);
         int candidateNonBasisPivot = findPositiveRowEntryWithBlandRule(0);

         int basisPivot, nonBasisPivot;
         if (candidateBasisPivot == nullMatrixIndex && candidateNonBasisPivot == nullMatrixIndex)
         {
            crissCrossStatistics.onSolutionFound();
            break;
         }
         else if (candidateBasisPivot != nullMatrixIndex && (candidateNonBasisPivot == nullMatrixIndex
                                                             || dictionary.getBasisIndex(candidateBasisPivot) < dictionary.getNonBasisIndex(candidateNonBasisPivot)))
         {
            basisPivot = candidateBasisPivot;
            nonBasisPivot = findPositiveRowEntryWithBlandRule(basisPivot);

            if (nonBasisPivot == nullMatrixIndex)
            {
               // inconsistent
               break;
            }
         }
         else
         {
            nonBasisPivot = candidateNonBasisPivot;
            basisPivot = findNegativeColumnEntryWithBlandRule(nonBasisPivot);

            if (basisPivot == nullMatrixIndex)
            {
               // dual inconsistent
               break;
            }
         }

         dictionary.performPivot(basisPivot, nonBasisPivot);
      }

      packSolution(crissCrossStatistics);
      crissCrossStatistics.setSolveTime(timer.totalElapsed());
   }

   private int findNegativeColumnEntryWithBlandRule(int column)
   {
      int minLexicalIndex = Integer.MAX_VALUE;
      int row = nullMatrixIndex;

      for (int i = 1; i < dictionary.getNumberOfRows(); i++)
      {
         double d_ig = dictionary.getEntry(i, column);
         int lexicalIndex = dictionary.getBasisIndex(i);

         if (d_ig < -epsilon && lexicalIndex < minLexicalIndex)
         {
            minLexicalIndex = lexicalIndex;
            row = i;
         }
      }

      return row;
   }

   private int findPositiveRowEntryWithBlandRule(int row)
   {
      int minLexicalIndex = Integer.MAX_VALUE;
      int column = nullMatrixIndex;

      for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
      {
         double d_fj = dictionary.getEntry(row, j);
         int lexicalIndex = dictionary.getNonBasisIndex(j);

         if (d_fj > epsilon && lexicalIndex < minLexicalIndex)
         {
            minLexicalIndex = lexicalIndex;
            column = j;
         }
      }

      return column;
   }

   public SolverStatistics getCrissCrossStatistics()
   {
      return crissCrossStatistics;
   }

   public SolverStatistics getSimplexStatistics()
   {
      return simplexStatistics;
   }

   public TIntArrayList getBasisIndices()
   {
      return dictionary.getBasisIndices();
   }

   public TIntArrayList getNonBasisIndices()
   {
      return dictionary.getNonBasisIndices();
   }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy