All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arcadedb.integration.importer.vector.TextEmbeddingsImporter Maven / Gradle / Ivy

There is a newer version: 24.11.2
Show newest version
/*
 * Copyright 2023 Arcade Data Ltd
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package com.arcadedb.integration.importer.vector;

import com.arcadedb.database.Database;
import com.arcadedb.database.DatabaseFactory;
import com.arcadedb.database.DatabaseInternal;
import com.arcadedb.index.vector.HnswVectorIndexRAM;
import com.arcadedb.index.vector.VectorUtils;
import com.arcadedb.index.vector.distance.DistanceFunctionFactory;
import com.arcadedb.integration.importer.ConsoleLogger;
import com.arcadedb.integration.importer.ImporterContext;
import com.arcadedb.integration.importer.ImporterSettings;
import com.arcadedb.schema.Type;
import com.arcadedb.utility.CodeUtils;
import com.arcadedb.utility.DateUtils;
import com.github.jelmerk.knn.DistanceFunction;

import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.*;
import java.util.stream.*;

/**
 * Imports Embeddings in arbitrary dimensions.
 *
 * @author Luca Garulli ([email protected])
 */
public class TextEmbeddingsImporter {
  private final    InputStream      inputStream;
  private final    ImporterSettings settings;
  private final    ConsoleLogger    logger;
  private          int              m                    = 16;
  private          int              ef                   = 256;
  private          int              efConstruction       = 256;
  private          boolean          normalizeVectors     = false;
  private          String           databasePath;
  private          boolean          overwriteDatabase    = false;
  private          long             errors               = 0L;
  private          long             warnings             = 0L;
  private          DatabaseFactory  factory;
  private          Database         database;
  private          long             beginTime;
  private          boolean          error                = false;
  private          ImporterContext  context              = new ImporterContext();
  private          String           vectorTypeName       = "Float";
  private          String           distanceFunctionName = "InnerProduct";
  private          String           vectorPropertyName   = "vector";
  private          String           idPropertyName       = "name";
  private          String           deletedPropertyName  = "deleted";
  private volatile long             embeddingsParsed     = 0L;
  private volatile long             indexedEmbedding     = 0L;
  private volatile long             verticesCreated      = 0L;
  private volatile long             verticesConnected    = 0L;

  public TextEmbeddingsImporter(final DatabaseInternal database, final InputStream inputStream, final ImporterSettings settings) throws ClassNotFoundException {
    this.settings = settings;
    this.database = database;
    this.databasePath = database.getDatabasePath();
    this.inputStream = inputStream;
    this.logger = new ConsoleLogger(settings.verboseLevel);

    if (settings.options.containsKey("distanceFunction")) {
      this.distanceFunctionName = settings.options.get("distanceFunction");
      this.distanceFunctionName = Character.toUpperCase(this.distanceFunctionName.charAt(0)) + this.distanceFunctionName.substring(1).toLowerCase(Locale.ENGLISH);
    }

    if (settings.options.containsKey("vectorType")) {
      this.vectorTypeName = settings.options.get("vectorType");
      // USE CAMEL CASE FOR THE VECTOR TYPE
      this.vectorTypeName = Character.toUpperCase(this.vectorTypeName.charAt(0)) + this.vectorTypeName.substring(1).toLowerCase(Locale.ENGLISH);
    }

    if (settings.options.containsKey("vectorProperty"))
      this.vectorPropertyName = settings.options.get("vectorProperty");

    if (settings.options.containsKey("idProperty"))
      this.idPropertyName = settings.options.get("idProperty");

    if (settings.options.containsKey("deletedProperty"))
      this.deletedPropertyName = settings.options.get("deletedProperty");

    if (settings.options.containsKey("m"))
      this.m = Integer.parseInt(settings.options.get("m"));

    if (settings.options.containsKey("ef"))
      this.ef = Integer.parseInt(settings.options.get("ef"));

    if (settings.options.containsKey("efConstruction"))
      this.efConstruction = Integer.parseInt(settings.options.get("efConstruction"));

    if (settings.options.containsKey("normalizeVectors"))
      this.normalizeVectors = Boolean.parseBoolean(settings.options.get("normalizeVectors"));
  }

  public Database run() throws IOException, ClassNotFoundException, InterruptedException {
    if (!createDatabase())
      return null;

    final DistanceFunction distanceFunction = DistanceFunctionFactory.getImplementationByName(vectorTypeName + distanceFunctionName);

    beginTime = System.currentTimeMillis();

    final List texts = loadFromFile();

    if (settings.documentsSkipEntries != null) {
      for (int i = 0; i < settings.documentsSkipEntries; i++)
        texts.remove(0);
    }

    if (!texts.isEmpty()) {
      final int dimensions = texts.get(1).dimensions();

      logger.logLine(2, "- Parsed %,d embeddings with %,d dimensions in RAM", texts.size(), dimensions);

      final HnswVectorIndexRAM hnswIndex = HnswVectorIndexRAM.newBuilder(dimensions, distanceFunction,
          texts.size()).withM(m).withEf(ef).withEfConstruction(efConstruction).build();

      hnswIndex.addAll(texts, Runtime.getRuntime().availableProcessors(), (workDone, max) -> ++indexedEmbedding, 1);

      Type vectorPropertyType;

      if (vectorTypeName.equals("Short"))
        vectorPropertyType = Type.ARRAY_OF_SHORTS;
      else if (vectorTypeName.equals("Integer"))
        vectorPropertyType = Type.ARRAY_OF_INTEGERS;
      else if (vectorTypeName.equals("Long"))
        vectorPropertyType = Type.ARRAY_OF_LONGS;
      else if (vectorTypeName.equals("Float"))
        vectorPropertyType = Type.ARRAY_OF_FLOATS;
      else if (vectorTypeName.equals("Double"))
        vectorPropertyType = Type.ARRAY_OF_DOUBLES;
      else
        throw new IllegalArgumentException("Type '" + vectorTypeName + "' not supported");

      hnswIndex.createPersistentIndex(database)//
          .withVertexType(settings.vertexTypeName).withEdgeType(settings.edgeTypeName).withVectorProperty(vectorPropertyName, vectorPropertyType)
          .withIdProperty(idPropertyName)//
          .withDeletedProperty(deletedPropertyName)//
          .withVertexCreationCallback((record, item, total) -> ++verticesCreated)//
          .withCallback((record, total) -> ++verticesConnected)//
          .withBatchSize(1000).create();
    }

    logger.logLine(1, "***************************************************************************************************");
    logger.logLine(1, "Import of Text Embeddings database completed in %s with %,d errors and %,d warnings.",
        DateUtils.formatElapsed((System.currentTimeMillis() - beginTime)), errors, warnings);
    logger.logLine(1, "\nSUMMARY\n");
    logger.logLine(1, "- Embeddings.................................: %,d", texts.size());
    logger.logLine(1, "***************************************************************************************************");
    logger.logLine(1, "");

    if (database != null) {
      logger.logLine(1, "NOTES:");
      logger.logLine(1, "- you can find your new ArcadeDB database in '" + database.getDatabasePath() + "'");
    }

    return database;
  }

  public void printProgress() {
    float progressPerc = 0F;
    if (verticesConnected > 0)
      progressPerc = 40F + (verticesConnected * 60F / embeddingsParsed); // 60% OF THE TOTAL PROCESS
    else if (verticesCreated > 0)
      progressPerc = 10F + (verticesCreated * 30F / embeddingsParsed); // 30% OF THE TOTAL PROCESS
    else if (indexedEmbedding > 0)
      progressPerc = indexedEmbedding * 10F / embeddingsParsed; // 10% OF THE TOTAL PROCESS

    String result = String.format("- %.2f%%", progressPerc);

    if (embeddingsParsed > 0)
      result += String.format(" - %,d embeddings parsed", embeddingsParsed);
    if (indexedEmbedding > 0)
      result += String.format(" - %,d embeddings indexed", indexedEmbedding);
    if (verticesCreated > 0)
      result += String.format(" - %,d vertices created", verticesCreated);
    if (verticesConnected > 0)
      result += String.format(" - %,d vertices connected", verticesConnected);

    result += " (elapsed " + DateUtils.formatElapsed(System.currentTimeMillis() - beginTime) + ")";

    logger.logLine(2, result);
  }

  private boolean createDatabase() {
    if (database == null) {
      factory = new DatabaseFactory(databasePath);
      if (factory.exists()) {
        if (!overwriteDatabase) {
          logger.errorLine("Database already exists on path '%s'", databasePath);
          ++errors;
          return false;
        } else {
          database = factory.open();
          logger.errorLine("Found existent database at '%s', dropping it and recreate a new one", databasePath);
          database.drop();
        }
      }

      // CREATE THE DATABASE
      database = factory.create();
    }
    return true;
  }

  public boolean isError() {
    return error;
  }

  public ImporterContext getContext() {
    return context;
  }

  public TextEmbeddingsImporter setContext(final ImporterContext context) {
    this.context = context;
    return this;
  }

  private List loadFromFile() throws IOException {
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
      final Stream parser = reader.lines();

      if (settings.parsingLimitEntries > 0)
        parser.limit(settings.parsingLimitEntries);

      final AtomicInteger vectorSize = new AtomicInteger(301);

      return parser.map(line -> {
        ++embeddingsParsed;

        final List tokens = CodeUtils.split(line, ' ', -1, vectorSize.get());

        String word = tokens.get(0);

        float[] vector = new float[tokens.size() - 1];
        for (int i = 1; i < tokens.size() - 1; i++)
          vector[i] = Float.parseFloat(tokens.get(i));

        vectorSize.set(vector.length);

        if (normalizeVectors)
          // FOR INNER PRODUCT SEARCH NORMALIZE VECTORS
          vector = VectorUtils.normalize(vector);

        return new TextFloatsEmbedding(word, vector);
      }).collect(Collectors.toList());
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy