All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.gds.CommunityHelper Maven / Gradle / Ivy

There is a newer version: 2.11.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.gds;

import org.neo4j.gds.collections.ha.HugeLongArray;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

public final class CommunityHelper {

    private CommunityHelper() {}

    public static void assertCommunities(Map actual, long[]... expectedCommunities) {
        for (long[] expectedCommunity : expectedCommunities) {
            assertThat(Arrays.stream(expectedCommunity).map(actual::get).distinct())
                .withFailMessage("Expected %s to be in the same community. But actual communities are: %s", Arrays.toString(expectedCommunity), actual)
                .hasSize(1);
        }
    }

    public static void assertCommunities(HugeLongArray communityData, long[]... communities) {
        assertCommunities(communityData.toArray(), communities);
    }

    /**
     * Helper method that checks if the result of a community algorithm has the expected communities.
     * It only tests if members are in the same or different communities, given the input and
     * disregards specific community values.
     *
     * @param actual   The output of a community detection algorithm.
     * @param expected The expected membership of communities. Elements within an array are
     *                 expected to be in the same community, whereas all elements of different
     *                 arrays are expected to be in different communities.
     */
    public static void assertCommunities(long[] actual, long[]... expected) {
        List actualList = Arrays.stream(actual).boxed().collect(toList());
        List> expectedList = Arrays.stream(expected).map(
            a -> Arrays.stream(a).boxed().collect(toList())
        ).collect(toList());

        assertCommunities(actualList, expectedList);
    }

    public static void assertCommunities(List actualCommunityData, List> expectedCommunities) {
        for (List community : expectedCommunities) {
            assertSameCommunity(actualCommunityData, community);
        }

        for (int i = 0; i < expectedCommunities.size(); i++) {
            for (int j = i + 1; j < expectedCommunities.size(); j++) {
                int expected = expectedCommunities.get(i).get(0).intValue();
                int actual = expectedCommunities.get(j).get(0).intValue();
                assertNotEquals(
                    actualCommunityData.get(expected),
                    actualCommunityData.get(actual),
                    formatWithLocale(
                        "Expected node %d to be in a different community than node %d",
                        expected,
                        actual
                    )
                );
            }
        }
    }

    private static void assertSameCommunity(List communities, List members) {
        long expectedCommunity = communities.get(members.get(0).intValue());

        for (int i = 1; i < members.size(); i++) {
            Long member = members.get(i);
            long actualCommunity = communities.get(member.intValue());
            assertEquals(
                expectedCommunity,
                actualCommunity,
                formatWithLocale(
                    "Expected node %d (community %d) to have the same community as node %d (community %d)",
                    member,
                    actualCommunity,
                    members.get(0),
                    expectedCommunity
                )
            );
        }
    }

    public static void assertCommunitiesWithLabels(HugeLongArray communityData, Map expectedCommunities) {
        assertCommunitiesWithLabels(communityData.toArray(), expectedCommunities);
    }

    private static void assertCommunitiesWithLabels(long[] communityData, Map expectedCommunities) {
        List communityDataList = Arrays.stream(communityData).boxed().collect(toList());
        List> expectedCommunitiesList = expectedCommunities
            .values()
            .stream()
            .map(l -> Arrays.stream(l).boxed().collect(toList()))
            .collect(toList());

        assertCommunities(communityDataList, expectedCommunitiesList);
        for (Map.Entry entry : expectedCommunities.entrySet()) {
            long label = entry.getKey();
            long[] community = entry.getValue();

            for (Long nodeId : community) {
                assertEquals(label, communityData[nodeId.intValue()],
                    formatWithLocale(
                        "Expected node %d to be in community %d, but was %d",
                        nodeId,
                        label,
                        communityData[nodeId.intValue()]
                    )
                );
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy