us.ihmc.convexOptimization.linearProgram.DictionaryFormLinearProgramSolver Maven / Gradle / Ivy
package us.ihmc.convexOptimization.linearProgram;
import gnu.trove.list.array.TIntArrayList;
import org.apache.commons.math3.util.Precision;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.commons.time.Stopwatch;
import us.ihmc.convexOptimization.linearProgram.SolverStatistics.LinearProgramFailureReason;
import java.util.Arrays;
import static us.ihmc.convexOptimization.linearProgram.LinearProgramSolver.*;
/**
* Solves a dictionary form LP using the criss-cross or simplex methods.
* Simplex implementation borrows from org.apache.commons.math3.optim.linear.SimplexSolver and doi.org/10.3929/ethz-b-000426221 (ch. 4)
*/
public class DictionaryFormLinearProgramSolver
{
static final boolean debug = false;
static final int maxVariables = 200;
static final int maxCrissCrossIterations = 50000;
static final int maxSimplexIterations = 1000;
static final int nullMatrixIndex = -1;
static final double epsilon = 1e-6;
static final double zeroCutoff = 1e-10;
private final LinearProgramDictionary dictionary = new LinearProgramDictionary();
private final DMatrixRMaj primalSolution = new DMatrixRMaj(maxVariables);
private final DMatrixRMaj dualSolution = new DMatrixRMaj(maxVariables);
private final Stopwatch timer = new Stopwatch();
private final SolverStatistics simplexStatistics = new SolverStatistics();
private final SolverStatistics crissCrossStatistics = new SolverStatistics();
private enum SimplexPhase
{
PHASE_I, PHASE_II;
int objectiveSize()
{
return this == PHASE_I ? 2 : 1;
}
}
/////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////// SIMPLEX METHOD /////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////
public void solveSimplex(DMatrixRMaj startingDictionary)
{
if (startingDictionary.getNumCols() > maxVariables)
{
throw new IllegalArgumentException("Simplex method has a maximum of " + maxVariables + " decision variables, " + startingDictionary.getNumCols() + " provided.");
}
timer.reset();
simplexStatistics.clear();
if (dictionary.initialize(startingDictionary, SolverMethod.SIMPLEX))
{
/* Phase I: compute feasible dictionary */
performSimplexPhase(SimplexPhase.PHASE_I);
if (!simplexStatistics.foundSolution())
{
return;
}
else if (dictionary.getEntry(0, 0) < -epsilon)
{
simplexStatistics.onSolverFailure(LinearProgramFailureReason.INVALID_PHASE_I_SOLUTION, true);
return;
}
dictionary.dropPhaseIVariables();
}
/* Phase II: optimize feasible dictionary */
performSimplexPhase(SimplexPhase.PHASE_II);
packSolution(simplexStatistics);
simplexStatistics.setSolveTime(timer.lapElapsed());
}
/* package private for testing */
void performSimplexPhase(SimplexPhase phase)
{
while (true)
{
if (simplexStatistics.getAndIncrementIterations() > maxSimplexIterations)
{
simplexStatistics.onSolverFailure(LinearProgramFailureReason.MAX_ITERATIONS_REACHED, phase == SimplexPhase.PHASE_I);
break;
}
if (isSimplexOptimal())
{
simplexStatistics.onSolutionFound();
break;
}
int s = computeSimplexPivotColumn();
int r = computeSimplexPivotRow(s, phase);
if (r == nullMatrixIndex)
{
simplexStatistics.onSolverFailure(LinearProgramFailureReason.NO_CANDIDATE_PIVOT, phase == SimplexPhase.PHASE_I);
break;
}
if (debug)
{
System.out.println("Pivoting on (" + r + "," + s + ")\n");
}
dictionary.performPivot(r, s);
}
}
private void packSolution(SolverStatistics solverStatistics)
{
primalSolution.reshape(dictionary.getNumberOfColumns() - 1, 1);
dualSolution.reshape(dictionary.getNumberOfRows() - 1, 1);
Arrays.fill(primalSolution.getData(), 0.0);
Arrays.fill(dualSolution.getData(), 0.0);
double minDictionaryRHSColumnEntry = Double.POSITIVE_INFINITY;
for (int i = 1; i < dictionary.getBasisSize(); i++)
{
int lexicalIndex = dictionary.getBasisIndex(i);
double entry = dictionary.getEntry(i, 0);
minDictionaryRHSColumnEntry = Math.min(entry, minDictionaryRHSColumnEntry);
if (isNonNegativeConstraint(lexicalIndex, primalSolution.getNumRows()))
{
int variableIndex = toVariableIndex(lexicalIndex);
primalSolution.set(variableIndex, entry);
}
}
for (int i = 1; i < dictionary.getNonBasisSize(); i++)
{
int lexicalIndex = dictionary.getNonBasisIndex(i);
double entry = dictionary.getEntry(0, i);
if (!isNonNegativeConstraint(lexicalIndex, primalSolution.getNumRows()))
{
int constraintIndex = toConstraintIndex(lexicalIndex, primalSolution.getNumRows());
dualSolution.set(constraintIndex, -entry);
}
}
solverStatistics.setMinDictionaryRHSColumnEntry(minDictionaryRHSColumnEntry);
}
/* Checks optimality assuming feasibility, so only the objective row needs to be checked */
private boolean isSimplexOptimal()
{
for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
{
if (dictionary.getEntry(0, j) > epsilon)
{
return false;
}
}
return true;
}
// use Bland pivot rule, which finds the positive objective row entry with the lowest corresponding variable (lexical) index
private int computeSimplexPivotColumn()
{
int minimumEntryIndex = Integer.MAX_VALUE;
int column = nullMatrixIndex;
for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
{
double entry = dictionary.getEntry(0, j);
if (entry < epsilon)
{
continue;
}
int index = dictionary.getNonBasisIndex(j);
if (index < minimumEntryIndex)
{
minimumEntryIndex = index;
column = j;
}
}
return column;
}
private final TIntArrayList minRatioIndices = new TIntArrayList(maxVariables + 1);
private int computeSimplexPivotRow(int column, SimplexPhase phase)
{
double minRatio = Double.MAX_VALUE;
minRatioIndices.reset();
for (int i = phase.objectiveSize(); i < dictionary.getNumberOfRows(); i++)
{
double d_ig = dictionary.getEntry(i, 0);
double d_is = dictionary.getEntry(i, column);
if (d_is > -epsilon)
{
continue;
}
double ratio = Math.abs(d_ig / d_is);
int cmp = Precision.compareTo(ratio, minRatio, epsilon);
if (cmp == 0)
{
minRatioIndices.add(i);
}
else if (cmp < 0)
{
minRatioIndices.reset();
minRatioIndices.add(i);
minRatio = ratio;
}
}
if (minRatioIndices.isEmpty())
{
return nullMatrixIndex;
}
else if (minRatioIndices.size() > 1)
{
// (from apache impl...)
// apply Bland's rule to prevent cycling:
// take the row for which the corresponding basic variable has the smallest index
int minRowIndex = Integer.MAX_VALUE;
int minRow = nullMatrixIndex;
for (int i = 0; i < minRatioIndices.size(); i++)
{
int variableIndex = dictionary.getBasisIndex(minRatioIndices.get(i));
if (variableIndex < minRowIndex)
{
minRowIndex = variableIndex;
minRow = minRatioIndices.get(i);
}
}
return minRow;
}
else
{
return minRatioIndices.get(0);
}
}
public DMatrixRMaj getPrimalSolution()
{
return primalSolution;
}
public DMatrixRMaj getDualSolution()
{
return dualSolution;
}
public void printSolution()
{
System.out.println("Solution:");
for (int i = 0; i < primalSolution.getNumRows(); i++)
{
System.out.println("\t " + primalSolution.get(i));
}
}
/////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////// CRISS CROSS METHOD /////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////
public void solveCrissCross(DMatrixRMaj startingDictionary)
{
crissCrossStatistics.clear();
timer.reset();
dictionary.initialize(startingDictionary, SolverMethod.CRISS_CROSS);
while (true)
{
if (crissCrossStatistics.getAndIncrementIterations() > maxCrissCrossIterations)
{
crissCrossStatistics.setFoundSolution(false);
break;
}
int candidateBasisPivot = findNegativeColumnEntryWithBlandRule(0);
int candidateNonBasisPivot = findPositiveRowEntryWithBlandRule(0);
int basisPivot, nonBasisPivot;
if (candidateBasisPivot == nullMatrixIndex && candidateNonBasisPivot == nullMatrixIndex)
{
crissCrossStatistics.onSolutionFound();
break;
}
else if (candidateBasisPivot != nullMatrixIndex && (candidateNonBasisPivot == nullMatrixIndex
|| dictionary.getBasisIndex(candidateBasisPivot) < dictionary.getNonBasisIndex(candidateNonBasisPivot)))
{
basisPivot = candidateBasisPivot;
nonBasisPivot = findPositiveRowEntryWithBlandRule(basisPivot);
if (nonBasisPivot == nullMatrixIndex)
{
// inconsistent
break;
}
}
else
{
nonBasisPivot = candidateNonBasisPivot;
basisPivot = findNegativeColumnEntryWithBlandRule(nonBasisPivot);
if (basisPivot == nullMatrixIndex)
{
// dual inconsistent
break;
}
}
dictionary.performPivot(basisPivot, nonBasisPivot);
}
packSolution(crissCrossStatistics);
crissCrossStatistics.setSolveTime(timer.totalElapsed());
}
private int findNegativeColumnEntryWithBlandRule(int column)
{
int minLexicalIndex = Integer.MAX_VALUE;
int row = nullMatrixIndex;
for (int i = 1; i < dictionary.getNumberOfRows(); i++)
{
double d_ig = dictionary.getEntry(i, column);
int lexicalIndex = dictionary.getBasisIndex(i);
if (d_ig < -epsilon && lexicalIndex < minLexicalIndex)
{
minLexicalIndex = lexicalIndex;
row = i;
}
}
return row;
}
private int findPositiveRowEntryWithBlandRule(int row)
{
int minLexicalIndex = Integer.MAX_VALUE;
int column = nullMatrixIndex;
for (int j = 1; j < dictionary.getNumberOfColumns(); j++)
{
double d_fj = dictionary.getEntry(row, j);
int lexicalIndex = dictionary.getNonBasisIndex(j);
if (d_fj > epsilon && lexicalIndex < minLexicalIndex)
{
minLexicalIndex = lexicalIndex;
column = j;
}
}
return column;
}
public SolverStatistics getCrissCrossStatistics()
{
return crissCrossStatistics;
}
public SolverStatistics getSimplexStatistics()
{
return simplexStatistics;
}
public TIntArrayList getBasisIndices()
{
return dictionary.getBasisIndices();
}
public TIntArrayList getNonBasisIndices()
{
return dictionary.getNonBasisIndices();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy