All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.vertexium.accumulo.AccumuloFindPathStrategy Maven / Gradle / Ivy

There is a newer version: 4.10.0
Show newest version
package org.vertexium.accumulo;

import org.apache.accumulo.core.client.IteratorSetting;
import org.apache.accumulo.core.client.ScannerBase;
import org.apache.accumulo.core.data.Key;
import org.apache.accumulo.core.data.Value;
import org.apache.accumulo.core.trace.Span;
import org.apache.accumulo.core.trace.Trace;
import org.vertexium.*;
import org.vertexium.accumulo.iterator.ConnectedVertexIdsIterator;
import org.vertexium.accumulo.util.RangeUtils;
import org.vertexium.util.IterableUtils;
import org.vertexium.util.VertexiumLogger;
import org.vertexium.util.VertexiumLoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;

import static org.vertexium.util.StreamUtils.stream;

public class AccumuloFindPathStrategy {
    private static final VertexiumLogger LOGGER = VertexiumLoggerFactory.getLogger(AccumuloFindPathStrategy.class);
    private final AccumuloGraph graph;
    private final FindPathOptions options;
    private final ProgressCallback progressCallback;
    private final Authorizations authorizations;
    private final String[] deflatedLabels;
    private final String[] deflatedExcludedLabels;

    public AccumuloFindPathStrategy(
            AccumuloGraph graph,
            FindPathOptions options,
            ProgressCallback progressCallback,
            Authorizations authorizations
    ) {
        this.graph = graph;
        this.options = options;
        this.progressCallback = progressCallback;
        this.authorizations = authorizations;
        this.deflatedLabels = deflateLabels(graph.getNameSubstitutionStrategy(), options.getLabels());
        this.deflatedExcludedLabels = deflateLabels(graph.getNameSubstitutionStrategy(), options.getExcludedLabels());
    }

    private static String[] deflateLabels(AccumuloNameSubstitutionStrategy nameSubstitutionStrategy, String[] labels) {
        if (labels == null) {
            return null;
        }
        String[] results = new String[labels.length];
        for (int i = 0; i < labels.length; i++) {
            String label = labels[i];
            results[i] = nameSubstitutionStrategy.deflate(label);
        }
        return results;
    }

    public Iterable findPaths() {
        progressCallback.progress(0, ProgressCallback.Step.FINDING_PATH);

        List foundPaths = new ArrayList<>();
        if (options.getMaxHops() < 1) {
            throw new IllegalArgumentException("maxHops cannot be less than 1");
        } else if (options.getMaxHops() == 1) {
            Set sourceConnectedVertexIds = getConnectedVertexIds(options.getSourceVertexId());
            if (sourceConnectedVertexIds.contains(options.getDestVertexId())) {
                foundPaths.add(new Path(options.getSourceVertexId(), options.getDestVertexId()));
            }
        } else if (options.getMaxHops() == 2) {
            findPathsSetIntersection(foundPaths);
        } else {
            findPathsBreadthFirst(foundPaths, options.getSourceVertexId(), options.getDestVertexId(), options.getMaxHops());
        }

        progressCallback.progress(1, ProgressCallback.Step.COMPLETE);
        return foundPaths;
    }

    private void findPathsSetIntersection(List foundPaths) {
        String sourceVertexId = options.getSourceVertexId();
        String destVertexId = options.getDestVertexId();

        Set vertexIds = new HashSet<>();
        vertexIds.add(sourceVertexId);
        vertexIds.add(destVertexId);
        Map> connectedVertexIds = getConnectedVertexIds(vertexIds);

        progressCallback.progress(0.1, ProgressCallback.Step.SEARCHING_SOURCE_VERTEX_EDGES);
        Set sourceVertexConnectedVertexIds = connectedVertexIds.get(sourceVertexId);
        if (sourceVertexConnectedVertexIds == null) {
            return;
        }

        progressCallback.progress(0.3, ProgressCallback.Step.SEARCHING_DESTINATION_VERTEX_EDGES);
        Set destVertexConnectedVertexIds = connectedVertexIds.get(destVertexId);
        if (destVertexConnectedVertexIds == null) {
            return;
        }

        if (sourceVertexConnectedVertexIds.contains(destVertexId)) {
            foundPaths.add(new Path(sourceVertexId, destVertexId));
            if (options.isGetAnyPath()) {
                return;
            }
        }

        progressCallback.progress(0.6, ProgressCallback.Step.MERGING_EDGES);
        sourceVertexConnectedVertexIds.retainAll(destVertexConnectedVertexIds);

        progressCallback.progress(0.9, ProgressCallback.Step.ADDING_PATHS);
        foundPaths.addAll(
                sourceVertexConnectedVertexIds.stream()
                        .map(connectedVertexId -> new Path(sourceVertexId, connectedVertexId, destVertexId))
                        .collect(Collectors.toList())
        );
    }

    private void findPathsBreadthFirst(List foundPaths, String sourceVertexId, String destVertexId, int hops) {
        Map> connectedVertexIds = getConnectedVertexIds(sourceVertexId, destVertexId);
        // start at 2 since we already got the source and dest vertex connected vertex ids
        for (int i = 2; i < hops; i++) {
            progressCallback.progress((double) i / (double) hops, ProgressCallback.Step.FINDING_PATH);
            Set vertexIdsToSearch = new HashSet<>();
            for (Map.Entry> entry : connectedVertexIds.entrySet()) {
                vertexIdsToSearch.addAll(entry.getValue());
            }
            vertexIdsToSearch.removeAll(connectedVertexIds.keySet());
            Map> r = getConnectedVertexIds(vertexIdsToSearch);
            connectedVertexIds.putAll(r);
        }
        progressCallback.progress(0.9, ProgressCallback.Step.ADDING_PATHS);
        Set seenVertices = new HashSet<>();
        Path currentPath = new Path(sourceVertexId);
        findPathsRecursive(connectedVertexIds, foundPaths, sourceVertexId, destVertexId, hops, seenVertices, currentPath, progressCallback);
    }

    private void findPathsRecursive(
            Map> connectedVertexIds,
            List foundPaths,
            final String sourceVertexId,
            String destVertexId,
            int hops,
            Set seenVertices,
            Path currentPath,
            @SuppressWarnings("UnusedParameters") ProgressCallback progressCallback
    ) {
        if (options.isGetAnyPath() && foundPaths.size() == 1) {
            return;
        }
        seenVertices.add(sourceVertexId);
        if (sourceVertexId.equals(destVertexId)) {
            foundPaths.add(currentPath);
        } else if (hops > 0) {
            Set vertexIds = connectedVertexIds.get(sourceVertexId);
            if (vertexIds != null) {
                for (String childId : vertexIds) {
                    if (!seenVertices.contains(childId)) {
                        findPathsRecursive(connectedVertexIds, foundPaths, childId, destVertexId, hops - 1, seenVertices, new Path(currentPath, childId), progressCallback);
                    }
                }
            }
        }
        seenVertices.remove(sourceVertexId);
    }

    private Set getConnectedVertexIds(String vertexId) {
        Set vertexIds = new HashSet<>();
        vertexIds.add(vertexId);
        Map> results = getConnectedVertexIds(vertexIds);
        Set vertexIdResults = results.get(vertexId);
        if (vertexIdResults == null) {
            return new HashSet<>();
        }
        return vertexIdResults;
    }

    private Map> getConnectedVertexIds(String vertexId1, String vertexId2) {
        Set vertexIds = new HashSet<>();
        vertexIds.add(vertexId1);
        vertexIds.add(vertexId2);
        return getConnectedVertexIds(vertexIds);
    }

    private Map> getConnectedVertexIds(Set vertexIds) {
        Span trace = Trace.start("getConnectedVertexIds");
        try {
            if (LOGGER.isTraceEnabled()) {
                LOGGER.trace("getConnectedVertexIds:\n  %s", IterableUtils.join(vertexIds, "\n  "));
            }

            if (vertexIds.size() == 0) {
                return new HashMap<>();
            }

            List ranges = new ArrayList<>();
            for (String vertexId : vertexIds) {
                ranges.add(RangeUtils.createRangeFromString(vertexId));
            }

            int maxVersions = 1;
            Long startTime = null;
            Long endTime = null;
            ScannerBase scanner = graph.createElementScanner(
                    FetchHints.EDGE_REFS,
                    ElementType.VERTEX,
                    maxVersions,
                    startTime,
                    endTime,
                    ranges,
                    false,
                    authorizations
            );

            IteratorSetting connectedVertexIdsIteratorSettings = new IteratorSetting(
                    1000,
                    ConnectedVertexIdsIterator.class.getSimpleName(),
                    ConnectedVertexIdsIterator.class
            );
            ConnectedVertexIdsIterator.setLabels(connectedVertexIdsIteratorSettings, deflatedLabels);
            ConnectedVertexIdsIterator.setExcludedLabels(connectedVertexIdsIteratorSettings, deflatedExcludedLabels);
            scanner.addScanIterator(connectedVertexIdsIteratorSettings);

            final long timerStartTime = System.currentTimeMillis();
            try {
                Map> results = new HashMap<>();
                for (Map.Entry row : scanner) {
                    try {
                        Map verticesExist = graph.doVerticesExist(ConnectedVertexIdsIterator.decodeValue(row.getValue()), authorizations);
                        Set rowVertexIds = stream(verticesExist.keySet())
                                .filter(key -> verticesExist.getOrDefault(key, false))
                                .collect(Collectors.toSet());
                        results.put(row.getKey().getRow().toString(), rowVertexIds);
                    } catch (IOException e) {
                        throw new VertexiumException("Could not decode vertex ids for row: " + row.getKey().toString(), e);
                    }
                }
                return results;
            } finally {
                scanner.close();
                AccumuloGraph.GRAPH_LOGGER.logEndIterator(System.currentTimeMillis() - timerStartTime);
            }
        } finally {
            trace.stop();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy