All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arcadedb.index.vector.HnswVectorIndex Maven / Gradle / Ivy

There is a newer version: 24.11.1
Show newest version
/*
 * Copyright 2023 Arcade Data Ltd
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.arcadedb.index.vector;

import com.arcadedb.database.DatabaseInternal;
import com.arcadedb.database.Identifiable;
import com.arcadedb.database.RID;
import com.arcadedb.engine.Component;
import com.arcadedb.engine.ComponentFactory;
import com.arcadedb.engine.ComponentFile;
import com.arcadedb.exception.RecordNotFoundException;
import com.arcadedb.exception.SchemaException;
import com.arcadedb.graph.Edge;
import com.arcadedb.graph.MutableVertex;
import com.arcadedb.graph.Vertex;
import com.arcadedb.index.IndexCursor;
import com.arcadedb.index.IndexException;
import com.arcadedb.index.IndexInternal;
import com.arcadedb.index.TypeIndex;
import com.arcadedb.index.lsm.LSMTreeIndexAbstract;
import com.arcadedb.index.vector.distance.DistanceFunctionFactory;
import com.arcadedb.log.LogManager;
import com.arcadedb.schema.IndexBuilder;
import com.arcadedb.schema.Schema;
import com.arcadedb.schema.Type;
import com.arcadedb.schema.VectorIndexBuilder;
import com.arcadedb.serializer.json.JSONObject;
import com.arcadedb.utility.FileUtils;
import com.arcadedb.utility.Pair;
import com.github.jelmerk.knn.DistanceFunction;
import com.github.jelmerk.knn.Index;
import com.github.jelmerk.knn.Item;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.util.Murmur3;
import org.eclipse.collections.api.list.primitive.MutableIntList;

import java.io.*;
import java.lang.reflect.*;
import java.util.*;
import java.util.concurrent.locks.*;
import java.util.logging.*;
import java.util.stream.*;

/**
 * This work is derived from the excellent work made by Jelmer Kuperus on https://github.com/jelmerk/hnswlib.
 * 

* Implementation of {@link Index} that implements the hnsw algorithm. * TODO: Check if the global lock interferes with ArcadeDB's tx approach * * @author Luca Garulli ([email protected]) * @see * Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs */ public class HnswVectorIndex extends Component implements com.arcadedb.index.Index, IndexInternal { public interface BuildVectorIndexCallback { void onVertexIndexed(Vertex document, Item item, long totalIndexed); } public interface IgnoreVertexCallback { boolean ignoreVertex(Vertex v); } public static final String FILE_EXT = "hnswidx"; public static final int CURRENT_VERSION = 0; private final DistanceFunction distanceFunction; private final Comparator distanceComparator; private final MaxValueComparator maxValueDistanceComparator; private final int dimensions; private final int maxItemCount; private final int m; private final int maxM; private final int maxM0; private final double levelLambda; private final int ef; private final int efConstruction; private final ReentrantLock globalLock; private final Set excludedCandidates = new HashSet<>(); private final String vertexType; private final String edgeType; private final String vectorPropertyName; private final String idPropertyName; private final String deletedPropertyName; private final Map cache; private final String indexName; private TypeIndex underlyingIndex; public volatile RID entryPointRIDToLoad; public volatile Vertex entryPoint; public static class IndexFactoryHandler implements com.arcadedb.index.IndexFactoryHandler { @Override public IndexInternal create(final IndexBuilder builder) { if (!(builder instanceof VectorIndexBuilder)) throw new IndexException("Expected VectorIndexBuilder but received " + builder); return new HnswVectorIndex<>((VectorIndexBuilder) builder); } } public static class PaginatedComponentFactoryHandlerUnique implements ComponentFactory.PaginatedComponentFactoryHandler { @Override public Component createOnLoad(final DatabaseInternal database, final String name, final String filePath, final int id, final ComponentFile.MODE mode, final int pageSize, final int version) throws IOException { return new HnswVectorIndex(database, name, filePath, id, version); } } protected HnswVectorIndex(final VectorIndexBuilder builder) { super(builder.getDatabase(), builder.getFilePath(), builder.getDatabase().getFileManager().newFileId(), CURRENT_VERSION, builder.getFilePath()); this.dimensions = builder.getDimensions(); this.maxItemCount = builder.getMaxItemCount(); this.distanceFunction = builder.getDistanceFunction(); this.distanceComparator = builder.getDistanceComparator(); this.maxValueDistanceComparator = new MaxValueComparator<>(this.distanceComparator); this.m = builder.getM(); this.maxM = m; this.maxM0 = m * 2; this.levelLambda = 1 / Math.log(this.m); this.efConstruction = Math.max(builder.getEfConstruction(), m); this.ef = builder.getEf(); this.vertexType = builder.getVertexType(); this.edgeType = builder.getEdgeType(); this.vectorPropertyName = builder.getVectorPropertyName(); this.idPropertyName = builder.getIdPropertyName(); this.deletedPropertyName = builder.getDeletedPropertyName(); this.cache = builder.getCache(); this.underlyingIndex = builder.getDatabase().getSchema() .buildTypeIndex(builder.getVertexType(), new String[] { idPropertyName }).withUnique(true).withIgnoreIfExists(true) .withType(Schema.INDEX_TYPE.LSM_TREE).create(); this.underlyingIndex.setAssociatedIndex(this); this.globalLock = new ReentrantLock(); this.indexName = builder.getIndexName() != null ? builder.getIndexName() : vertexType + "[" + idPropertyName + "," + vectorPropertyName + "]"; } /** * Load time. */ protected HnswVectorIndex(final DatabaseInternal database, final String indexName, final String filePath, final int id, final int version) throws IOException { super(database, indexName, id, version, filePath); final String fileContent = FileUtils.readFileAsString(new File(filePath)); final JSONObject json = new JSONObject(fileContent); this.distanceFunction = DistanceFunctionFactory.getImplementationByClassName(json.getString("distanceFunction")); if (distanceFunction == null) throw new IllegalArgumentException("distance function '" + json.getString("distanceFunction") + "' not supported"); this.dimensions = json.getInt("dimensions"); this.distanceComparator = (Comparator) Comparator.naturalOrder(); this.maxValueDistanceComparator = new MaxValueComparator<>(this.distanceComparator); this.maxItemCount = json.getInt("maxItemCount"); this.m = json.getInt("m"); this.maxM = json.getInt("maxM"); this.maxM0 = json.getInt("maxM0"); this.levelLambda = json.getDouble("levelLambda"); this.ef = json.getInt("ef"); this.efConstruction = json.getInt("efConstruction"); if (!json.getString("entryPoint").isEmpty()) { this.entryPointRIDToLoad = new RID(database, json.getString("entryPoint")); } else this.entryPointRIDToLoad = null; this.vertexType = json.getString("vertexType"); this.edgeType = json.getString("edgeType"); this.idPropertyName = json.getString("idPropertyName"); this.vectorPropertyName = json.getString("vectorPropertyName"); this.deletedPropertyName = json.has("deletedPropertyName") ? json.getString("deletedPropertyName") : "deleted"; this.globalLock = new ReentrantLock(); this.cache = null; this.indexName = json.getString("indexName"); } @Override public void onAfterSchemaLoad() { try { this.underlyingIndex = database.getSchema().buildTypeIndex(vertexType, new String[] { idPropertyName }) .withIgnoreIfExists(true).withUnique(true).withType(Schema.INDEX_TYPE.LSM_TREE).create(); this.underlyingIndex.setAssociatedIndex(this); // AFTER THE WHOLE SCHEMA IS LOADED INITIALIZE THE INDEX if (this.entryPointRIDToLoad != null) { try { this.entryPoint = this.entryPointRIDToLoad.asVertex(); } catch (RecordNotFoundException e) { // ENTRYPOINT DELETED, DROP THE INDEX LogManager.instance() .log(this, Level.WARNING, "HNSW index '" + indexName + "' has an invalid entrypoint. The index will be removed"); this.entryPointRIDToLoad = null; database.getSchema().dropIndex(indexName); } } } catch (Exception e) { LogManager.instance().log(this, Level.WARNING, "Error on loading of HNSW index '" + indexName + "'", e); } } @Override public void onAfterCommit() { if (entryPoint != null && !entryPoint.getIdentity().equals(entryPointRIDToLoad)) { // ENTRY POINT IS CHANGED: SAVE THE NEW CONFIGURATION TO DISK save(); entryPointRIDToLoad = entryPoint.getIdentity(); } } @Override public String getName() { return indexName; } public List> findNeighborsFromId(final TId id, final int k) { return findNeighborsFromId(id, k, null); } public List> findNeighborsFromId(final TId id, final int k, IgnoreVertexCallback ignoreVertexCallback) { final Vertex start = get(id); if (start == null) return Collections.emptyList(); return findNeighborsFromVertex(start, k, ignoreVertexCallback); } public List> findNeighborsFromVertex(final Vertex start, final int k, final IgnoreVertexCallback ignoreVertexCallback) { final RID startRID = start.getIdentity(); final TVector vector = getVectorFromVertex(start); final List> neighbors = findNearest(vector, k + 1, ignoreVertexCallback).stream()// .filter(result -> !result.item().getIdentity().equals(startRID))// .limit(k)// .collect(Collectors.toList()); final List> result = new ArrayList<>(neighbors.size()); for (SearchResult neighbor : neighbors) result.add(new Pair(neighbor.item(), neighbor.distance())); return result; } public List> findNeighborsFromVector(final TVector vector, final int k) { return findNeighborsFromVector(vector, k, null); } public List> findNeighborsFromVector(final TVector vector, final int k, final IgnoreVertexCallback ignoreVertexCallback) { final List> neighbors = findNearest(vector, k + 1, ignoreVertexCallback).stream().limit(k) .collect(Collectors.toList()); final List> result = new ArrayList<>(neighbors.size()); for (SearchResult neighbor : neighbors) result.add(new Pair(neighbor.item(), neighbor.distance())); return result; } public void addAll(final List> embeddings, final BuildVectorIndexCallback callback) { int indexed = 0; for (Item embedding : embeddings) { final IndexCursor existent = underlyingIndex.get(new Object[] { embedding.id() }); MutableVertex vertex; if (existent.hasNext()) { vertex = existent.next().asVertex().modify(); final Boolean deleted = vertex.getBoolean(deletedPropertyName); if (deleted != null && deleted) vertex.remove(deletedPropertyName); } else vertex = database.newVertex(vertexType); vertex.set(idPropertyName, embedding.id()).set(vectorPropertyName, embedding.vector()).save(); add(vertex); callback.onVertexIndexed(vertex, embedding, ++indexed); } } public boolean add(Vertex vertex) { final TVector vertexVector = getVectorFromVertex(vertex); if (Array.getLength(vertexVector) != dimensions) throw new IllegalArgumentException( "Item has dimensionality of " + Array.getLength(vertexVector) + " but the index was defined with " + dimensions + " dimensions"); final TId vertexId = getIdFromVertex(vertex); final int vertexMaxLevel = getMaxLevelFromVertex(vertex); final int randomLevel = assignLevel(vertexId, this.levelLambda); globalLock.lock(); try { final Boolean deleted = vertex.getBoolean(deletedPropertyName); if (deleted != null && deleted) { vertex = vertex.modify(); ((MutableVertex) vertex).remove(deletedPropertyName); ((MutableVertex) vertex).save(); } final long totalEdges = vertex.countEdges(Vertex.DIRECTION.OUT, getEdgeType(0)); if (totalEdges > 0) // ALREADY INSERTED return true; vertex = vertex.modify().set("vectorMaxLevel", randomLevel).save(); if (cache != null) cache.put(vertex.getIdentity(), vertex); final RID vertexRID = vertex.getIdentity(); synchronized (excludedCandidates) { excludedCandidates.add(vertexRID); } final Vertex entryPointCopy = entryPoint; try { if (entryPoint != null && randomLevel <= getMaxLevelFromVertex(entryPoint)) { globalLock.unlock(); } Vertex currObj = entryPointCopy; final int entryPointCopyMaxLevel = getMaxLevelFromVertex(entryPointCopy); if (currObj != null) { if (vertexMaxLevel < entryPointCopyMaxLevel) { final TVector vector = getVectorFromVertex(currObj); if (vector == null) { LogManager.instance().log(this, Level.WARNING, "Vector not found in vertex %s", currObj); throw new IndexException("Embeddings not found in object " + currObj); } TDistance curDist = distanceFunction.distance(vertexVector, vector); for (int activeLevel = entryPointCopyMaxLevel; activeLevel > vertexMaxLevel; activeLevel--) { boolean changed = true; while (changed) { changed = false; synchronized (currObj) { final Iterator candidateConnections = getConnectionsFromVertex(currObj, activeLevel); while (candidateConnections.hasNext()) { final Vertex candidateNode = candidateConnections.next(); final TVector candidateNodeVector = getVectorFromVertex(candidateNode); if (candidateNodeVector == null) { // INVALID LogManager.instance().log(this, Level.WARNING, "Vector not found in vertex %s", candidateNode); continue; } final TDistance candidateDistance = distanceFunction.distance(vertexVector, candidateNodeVector); if (lt(candidateDistance, curDist)) { curDist = candidateDistance; currObj = candidateNode; changed = true; } } } } } } for (int level = Math.min(randomLevel, entryPointCopyMaxLevel); level >= 0; level--) { final PriorityQueue> topCandidates = searchBaseLayer(currObj, vertexVector, efConstruction, level, null); final boolean entryPointDeleted = isDeletedFromVertex(entryPointCopy); if (entryPointDeleted) { TDistance distance = distanceFunction.distance(vertexVector, getVectorFromVertex(entryPointCopy)); topCandidates.add(new NodeIdAndDistance<>(entryPointCopy.getIdentity(), distance, maxValueDistanceComparator)); if (topCandidates.size() > efConstruction) topCandidates.poll(); } mutuallyConnectNewElement(vertex, topCandidates, level); } } // zoom out to the highest level if (entryPoint == null || vertexMaxLevel > entryPointCopyMaxLevel) // this is thread safe because we get the global lock when we add a level this.entryPoint = vertex; return true; } finally { synchronized (excludedCandidates) { excludedCandidates.remove(vertexRID); } } } finally { if (globalLock.isHeldByCurrentThread()) { globalLock.unlock(); } } } private Iterator getConnectionsFromVertex(final Vertex vertex, final int level) { return vertex.getVertices(Vertex.DIRECTION.OUT, edgeType + level).iterator(); } private int countConnectionsFromVertex(final Vertex vertex, final int level) { return (int) vertex.countEdges(Vertex.DIRECTION.OUT, edgeType + level); } private int getMaxLevelFromVertex(final Vertex vertex) { if (vertex == null) return 0; final Integer vectorMaxLevel = vertex.getInteger("vectorMaxLevel"); return vectorMaxLevel != null ? vectorMaxLevel : 0; } private Vertex loadVertexFromRID(final Identifiable rid) { if (rid instanceof Vertex) return (Vertex) rid; Vertex vertex = null; if (cache != null) vertex = cache.get(rid); if (vertex == null) vertex = rid.asVertex(); return vertex; } private void mutuallyConnectNewElement(final Vertex newNode, final PriorityQueue> topCandidates, final int level) { final int bestN = level == 0 ? this.maxM0 : this.maxM; final RID newNodeId = newNode.getIdentity(); final TVector newItemVector = getVectorFromVertex(newNode); getNeighborsByHeuristic2(topCandidates, m); while (!topCandidates.isEmpty()) { final RID selectedNeighbourId = topCandidates.poll().nodeId; synchronized (excludedCandidates) { if (excludedCandidates.contains(selectedNeighbourId)) { continue; } } // CREATE THE EDGE TYPE IF NOT PRESENT final String edgeTypeName = getEdgeType(level); database.getSchema().getOrCreateEdgeType(edgeTypeName); newNode.newEdge(edgeTypeName, selectedNeighbourId, false); final Vertex neighbourNode = loadVertexFromRID(selectedNeighbourId); final TVector neighbourVector = getVectorFromVertex(neighbourNode); final int neighbourConnectionsAtLevelTotal = countConnectionsFromVertex(neighbourNode, level); final Iterator neighbourConnectionsAtLevel = getConnectionsFromVertex(neighbourNode, level); if (neighbourConnectionsAtLevelTotal < bestN) { neighbourNode.newEdge(edgeTypeName, newNode, false); } else { // finding the "weakest" element to replace it with the new one final TDistance dMax = distanceFunction.distance(newItemVector, neighbourVector); final Comparator> comparator = Comparator.>naturalOrder() .reversed(); final PriorityQueue> candidates = new PriorityQueue<>(comparator); candidates.add(new NodeIdAndDistance<>(newNodeId, dMax, maxValueDistanceComparator)); neighbourConnectionsAtLevel.forEachRemaining(neighbourConnection -> { final TDistance dist = distanceFunction.distance(neighbourVector, getVectorFromVertex(neighbourConnection)); candidates.add(new NodeIdAndDistance<>(neighbourConnection.getIdentity(), dist, maxValueDistanceComparator)); }); getNeighborsByHeuristic2(candidates, bestN); while (!candidates.isEmpty()) { neighbourNode.newEdge(edgeTypeName, candidates.poll().nodeId, false); } } } } public TypeIndex getUnderlyingIndex() { return underlyingIndex; } public List> findNearest(final TVector destination, final int k, final IgnoreVertexCallback ignoreVertexCallback) { if (entryPoint == null) return Collections.emptyList(); final Vertex entryPointCopy = entryPoint; Vertex currObj = entryPointCopy; final TVector vector = getVectorFromVertex(currObj); if (vector == null) { LogManager.instance().log(this, Level.WARNING, "Vector not found in vertex %s", currObj); return Collections.emptyList(); } TDistance curDist = distanceFunction.distance(destination, vector); for (int activeLevel = getMaxLevelFromVertex(entryPointCopy); activeLevel > 0; activeLevel--) { boolean changed = true; while (changed) { changed = false; final Iterator candidateConnections = getConnectionsFromVertex(currObj, activeLevel); while (candidateConnections.hasNext()) { final Vertex candidateNode = candidateConnections.next(); TDistance candidateDistance = distanceFunction.distance(destination, getVectorFromVertex(candidateNode)); if (lt(candidateDistance, curDist)) { curDist = candidateDistance; currObj = candidateNode; changed = true; } } } } final PriorityQueue> topCandidates = searchBaseLayer(currObj, destination, Math.max(ef, k), 0, ignoreVertexCallback); while (topCandidates.size() > k) { topCandidates.poll(); } List> results = new ArrayList<>(topCandidates.size()); while (!topCandidates.isEmpty()) { NodeIdAndDistance pair = topCandidates.poll(); results.add(0, new SearchResult<>(loadVertexFromRID(pair.nodeId), pair.distance, maxValueDistanceComparator)); } return results; } private PriorityQueue> searchBaseLayer(final Vertex entryPointNode, final TVector destination, final int k, final int layer, final IgnoreVertexCallback ignoreVertexCallback) { final Set visitedNodes = new HashSet<>(); final PriorityQueue> topCandidates = new PriorityQueue<>( Comparator.>naturalOrder().reversed()); final PriorityQueue> candidateSet = new PriorityQueue<>(); TDistance lowerBound; if (!ignoreVertex(entryPointNode, ignoreVertexCallback)) { final TVector entryPointVector = getVectorFromVertex(entryPointNode); final TDistance distance = distanceFunction.distance(destination, entryPointVector); final NodeIdAndDistance pair = new NodeIdAndDistance<>(entryPointNode.getIdentity(), distance, maxValueDistanceComparator); topCandidates.add(pair); lowerBound = distance; candidateSet.add(pair); } else { lowerBound = MaxValueComparator.maxValue(); NodeIdAndDistance pair = new NodeIdAndDistance<>(entryPointNode.getIdentity(), lowerBound, maxValueDistanceComparator); candidateSet.add(pair); } visitedNodes.add(entryPointNode.getIdentity()); while (!candidateSet.isEmpty()) { final NodeIdAndDistance currentPair = candidateSet.poll(); if (gt(currentPair.distance, lowerBound)) break; final Vertex node = loadVertexFromRID(currentPair.nodeId); final Iterator candidates = getConnectionsFromVertex(node, layer); while (candidates.hasNext()) { final Vertex candidateNode = candidates.next(); if (!visitedNodes.contains(candidateNode.getIdentity())) { visitedNodes.add(candidateNode.getIdentity()); final TVector vector = getVectorFromVertex(candidateNode); if (vector == null) { // INVALID LogManager.instance().log(this, Level.WARNING, "Vector not found in vertex %s", candidateNode); continue; } final TDistance candidateDistance = distanceFunction.distance(destination, vector); if (topCandidates.size() < k || gt(lowerBound, candidateDistance)) { final NodeIdAndDistance candidatePair = new NodeIdAndDistance<>(candidateNode.getIdentity(), candidateDistance, maxValueDistanceComparator); candidateSet.add(candidatePair); if (!ignoreVertex(candidateNode, ignoreVertexCallback)) topCandidates.add(candidatePair); if (topCandidates.size() > k) topCandidates.poll(); if (!topCandidates.isEmpty()) lowerBound = topCandidates.peek().distance; } } } } return topCandidates; } /** * Returns the dimensionality of the items stored in this index. * * @return the dimensionality of the items stored in this index */ public int getDimensions() { return dimensions; } /** * Returns the number of bi-directional links created for every new element during construction. * * @return the number of bi-directional links created for every new element during construction */ public int getM() { return m; } /** * The size of the dynamic list for the nearest neighbors (used during the search) * * @return The size of the dynamic list for the nearest neighbors */ public int getEf() { return ef; } /** * Returns the parameter has the same meaning as ef, but controls the index time / index precision. * * @return the parameter has the same meaning as ef, but controls the index time / index precision */ public int getEfConstruction() { return efConstruction; } /** * Returns the distance function. * * @return the distance function */ public DistanceFunction getDistanceFunction() { return distanceFunction; } /** * Returns the comparator used to compare distances. * * @return the comparator used to compare distance */ public Comparator getDistanceComparator() { return distanceComparator; } /** * Returns the maximum number of items the index can hold. * * @return the maximum number of items the index can hold */ public int getMaxItemCount() { return maxItemCount; } public void save(OutputStream out) throws IOException { try (ObjectOutputStream oos = new ObjectOutputStream(out)) { oos.writeObject(this); } } private int assignLevel(final TId value, final double lambda) { // by relying on the external id to come up with the level, the graph construction should be a lot more stable // see : https://github.com/nmslib/hnswlib/issues/28 final int hashCode = value.hashCode(); final byte[] bytes = new byte[] { (byte) (hashCode >> 24), (byte) (hashCode >> 16), (byte) (hashCode >> 8), (byte) hashCode }; final double random = Math.abs((double) Murmur3.hash32(bytes) / (double) Integer.MAX_VALUE); final double r = -Math.log(random) * lambda; return (int) r; } private boolean lt(final TDistance x, final TDistance y) { return maxValueDistanceComparator.compare(x, y) < 0; } private boolean gt(final TDistance x, final TDistance y) { return maxValueDistanceComparator.compare(x, y) > 0; } public TId getIdFromVertex(final Vertex vertex) { return (TId) vertex.get(idPropertyName); } public TVector getVectorFromVertex(final Vertex vertex) { return (TVector) vertex.get(vectorPropertyName); } public boolean isDeletedFromVertex(final Vertex vertex) { final Boolean deleted = vertex.getBoolean(deletedPropertyName); return deleted != null && deleted; } public boolean ignoreVertex(final Vertex vertex, final IgnoreVertexCallback ignoreVertexCallback) { if (isDeletedFromVertex(vertex)) return true; if (ignoreVertexCallback != null) return ignoreVertexCallback.ignoreVertex(vertex); return false; } public int getDimensionFromVertex(final Vertex vertex) { return Array.getLength(getVectorFromVertex(vertex)); } public String getEdgeType(final int level) { return edgeType + level; } static class NodeIdAndDistance implements Comparable> { final RID nodeId; final TDistance distance; final Comparator distanceComparator; NodeIdAndDistance(final RID nodeId, final TDistance distance, final Comparator distanceComparator) { this.nodeId = nodeId; this.distance = distance; this.distanceComparator = distanceComparator; } @Override public int compareTo(NodeIdAndDistance o) { return distanceComparator.compare(distance, o.distance); } } static class MaxValueComparator implements Comparator, Serializable { private static final long serialVersionUID = 1L; private final Comparator delegate; MaxValueComparator(Comparator delegate) { this.delegate = delegate; } @Override public int compare(final TDistance o1, final TDistance o2) { return o1 == null ? o2 == null ? 0 : 1 : o2 == null ? -1 : delegate.compare(o1, o2); } static TDistance maxValue() { return null; } } public void save() { try { FileUtils.writeFile(new File(filePath), toJSON().toString()); } catch (IOException e) { throw new IndexException("Error on saving HNSW index '" + indexName + "'", e); } } @Override public JSONObject toJSON() { final JSONObject json = new JSONObject(); json.put("type", getType()); json.put("indexName", getName()); json.put("version", CURRENT_VERSION); json.put("dimensions", dimensions); json.put("distanceFunction", distanceFunction.getClass().getSimpleName()); json.put("distanceComparator", distanceComparator.getClass().getSimpleName()); json.put("maxItemCount", maxItemCount); json.put("m", m); json.put("maxM", maxM); json.put("maxM0", maxM0); json.put("levelLambda", levelLambda); json.put("ef", ef); json.put("efConstruction", efConstruction); json.put("levelLambda", levelLambda); json.put("entryPoint", entryPoint == null ? "" : entryPoint.getIdentity().toString()); json.put("vertexType", vertexType); json.put("edgeType", edgeType); json.put("idPropertyName", idPropertyName); json.put("vectorPropertyName", vectorPropertyName); return json; } @Override public IndexInternal getAssociatedIndex() { return null; } @Override public void drop() { // KEEP THE UNDERLYING INDEX ALIVE TO ALLOW THE REBUILD WITHOUT CALCULATING THE EMBEDDINGS // if (underlyingIndex != null) // database.getSchema().dropIndex(underlyingIndex.getName()); try { if (underlyingIndex != null) { database.transaction(() -> { final IndexCursor it = underlyingIndex.iterator(true); while (it.hasNext()) { try { final Identifiable next = it.next(); if (next != null) { final Vertex vertex = next.asVertex(); for (int level = 0; level <= getMaxLevelFromVertex(vertex); level++) { try { for (Edge e : vertex.getEdges(Vertex.DIRECTION.BOTH, getEdgeType(level))) e.delete(); } catch (RecordNotFoundException | SchemaException e) { // IGNORE IT } } } } catch (RecordNotFoundException e) { // IGNORE IT } } }); } } catch (Exception e) { LogManager.instance().log(this, Level.WARNING, "Error on scanning the vector index to delete edges", e); } final File cfg = new File(filePath); if (cfg.exists()) cfg.delete(); } @Override public Map getStats() { return underlyingIndex.getStats(); } @Override public LSMTreeIndexAbstract.NULL_STRATEGY getNullStrategy() { return underlyingIndex.getNullStrategy(); } @Override public void setNullStrategy(final LSMTreeIndexAbstract.NULL_STRATEGY nullStrategy) { underlyingIndex.setNullStrategy(nullStrategy); } @Override public boolean isUnique() { return true; } @Override public boolean supportsOrderedIterations() { return underlyingIndex.supportsOrderedIterations(); } @Override public boolean isAutomatic() { return underlyingIndex != null ? underlyingIndex.isAutomatic() : false; } @Override public int getPageSize() { return underlyingIndex.getPageSize(); } @Override public long build(final int buildIndexBatchSize, final BuildIndexCallback callback) { return underlyingIndex.build(buildIndexBatchSize, callback); } public long build(final HnswVectorIndexRAM origin, final int buildIndexBatchSize, final BuildVectorIndexCallback vertexCreationCallback, final BuildIndexCallback edgeCallback) { if (origin != null) { // IMPORT FROM RAM Index final RID[] pointersToRIDMapping = new RID[origin.nodeCount]; database.begin(); // SAVE ALL THE NODES AS VERTICES AND KEEP AN ARRAY OF RIDS TO BUILD EDGES LATER int maxLevel = 0; HnswVectorIndexRAM.ItemIterator iter = origin.iterateNodes(); for (int totalVertices = 0; iter.hasNext(); ++totalVertices) { final HnswVectorIndexRAM.Node node = iter.next(); final int nodeMaxLevel = node.maxLevel(); if (nodeMaxLevel > maxLevel) maxLevel = nodeMaxLevel; final MutableVertex vertex; final IndexCursor existent = underlyingIndex.get(new Object[] { node.item.id() }); if (existent.hasNext()) { vertex = existent.next().asVertex().modify(); final Boolean deleted = vertex.getBoolean(deletedPropertyName); if (deleted != null && deleted) vertex.remove(deletedPropertyName); } else vertex = database.newVertex(vertexType); vertex.set(idPropertyName, node.item.id()).set(vectorPropertyName, node.item.vector()); if (nodeMaxLevel > 0) // SAVE MAX LEVEL INTO THE VERTEX. IF NOT PRESENT, MEANS 0 vertex.set("vectorMaxLevel", nodeMaxLevel); vertex.save(); if (vertexCreationCallback != null) vertexCreationCallback.onVertexIndexed(vertex, node.item, totalVertices); pointersToRIDMapping[node.id] = vertex.getIdentity(); if (totalVertices % buildIndexBatchSize == 0) { database.commit(); database.begin(); } } database.commit(); final Integer entryPoint = origin.getEntryPoint(); if (entryPoint != null) this.entryPoint = pointersToRIDMapping[entryPoint].asVertex(); // BUILD ALL EDGE TYPES (ONE PER LEVEL) for (int level = 0; level <= maxLevel; level++) { // ASSURE THE EDGE TYPE IS CREATED IN THE DATABASE database.getSchema().getOrCreateEdgeType(getEdgeType(level), 1); } database.begin(); // BUILD THE EDGES long totalVertices = 0L; long totalEdges = 0L; iter = origin.iterateNodes(); for (int txCounter = 0; iter.hasNext(); ++txCounter) { final HnswVectorIndexRAM.Node node = iter.next(); final Vertex source = pointersToRIDMapping[node.id].asVertex(); ++totalVertices; final MutableIntList[] connections = node.connections(); for (int level = 0; level < connections.length; level++) { final String edgeTypeLevel = getEdgeType(level); final MutableIntList pointers = connections[level]; for (int i = 0; i < pointers.size(); i++) { final int pointer = pointers.get(i); final RID destination = pointersToRIDMapping[pointer]; if (destination == null) LogManager.instance().log(this, Level.WARNING, "Destination vertex %d is null", pointer); else { source.newEdge(edgeTypeLevel, destination, false); ++totalEdges; } } } if (txCounter % buildIndexBatchSize == 0) { database.commit(); database.begin(); } if (edgeCallback != null) edgeCallback.onDocumentIndexed(source, totalEdges); } database.commit(); save(); return totalVertices; } // TODO: NOT SUPPORTED WITHOUT RAM INDEX return 0L; } @Override public boolean equals(final Object obj) { if (!(obj instanceof HnswVectorIndex)) return false; return componentName.equals(((HnswVectorIndex) obj).componentName) && underlyingIndex.equals(obj); } public List getSubIndexes() { return underlyingIndex.getSubIndexes(); } @Override public int hashCode() { return Objects.hash(componentName, underlyingIndex.hashCode()); } @Override public String toString() { return indexName; } @Override public void setMetadata(final String name, final String[] propertyNames, final int associatedBucketId) { } @Override public boolean setStatus(INDEX_STATUS[] expectedStatuses, INDEX_STATUS newStatus) { return false; } @Override public Component getComponent() { return this; } @Override public Type[] getKeyTypes() { return underlyingIndex.getKeyTypes(); } @Override public byte[] getBinaryKeyTypes() { return underlyingIndex.getBinaryKeyTypes(); } @Override public List getFileIds() { if (underlyingIndex == null) // NOT PROPERLY BUILT YET return Collections.emptyList(); return underlyingIndex.getFileIds(); } @Override public void setTypeIndex(final TypeIndex typeIndex) { throw new UnsupportedOperationException("setTypeIndex"); } @Override public TypeIndex getTypeIndex() { return null; } @Override public int getAssociatedBucketId() { return -1; } public void addIndexOnBucket(final IndexInternal index) { underlyingIndex.addIndexOnBucket(index); } public void removeIndexOnBucket(final IndexInternal index) { underlyingIndex.removeIndexOnBucket(index); } public IndexInternal[] getIndexesOnBuckets() { return underlyingIndex.getIndexesOnBuckets(); } public List getIndexesByKeys(final Object[] keys) { return underlyingIndex.getIndexesByKeys(keys); } public IndexCursor iterator(final boolean ascendingOrder) { return underlyingIndex.iterator(ascendingOrder); } public IndexCursor iterator(final boolean ascendingOrder, final Object[] fromKeys, final boolean inclusive) { return underlyingIndex.iterator(ascendingOrder, fromKeys, inclusive); } public IndexCursor range(final boolean ascending, final Object[] beginKeys, final boolean beginKeysInclusive, final Object[] endKeys, boolean endKeysInclusive) { return underlyingIndex.range(ascending, beginKeys, beginKeysInclusive, endKeys, endKeysInclusive); } @Override public IndexCursor get(final Object[] keys) { return underlyingIndex.get(keys); } @Override public IndexCursor get(final Object[] keys, final int limit) { return underlyingIndex.get(keys, limit); } @Override public void put(final Object[] keys, RID[] rid) { underlyingIndex.put(keys, rid); } @Override public void remove(final Object[] keys) { globalLock.lock(); try { final IndexCursor cursor = underlyingIndex.get(new Object[] { keys[0] }); if (!cursor.hasNext()) return; final Vertex vertex = loadVertexFromRID(cursor.next()); vertex.modify().set(deletedPropertyName, true).save(); //underlyingIndex.remove(keys); } finally { globalLock.unlock(); } } @Override public void remove(final Object[] keys, final Identifiable rid) { globalLock.lock(); try { final IndexCursor cursor = underlyingIndex.get(new Object[] { keys[0] }); if (!cursor.hasNext()) return; final Identifiable itemRID = cursor.next(); if (!itemRID.equals(rid)) return; final Vertex vertex = loadVertexFromRID(itemRID); vertex.modify().set(deletedPropertyName, true).save(); // underlyingIndex.remove(keys, rid); } finally { globalLock.unlock(); } } @Override public long countEntries() { return underlyingIndex.countEntries(); } @Override public boolean compact() throws IOException, InterruptedException { return underlyingIndex.compact(); } @Override public boolean isCompacting() { return underlyingIndex.isCompacting(); } @Override public boolean isValid() { return underlyingIndex.isValid(); } @Override public boolean scheduleCompaction() { return underlyingIndex.scheduleCompaction(); } @Override public String getMostRecentFileName() { return underlyingIndex.getMostRecentFileName(); } @Override public Schema.INDEX_TYPE getType() { return Schema.INDEX_TYPE.HSNW; } @Override public String getTypeName() { return vertexType; } @Override public List getPropertyNames() { return List.of(idPropertyName, vectorPropertyName); } @Override public void close() { underlyingIndex.close(); } private Vertex get(final Object id) { globalLock.lock(); try { final IndexCursor cursor = underlyingIndex.get(new Object[] { id }); if (!cursor.hasNext()) return null; return loadVertexFromRID(cursor.next()); } finally { globalLock.unlock(); } } private void getNeighborsByHeuristic2(final PriorityQueue> topCandidates, final int m) { if (topCandidates.size() < m) return; final PriorityQueue> queueClosest = new PriorityQueue<>(); final List> returnList = new ArrayList<>(); while (!topCandidates.isEmpty()) { queueClosest.add(topCandidates.poll()); } while (!queueClosest.isEmpty()) { if (returnList.size() >= m) break; final NodeIdAndDistance currentPair = queueClosest.poll(); final TDistance distToQuery = currentPair.distance; boolean good = true; for (NodeIdAndDistance secondPair : returnList) { final TDistance curdist = distanceFunction.distance(// getVectorFromVertex(loadVertexFromRID(secondPair.nodeId)),// getVectorFromVertex(loadVertexFromRID(currentPair.nodeId))); if (lt(curdist, distToQuery)) { good = false; break; } } if (good) { returnList.add(currentPair); } } topCandidates.addAll(returnList); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy