
org.deeplearning4j.graph.Graph Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.graph;
import org.deeplearning4j.graph.api.BaseGraph;
import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.exception.NoEdgesException;
import org.deeplearning4j.graph.vertexfactory.VertexFactory;
import java.lang.reflect.Array;
import java.util.*;
public class Graph extends BaseGraph {
private boolean allowMultipleEdges;
private List>[] edges; //edge[i].get(j).to = k, then edge from i -> k
private List> vertices;
public Graph(int numVertices, VertexFactory vertexFactory) {
this(numVertices, false, vertexFactory);
}
@SuppressWarnings("unchecked")
public Graph(int numVertices, boolean allowMultipleEdges, VertexFactory vertexFactory) {
if (numVertices <= 0)
throw new IllegalArgumentException();
this.allowMultipleEdges = allowMultipleEdges;
vertices = new ArrayList<>(numVertices);
for (int i = 0; i < numVertices; i++)
vertices.add(vertexFactory.create(i));
edges = (List>[]) Array.newInstance(List.class, numVertices);
}
@SuppressWarnings("unchecked")
public Graph(List> vertices, boolean allowMultipleEdges) {
this.vertices = new ArrayList<>(vertices);
this.allowMultipleEdges = allowMultipleEdges;
edges = (List>[]) Array.newInstance(List.class, vertices.size());
}
public Graph(List> vertices) {
this(vertices, false);
}
@Override
public int numVertices() {
return vertices.size();
}
@Override
public Vertex getVertex(int idx) {
if (idx < 0 || idx >= vertices.size())
throw new IllegalArgumentException("Invalid index: " + idx);
return vertices.get(idx);
}
@Override
public List> getVertices(int[] indexes) {
List> out = new ArrayList<>(indexes.length);
for (int i : indexes)
out.add(getVertex(i));
return out;
}
@Override
public List> getVertices(int from, int to) {
if (to < from || from < 0 || to >= vertices.size())
throw new IllegalArgumentException("Invalid range: from=" + from + ", to=" + to);
List> out = new ArrayList<>(to - from + 1);
for (int i = from; i <= to; i++)
out.add(getVertex(i));
return out;
}
@Override
public void addEdge(Edge edge) {
if (edge.getFrom() < 0 || edge.getTo() >= vertices.size())
throw new IllegalArgumentException("Invalid edge: " + edge + ", from/to indexes out of range");
List> fromList = edges[edge.getFrom()];
if (fromList == null) {
fromList = new ArrayList<>();
edges[edge.getFrom()] = fromList;
}
addEdgeHelper(edge, fromList);
if (edge.isDirected())
return;
//Add other way too (to allow easy lookup for undirected edges)
List> toList = edges[edge.getTo()];
if (toList == null) {
toList = new ArrayList<>();
edges[edge.getTo()] = toList;
}
addEdgeHelper(edge, toList);
}
@Override
@SuppressWarnings("unchecked")
public List> getEdgesOut(int vertex) {
if (edges[vertex] == null)
return Collections.emptyList();
return new ArrayList<>(edges[vertex]);
}
@Override
public int getVertexDegree(int vertex) {
if (edges[vertex] == null)
return 0;
return edges[vertex].size();
}
@Override
public Vertex getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException {
if (vertex < 0 || vertex >= vertices.size())
throw new IllegalArgumentException("Invalid vertex index: " + vertex);
if (edges[vertex] == null || edges[vertex].isEmpty())
throw new NoEdgesException("Cannot generate random connected vertex: vertex " + vertex
+ " has no outgoing/undirected edges");
int connectedVertexNum = rng.nextInt(edges[vertex].size());
Edge edge = edges[vertex].get(connectedVertexNum);
if (edge.getFrom() == vertex)
return vertices.get(edge.getTo()); //directed or undirected, vertex -> x
else
return vertices.get(edge.getFrom()); //Undirected edge, x -> vertex
}
@Override
public List> getConnectedVertices(int vertex) {
if (vertex < 0 || vertex >= vertices.size())
throw new IllegalArgumentException("Invalid vertex index: " + vertex);
if (edges[vertex] == null)
return Collections.emptyList();
List> list = new ArrayList<>(edges[vertex].size());
for (Edge edge : edges[vertex]) {
list.add(vertices.get(edge.getTo()));
}
return list;
}
@Override
public int[] getConnectedVertexIndices(int vertex) {
int[] out = new int[(edges[vertex] == null ? 0 : edges[vertex].size())];
if (out.length == 0)
return out;
for (int i = 0; i < out.length; i++) {
Edge e = edges[vertex].get(i);
out[i] = (e.getFrom() == vertex ? e.getTo() : e.getFrom());
}
return out;
}
private void addEdgeHelper(Edge edge, List> list) {
if (!allowMultipleEdges) {
//Check to avoid multiple edges
boolean duplicate = false;
if (edge.isDirected()) {
for (Edge e : list) {
if (e.getTo() == edge.getTo()) {
duplicate = true;
break;
}
}
} else {
for (Edge e : list) {
if ((e.getFrom() == edge.getFrom() && e.getTo() == edge.getTo())
|| (e.getTo() == edge.getFrom() && e.getFrom() == edge.getTo())) {
duplicate = true;
break;
}
}
}
if (!duplicate) {
list.add(edge);
}
} else {
//allow multiple/duplicate edges
list.add(edge);
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Graph {");
sb.append("\nVertices {");
for (Vertex v : vertices) {
sb.append("\n\t").append(v);
}
sb.append("\n}");
sb.append("\nEdges {");
for (int i = 0; i < edges.length; i++) {
sb.append("\n\t");
if (edges[i] == null)
continue;
sb.append(i).append(":");
for (Edge e : edges[i]) {
sb.append(" ").append(e);
}
}
sb.append("\n}");
sb.append("\n}");
return sb.toString();
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Graph))
return false;
Graph g = (Graph) o;
if (allowMultipleEdges != g.allowMultipleEdges)
return false;
if (edges.length != g.edges.length)
return false;
if (vertices.size() != g.vertices.size())
return false;
for (int i = 0; i < edges.length; i++) {
if (!edges[i].equals(g.edges[i]))
return false;
}
return vertices.equals(g.vertices);
}
@Override
public int hashCode() {
int result = 23;
result = 31 * result + (allowMultipleEdges ? 1 : 0);
result = 31 * result + Arrays.hashCode(edges);
result = 31 * result + vertices.hashCode();
return result;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy