org.neo4j.gds.CommunityHelper Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of test-utils Show documentation
Show all versions of test-utils Show documentation
Neo4j Graph Data Science :: Test Utils
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package org.neo4j.gds;
import org.neo4j.gds.collections.ha.HugeLongArray;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
public final class CommunityHelper {
private CommunityHelper() {}
public static void assertCommunities(Map actual, long[]... expectedCommunities) {
for (long[] expectedCommunity : expectedCommunities) {
assertThat(Arrays.stream(expectedCommunity).map(actual::get).distinct())
.withFailMessage("Expected %s to be in the same community. But actual communities are: %s", Arrays.toString(expectedCommunity), actual)
.hasSize(1);
}
}
public static void assertCommunities(HugeLongArray communityData, long[]... communities) {
assertCommunities(communityData.toArray(), communities);
}
/**
* Helper method that checks if the result of a community algorithm has the expected communities.
* It only tests if members are in the same or different communities, given the input and
* disregards specific community values.
*
* @param actual The output of a community detection algorithm.
* @param expected The expected membership of communities. Elements within an array are
* expected to be in the same community, whereas all elements of different
* arrays are expected to be in different communities.
*/
public static void assertCommunities(long[] actual, long[]... expected) {
List actualList = Arrays.stream(actual).boxed().collect(toList());
List> expectedList = Arrays.stream(expected).map(
a -> Arrays.stream(a).boxed().collect(toList())
).collect(toList());
assertCommunities(actualList, expectedList);
}
public static void assertCommunities(List actualCommunityData, List> expectedCommunities) {
for (List community : expectedCommunities) {
assertSameCommunity(actualCommunityData, community);
}
for (int i = 0; i < expectedCommunities.size(); i++) {
for (int j = i + 1; j < expectedCommunities.size(); j++) {
int expected = expectedCommunities.get(i).get(0).intValue();
int actual = expectedCommunities.get(j).get(0).intValue();
assertNotEquals(
actualCommunityData.get(expected),
actualCommunityData.get(actual),
formatWithLocale(
"Expected node %d to be in a different community than node %d",
expected,
actual
)
);
}
}
}
private static void assertSameCommunity(List communities, List members) {
long expectedCommunity = communities.get(members.get(0).intValue());
for (int i = 1; i < members.size(); i++) {
Long member = members.get(i);
long actualCommunity = communities.get(member.intValue());
assertEquals(
expectedCommunity,
actualCommunity,
formatWithLocale(
"Expected node %d (community %d) to have the same community as node %d (community %d)",
member,
actualCommunity,
members.get(0),
expectedCommunity
)
);
}
}
public static void assertCommunitiesWithLabels(HugeLongArray communityData, Map expectedCommunities) {
assertCommunitiesWithLabels(communityData.toArray(), expectedCommunities);
}
private static void assertCommunitiesWithLabels(long[] communityData, Map expectedCommunities) {
List communityDataList = Arrays.stream(communityData).boxed().collect(toList());
List> expectedCommunitiesList = expectedCommunities
.values()
.stream()
.map(l -> Arrays.stream(l).boxed().collect(toList()))
.collect(toList());
assertCommunities(communityDataList, expectedCommunitiesList);
for (Map.Entry entry : expectedCommunities.entrySet()) {
long label = entry.getKey();
long[] community = entry.getValue();
for (Long nodeId : community) {
assertEquals(label, communityData[nodeId.intValue()],
formatWithLocale(
"Expected node %d to be in community %d, but was %d",
nodeId,
label,
communityData[nodeId.intValue()]
)
);
}
}
}
}