All Downloads are FREE. Search and download functionalities are using the official Maven repository.

commonMain.io.data2viz.sankey.SankeyLayout.kt Maven / Gradle / Ivy

/*
 * Copyright (c) 2018-2021. data2viz sàrl.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 */

package io.data2viz.sankey

import io.data2viz.shape.link.LinkBuilder
import io.data2viz.shape.link.linkBuilderH
import io.data2viz.shape.link.linkBuilderV
import kotlin.math.floor
import kotlin.math.max
import kotlin.math.min

public enum class SankeyAlignment { CENTER, JUSTIFY, RIGHT, LEFT }

///// LINKS VISUALIZATIONS

public val sankeyLinkHorizontal: LinkBuilder> = linkBuilderH> {
    x0 = { it.source.x1 }
    y0 = { it.y0 }
    x1 = { it.target.x0 }
    y1 = { it.y1 }
}

public val sankeyLinkVertical: LinkBuilder> = linkBuilderV> {
    x0 = { it.y0 }
    y0 = { it.source.x0 }
    x1 = { it.y1 }
    y1 = { it.target.x1 }
}

/**
 * source - the link’s source node
 * target - the link’s target node
 * value - the link’s numeric value
 * y0 - the link’s vertical starting position (at source node)
 * y1 - the link’s vertical end position (at target node)
 * width - the link’s width (proportional to link.value)
 * index - the zero-based index of link within the array of links
 */
public data class SankeyLink(
    val source: SankeyNode,
    val target: SankeyNode,
    var index: Int,
    var value: Double,
    var y0: Double = .0,
    var y1: Double = .0,
    var width: Double = .0
)

/**
 * sourceLinks - the array of outgoing links which have this node as their source
 * targetLinks - the array of incoming links which have this node as their target
 * value - the node’s value; the sum of link.value for the node’s incoming links
 * index - the node’s zero-based index within the array of nodes
 * depth - the node’s zero-based graph depth, derived from the graph topology
 * height - the node’s zero-based graph height, derived from the graph topology
 * x0 - the node’s minimum horizontal position, derived from node.depth
 * x1 - the node’s maximum horizontal position (node.x0 + sankey.nodeWidth)
 * y0 - the node’s minimum vertical position
 * y1 - the node’s maximum vertical position (node.y1 - node.y0 is proportional to node.value)
 */
public data class SankeyNode(
    val data: D,
    var index: Int,
    val sourceLinks: MutableList> = mutableListOf(),
    val targetLinks: MutableList> = mutableListOf(),
    var value: Double = .0,
    var depth: Int = 0,
    var height: Int = 0,
    var x0: Double = .0,
    var x1: Double = .0,
    var y0: Double = .0,
    var y1: Double = .0
)

public data class SankeyGraph(
    val nodes:List>,
    val links: List>
)

public class SankeyLayout {

    // extent
    private var x0 = .0
    private var y0 = .0
    private var x1 = 1.0
    private var y1 = 1.0

    public var height: Double
        get() = y1 - y0
        set(value) {
            y0 = .0
            y1 = value
        }

    public var width: Double
        get() = x1 - x0
        set(value) {
            x0 = .0
            x1 = value
        }

    public fun extent(x0: Double, x1: Double, y0: Double, y1: Double) {
        this.x0 = x0
        this.x1 = x1
        this.y0 = y0
        this.y1 = y1
    }

    public var nodeWidth: Double = 24.0
    public var nodePadding: Double = 8.0

    public var align: SankeyAlignment = SankeyAlignment.JUSTIFY

    // the number of relaxation iterations when generating the layout
    public var iterations: Int = 32

    public val nodes: MutableList> = mutableListOf>()
    public val links: MutableList> = mutableListOf>()

    public fun sankey(data: List, flow: (from: D, to: D) -> Double?): SankeyGraph {
        nodes.clear()
        links.clear()
        computeNodeLinks(data, flow)
        computeNodeValues()
        computeNodeDepths()
        computeNodeBreadths()
        computeLinkBreadths()
        return SankeyGraph(nodes, links)
    }

    private fun computeLinkBreadths() {
        nodes.forEach { node ->
            node.sourceLinks.sortWith(compareBy({ it.target.y0 }, { it.index }))
            node.targetLinks.sortWith(compareBy({ it.source.y0 }, { it.index }))
        }
        nodes.forEach { node ->
            var y0 = node.y0
            var y1 = y0
            node.sourceLinks.forEach { link ->
                link.y0 = y0 + link.width / 2.0
                y0 += link.width
            }
            node.targetLinks.forEach { link ->
                link.y1 = y1 + link.width / 2.0
                y1 += link.width
            }
        }
    }

    /**
     * Iteratively assign the depth (x-position) for each node.
     * Nodes are assigned the maximum depth of incoming neighbors plus one; nodes with no incoming links are assigned
     * depth zero, while nodes with no outgoing links are assigned the maximum depth.
     */
    private fun computeNodeDepths() {
        var nodeList = nodes.toList()
        val next = mutableListOf>()
        var nodeDepth = 0
        while (nodeList.isNotEmpty()) {
            nodeList.forEach { node ->
                node.depth = nodeDepth
                node.sourceLinks.forEach { link ->
                    if (next.indexOf(link.target) < 0) next.add(link.target)
                }
            }
            nodeDepth++
            nodeList = next.toList()
            next.clear()
        }

        nodeList = nodes.toList()
        next.clear()
        var nodeHeight = 0
        while (nodeList.isNotEmpty()) {
            nodeList.forEach { node ->
                node.height = nodeHeight
                node.targetLinks.forEach { link ->
                    if (next.indexOf(link.source) < 0) next.add(link.source)
                }
            }
            nodeHeight++
            nodeList = next.toList()
            next.clear()
        }

        val kx = (width - nodeWidth) / (nodeHeight - 1)
        nodes.forEach { node ->
            val x = when (align) {
                SankeyAlignment.JUSTIFY -> justify(node, nodeHeight)
                SankeyAlignment.CENTER -> center(node, nodeHeight)
                SankeyAlignment.RIGHT -> right(node, nodeHeight)
                SankeyAlignment.LEFT -> left(node, nodeHeight)
            }.toDouble()
            node.x0 = x0 + max(.0, min(nodeHeight - 1.0, floor(x))) * kx
            node.x1 = node.x0 + nodeWidth
        }
    }

    private fun relaxLeftToRight(columns: Map>>, alpha: Double) {
        columns.forEach { nodeList ->
            nodeList.value.forEach { node ->
                if (node.targetLinks.isNotEmpty()) {
                    val dy =
                        (node.targetLinks.sumOf(::weightedSource) / node.targetLinks.sumOf { it.value } - nodeCenter(
                            node
                        )) * alpha;
                    node.y0 += dy
                    node.y1 += dy
                }
            }
        }
    }

    private fun relaxRightToLeft(columns: Map>>, alpha: Double) {
        columns.keys.reversed().forEach { nodeKey ->
            val nodeList = columns.get(nodeKey)!!
            nodeList.forEach { node ->
                if (node.sourceLinks.isNotEmpty()) {
                    val sum1 = node.sourceLinks.sumOf(::weightedTarget)
                    val sum2 = node.sourceLinks.sumOf { it.value }
                    val nodeCenter = nodeCenter(node)
                    val dy = (sum1 / sum2 - nodeCenter) * alpha
                    node.y0 += dy
                    node.y1 += dy
                }
            }
        }
    }

    private fun weightedTarget(link: SankeyLink): Double {
        return nodeCenter(link.target) * link.value
    }

    private fun weightedSource(link: SankeyLink): Double {
        return nodeCenter(link.source) * link.value
    }

    private fun nodeCenter(node: SankeyNode): Double {
        return (node.y0 + node.y1) / 2.0
    }

    /*
    columns.forEach(function(nodes) {
                var node,
                    dy,
                    y = y0,
                    n = nodes.length,
                    i;

                // Push any overlapping nodes down.
                nodes.sort(ascendingBreadth);
                for (i = 0; i < n; ++i) {
                    node = nodes[i];
                    dy = y - node.y0;
                    if (dy > 0) node.y0 += dy, node.y1 += dy;
                    y = node.y1 + py;
                }

                // If the bottommost node goes outside the bounds, push it back up.
                dy = y - py - y1;
                if (dy > 0) {
                    y = (node.y0 -= dy), node.y1 -= dy;

                    // Push any overlapping nodes back up.
                    for (i = n - 2; i >= 0; --i) {
                        node = nodes[i];
                        dy = node.y1 + py - y;
                        if (dy > 0) node.y0 -= dy, node.y1 -= dy;
                        y = node.y0;
                    }
                }
            });
     */

    private fun resolveCollisions(columns: Map>>) {
        columns.forEach { nodesList ->
            val nodes = nodesList.value.sortedBy { it.y0 }
            var dy: Double
            var y = y0

            // Push any overlapping nodes down.
            nodes.forEach { node ->
                dy = y - node.y0
                if (dy > 0) {
                    node.y0 += dy
                    node.y1 += dy
                }
                y = node.y1 + nodePadding
            }

            // If the bottommost node goes outside the bounds, push it back up.
            dy = y - nodePadding - y1
            if (dy > 0) {
                val lastNode = nodes.last()
                lastNode.y0 -= dy
                y = lastNode.y0
                lastNode.y1 -= dy

                // Push any overlapping nodes back up.
                (nodes.size - 2 downTo 0).forEach { index ->
                    val node = nodes[index]
                    dy = node.y1 + nodePadding - y
                    if (dy > 0) {
                        node.y0 -= dy
                        node.y1 -= dy
                    }
                    y = node.y0
                }
            }
        }
    }

    private fun computeNodeBreadths() {
        /*var columns = nest()
            .key(function(d) { return d.x0; })
            .sortKeys(ascending)
            .entries(graph.nodes)
            .map(function(d) { return d.values; });
        }*/
        val columns = nodes.groupBy({ it.x0 }, { it })
        initializeNodeBreadth(columns)
        resolveCollisions(columns)

        var alpha = 1.0
        (1..iterations).forEach {
            alpha *= 0.99
            relaxRightToLeft(columns, alpha)
            resolveCollisions(columns)
            relaxLeftToRight(columns, alpha)
            resolveCollisions(columns)
        }
    }

    private fun initializeNodeBreadth(columns: Map>>) {
        val ky = columns.map { nodes ->
                (height - (nodes.value.size - 1) * nodePadding) / nodes.value.sumOf { it.value }
        }.minOrNull()!!

        columns.forEach { nodes ->
            nodes.value.forEachIndexed { i, node ->
                node.y0 = i.toDouble()
                node.y1 = node.y0 + node.value * ky
            }
        }

        links.forEach { link ->
            link.width = link.value * ky
        }
    }

    /**
     * Compute the value (size) of each node by summing the associated links.
     */
    private fun computeNodeValues() {
        nodes.forEach { node ->
            node.value = max(node.sourceLinks.sumOf { it.value }, node.targetLinks.sumOf { it.value })
        }
    }

    /**
     * Populate the sourceLinks and targetLinks for each node.
     */
    private fun computeNodeLinks(data: List, flow: (from: D, to: D) -> Double?) {
        nodes.addAll(data.mapIndexed { index, d -> SankeyNode(d, index) })
        var index = 0
        data.forEachIndexed { index1, d1 ->
            data.forEachIndexed { index2, d2 ->
                val linkValue = flow(d1, d2)
                if (linkValue != null && linkValue > .0) {
                    val node1 = nodes[index1]
                    val node2 = nodes[index2]
                    val link = SankeyLink(node1, node2, index, linkValue)
                    links.add(link)
                    node1.sourceLinks.add(link)
                    node2.targetLinks.add(link)
                    index++
                }
            }
        }
    }

    ///// ALIGNMENTS

    private fun justify(node: SankeyNode<*>, size: Int): Int {
        return if (node.sourceLinks.isEmpty()) size - 1 else node.depth
    }

    private fun left(node: SankeyNode<*>, size: Int): Int {
        return node.depth
    }

    private fun right(node: SankeyNode<*>, size: Int): Int {
        return size - 1 - node.height
    }

    private fun center(node: SankeyNode<*>, size: Int): Int {
        return if (node.targetLinks.isEmpty()) {
            if (node.sourceLinks.isEmpty()) 0 else node.sourceLinks.minByOrNull { it.target.depth }!!.target.depth - 1
        } else node.depth
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy