All Downloads are FREE. Search and download functionalities are using the official Maven repository.

breeze.optimize.linear.KuhnMunkres.scala Maven / Gradle / Ivy

There is a newer version: 1.0
Show newest version
package breeze.optimize.linear
/*
 Copyright 2010 David Hall, Daniel Ramage

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
*/


/**
 * Algorithm to find a minimum cost matching on a bipartite graph.
 *
 * Implements the hungarian algorithm.
 *
 * @author dlwh
 */
object KuhnMunkres extends BipartiteMatching {
  // http://github.com/bmc/munkres/blob/master/munkres.py
  /**
   * Given a matrix of positive weights, finds the minimum weight bipartite matching between to arrays.
   * Returns a matching from the rows (the first index into the matrix) to the columns
   * (-1 for unmatched rows in the case of unbalanced entries) along
   * with the total score of the matching.
   */
  def extractMatching(costs: Seq[Seq[Double]]) = {

    // swap rows and columns if num rows > num cols
    val (costs2: Seq[Seq[Double]], inverted:Boolean) =
      if (costs.length > costs(0).length) {
        val newCosts = Array.fill(costs(0).length, costs.length)(0.0)
        for(i <- 0 until costs.length;
            j <- 0 until costs(0).length) {
          newCosts(j)(i) = costs(i)(j)
        }
        (newCosts.map{ row => row.toSeq }.toSeq, true)
      } else {
        (costs, false)
      }

    val C : Array[Array[Double]] = padMatrix(costs2);
    val n = C.size;
    val rowCovered = Array.fill(n)(false);
    val colCovered = Array.fill(n)(false);

    var primeR = 0;
    var primeC = 0;
    val path = Array.fill(2 * n, 2 * n)(0);
    val marked = Array.fill(n,n)(0);

    def findSmallestNotCovered() = {
      val mins = for(i <- 0 until n iterator;
                     j <- 0 until n iterator;
                     if !rowCovered(i) && !colCovered(j)
                    ) yield C(i)(j);
      mins.reduceLeft(_ min _)
    }

    def findZero() = {
      var row = -1
      var col = -1
      var i = 0
      var done = false

      while(!done && i < n) {
        var j = 0
        while (j < n) {
          if (C(i)(j) == 0 && !rowCovered(i) && !colCovered(j)) {
            row = i
            col = j
            done = true;
          }
          j += 1
        }

        i += 1
      }

      (row, col)
    }

    def erasePrimes() {
      for(i <- 0 until n; j <- 0 until n if marked(i)(j) == 2) {
        marked(i)(j) = 0;
      }
    }

    def findStarInRow(row: Int) = {
      marked(row).indexWhere(1 == _)
    }

    def findStarInCol(col: Int) = {
      Iterator.range(0,n).indexWhere( i => marked(i)(col) == 1)
    }

    def findPrimeInRow(row: Int) = {
      marked(row).indexWhere(2 == _)
    }

    def convertPath(path: Array[Array[Int]], count: Int) = {
      for(i <- 0 to count) {
        if(marked(path(i)(0))(path(i)(1)) == 1)
          marked(path(i)(0))(path(i)(1)) = 0
        else {
          marked(path(i)(0))(path(i)(1)) = 1
        }
      }
    }


    def step1() = {
      for {
        i <- 0 until n;
        min = C(i).reduceLeft(_ min _);
        j <- 0 until n
      }  {
        C(i)(j) -= min;
      }

      2;
    }

    def step2() = {
      for {
        r <- 0 until n;
        c <- 0 until n
        if C(r)(c) == 0 && !rowCovered(r) && !colCovered(c)
      }  {
        marked(r)(c) = 1;
        rowCovered(r) = true;
        colCovered(c) = true;
      }
      java.util.Arrays.fill(rowCovered,false);
      java.util.Arrays.fill(colCovered,false);
      3;
    }

    def step3() = {
      var count = 0;
      for {
        i <- 0 until n;
        j <- 0 until n
        if marked(i)(j) == 1
      } {
        colCovered(j) = true
        count += 1
      }

      if(count >= n) {
        7
      } else {
        4
      }
    }

    def step4() = {
      var star_col = -1;
      var done = false;
      var step = 0;
      while(!done) {
        val (row,col) = findZero();
        if(row == -1) {
          done = true;
          step = 6;
        } else {
          marked(row)(col) = 2;
          val starredCol = findStarInRow(row);
          if(starredCol == -1) {
            done = true;
            primeR = row;
            primeC = col;
            step = 5;
          } else {
            rowCovered(row) = true;
            colCovered(starredCol) = false;
          }
        }
      }
      step
    }

    def step5() = {
      var count = 0;
      path(count)(0) = primeR;
      path(count)(1) = primeC;
      var done = false;
      while(!done) {
        val row = findStarInCol(path(count)(1));
        if(row >= 0) {
          count += 1;
          path(count)(0) = row;
          path(count)(1) = path(count-1)(1);
        } else {
          done = true;
        }

        if(!done) {
          val col = findPrimeInRow(path(count)(0));
          count += 1
          path(count)(0) = path(count-1)(0);
          path(count)(1) = col;
        }
      }

      convertPath(path, count);
      java.util.Arrays.fill(rowCovered,false);
      java.util.Arrays.fill(colCovered,false);
      erasePrimes();

      3
    }

    def step6() = {
      val min = findSmallestNotCovered();
      for {
        r <- 0 until n;
        c <- 0 until n
      } {
        if(rowCovered(r)) C(r)(c) += min;
        if(!colCovered(c)) C(r)(c) -= min;
      }

      4
    }

    var step = 1;
    while(step < 7) {
      step = step match {
        case 1 => step1();
        case 2 => step2();
        case 3 => step3();
        case 4 => step4();
        case 5 => step5();
        case 6 => step6();
      }
    }

    var answers = Array.fill(costs2.length)(-1);
    var cost = 0.0;
    for(i <- 0 until answers.length) {
      val j = marked(i).indexWhere(_ == 1);
      if(j >= 0) {
        cost += costs2(i)(j);
        answers(i) = j;
      }
    }

    // invert rows with columns to their original layout
    if (inverted) {
      val answers2 = Array.fill(costs2(0).length)(-1)
      for (i <- 0 until answers.length) {
        val j = answers(i)
        if (j != -1) {
          answers2(j) = i
        }
      }
      answers = answers2
    }

    (answers,cost)
  }

  private def padMatrix(costs: Seq[Seq[Double]]) = {
    val rows = costs.length;
    val cols = costs(0).length;
    val n = rows max cols;
    val ret = Array.tabulate(n,n){ (i,j) =>
      if(i >= rows) 0.0;
      else if(j >= costs(i).length) 0.0
      else costs(i)(j);
    }

    ret
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy