All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.qudtlib.algorithm.AssignmentProblem Maven / Gradle / Ivy

There is a newer version: 6.7.0
Show newest version
package io.github.qudtlib.algorithm;

import java.util.Arrays;
import java.util.stream.DoubleStream;

public class AssignmentProblem {
    /**
     * A problem instance, defined by its weight matrix. The matrix may not have more rows than
     * columns. A valid solution is the assigment of one column per row, such that no columns are
     * selected more than once and the sum of the assignment is minimal. Only one such solution is
     * generated.
     */
    public abstract static class Instance {
        protected final double[][] weights;
        protected final int rows;
        protected final int cols;

        public Instance(double[][] weights) {
            this.weights = weights;
            this.rows = weights.length;
            this.cols = weights[0].length;
        }

        public double[][] getWeights() {
            return weights;
        }

        public Double weightOfAssignment(int[] selection) {
            if (selection == null || selection.length == 0) {
                return null;
            }
            double sum = 0;
            for (int i = 0; i < selection.length; i++) {
                sum += weights[i][selection[i]];
            }
            return sum;
        }

        public abstract Solution solve();
    }

    public static class NaiveAlgorithmInstance extends Instance {
        private Solution currentBestSolution;

        public NaiveAlgorithmInstance(double[][] weights) {
            super(weights);
        }

        public boolean isLowerThanBestWeight(double weightToTest) {
            if (currentBestSolution == null) {
                return true;
            }
            if (!currentBestSolution.isComplete()) {
                return true;
            }
            return currentBestSolution.weight > weightToTest;
        }

        public void updateBestSolutionIfPossible(Solution candidate) {
            if (currentBestSolution == null
                    || candidate.isBetterSolutionThan(currentBestSolution)) {
                this.currentBestSolution = candidate;
            }
        }

        @Override
        public Solution solve() {
            solve(0, new Solution(this));
            return this.currentBestSolution;
        }

        private void solve(int row, Solution solution) {
            if (row >= this.rows) {
                updateBestSolutionIfPossible(solution);
                return;
            }
            if (!solution.isEmpty()) {
                double bestAttainableScore = sum(minPerRow(row, solution.assignment));
                if (!isLowerThanBestWeight(solution.weight + bestAttainableScore)) {
                    return;
                }
            }
            ValueWithIndex[] nMin = rowSortedAscending(this.weights, row, solution.getAssignment());
            for (int i = 0; i < nMin.length; i++) {
                if ((solution.isEmpty())
                        || isLowerThanBestWeight(solution.getWeight() + nMin[i].value)) {
                    solve(row + 1, solution.assignColumnInNextRow(nMin[i].index));
                }
            }
        }

        private double[] minPerRow(int startRow, int[] skipCols) {
            double[] ret = new double[weights.length - startRow];
            for (int r = startRow; r < weights.length; r++) {
                double min = Double.MAX_VALUE;
                for (int c = 0; c < weights[r].length; c++) {
                    if (!containsValue(skipCols, c)) {
                        if (min > weights[r][c]) {
                            min = weights[r][c];
                        }
                    }
                }
                ret[r - startRow] = min;
            }
            return ret;
        }

        private double sum(double[] arr) {
            return DoubleStream.of(arr).sum();
        }
    }

    public static class Solution {
        /**
         * for each row in the instance, contains the index of the column assigned for this solution
         * *
         */
        private final int[] assignment;
        /** sum of the assigned weights * */
        private final Double weight;

        private final Instance instance;

        public Solution(Instance instance) {
            this(instance, new int[0]);
        }

        public Solution(Instance instance, int[] assignment) {
            this.assignment = assignment;
            this.instance = instance;
            this.weight = instance.weightOfAssignment(assignment);
        }

        public int[] getAssignment() {
            return assignment;
        }

        public Double getWeight() {
            return weight;
        }

        public Solution assignColumnInNextRow(int col) {
            if (isComplete()) {
                throw new IllegalStateException("Solution is already complete");
            }
            return new Solution(this.instance, append(this.assignment, col));
        }

        private boolean isComplete() {
            return this.assignment.length >= instance.rows;
        }

        public boolean isEmpty() {
            return this.assignment == null || this.assignment.length == 0;
        }

        public boolean isBetterSolutionThan(Solution other) {
            if (!(this.isComplete() && other.isComplete())) {
                throw new IllegalStateException("Cannot compare incomplete solutions");
            }
            return this.weight < other.weight;
        }
    }

    public static Instance instance(double[][] weights) {
        int rows = weights.length;
        if (rows == 0) {
            throw new IllegalArgumentException("Cannot create instance with 0x0 weights matrix");
        }
        int cols = weights[0].length;
        if (rows > cols) {
            throw new IllegalArgumentException(
                    "The weights matrix may not have more rows than columns");
        }
        return new NaiveAlgorithmInstance(copy(weights));
    }

    private static double[][] copy(double[][] weights) {
        int rows = weights.length;
        int cols = weights[0].length;
        double[][] ret = new double[rows][cols];
        for (int i = 0; i < rows; i++) {
            System.arraycopy(weights[i], 0, ret[i], 0, cols);
        }
        return ret;
    }

    private static double[][] flipMatrix(double[][] mat) {
        int rows = mat.length;
        int cols = mat[0].length;
        double[][] ret = new double[cols][rows];
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                ret[j][i] = mat[i][j];
            }
        }
        return ret;
    }

    private static int[] append(int[] arr, int val) {
        int[] ret = new int[arr.length + 1];
        System.arraycopy(arr, 0, ret, 0, arr.length);
        ret[ret.length - 1] = val;
        return ret;
    }

    private static class ValueWithIndex {
        private final int index;
        private final double value;

        public ValueWithIndex(double value, int index) {
            this.index = index;
            this.value = value;
        }
    }

    private static ValueWithIndex[] rowSortedAscending(
            double[][] weights, int row, int[] skipCols) {
        int cols = weights[row].length;
        ValueWithIndex[] sorted = new ValueWithIndex[cols - skipCols.length];
        int j = 0;
        for (int i = 0; i < cols; i++) {
            if (!containsValue(skipCols, i)) {
                sorted[j++] = new ValueWithIndex(weights[row][i], i);
            }
        }
        Arrays.sort(sorted, (left, right) -> (int) Math.signum(left.value - right.value));
        return sorted;
    }

    private static ValueWithIndex[] rowNMin(double[][] mat, int row, int n, int[] skipCols) {
        ValueWithIndex[] nMin = new ValueWithIndex[n];
        for (int i = 0; i < mat[0].length; i++) {
            if (!containsValue(skipCols, i)) {
                replaceIfLower(nMin, mat[row][i], i);
            }
        }
        Arrays.sort(nMin, (left, right) -> (int) Math.signum(left.value - right.value));
        return nMin;
    }

    private static boolean containsValue(int[] arr, int val) {
        for (int i = 0; i < arr.length; i++) {
            if (arr[i] == val) return true;
        }
        return false;
    }

    private static void replaceIfLower(ValueWithIndex[] nMin, double value, int index) {
        int maxValueIndex = -1;
        double maxValue = -Double.MAX_VALUE;
        for (int i = 0; i < nMin.length; i++) {
            if (nMin[i] == null) {
                nMin[i] = new ValueWithIndex(value, index);
                return;
            }
            if (nMin[i].value > maxValue) {
                maxValue = nMin[i].value;
                maxValueIndex = i;
            }
        }
        if (maxValue > value) {
            nMin[maxValueIndex] = new ValueWithIndex(value, index);
        }
    }

    private static ValueWithIndex rowMax(double[][] mat, int row, int[] skipCols) {
        if (skipCols == null) {
            skipCols = new int[] {-1};
        }
        Arrays.sort(skipCols);
        double max = -Double.MAX_VALUE;
        int index = -1;
        int cols = mat[0].length;
        for (int i = 0; i < cols; i++) {
            if (Arrays.binarySearch(skipCols, i) > -1) {
                continue;
            }
            double cur = mat[row][i];
            if (cur > max) {
                index = i;
                max = cur;
            }
        }
        return new ValueWithIndex(max, index);
    }

    private static ValueWithIndex colMax(double[][] mat, int col, int[] skipRows) {
        if (skipRows == null) {
            skipRows = new int[] {-1};
        }
        Arrays.sort(skipRows);
        double max = -Double.MAX_VALUE;
        int index = -1;
        int rows = mat.length;
        for (int i = 0; i < rows; i++) {
            if (Arrays.binarySearch(skipRows, i) > -1) {
                continue;
            }
            double cur = mat[i][col];
            if (cur > max) {
                index = i;
                max = cur;
            }
        }
        return new ValueWithIndex(max, index);
    }

    private static double[][] copyWithoutCol(double[][] mat, int col) {
        int rows = mat.length;
        int cols = mat[0].length;
        double[][] ret = new double[rows][cols - 1];
        for (int r = 0; r < rows; r++) {
            for (int c = 0; c < cols - 1; c++) {
                int readCol = c >= col ? c + 1 : c;
                ret[r][c] = mat[r][readCol];
            }
        }
        return ret;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy