All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.scio.extra.voyager.Voyager.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2023 Spotify AB.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.spotify.scio.extra.voyager

import com.spotify.scio.util.{RemoteFileUtil, ScioUtil}
import com.spotify.scio.values.SideInput
import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType}
import com.spotify.voyager.jni.{Index, StringIndex}
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.values.PCollectionView

import java.net.URI
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.ConcurrentHashMap

/**
 * Represents the base URI for a voyager index, either on a local or a remote file system. For
 * remote file systems, the `path` should be in the form 'scheme:////'. For local
 * files, it should be in the form '//'. The `uri` specified represents the directory where
 * the `index.hnsw` and `names.json` are.
 */
final case class VoyagerUri(value: URI) extends AnyVal {

  import VoyagerUri._

  def exists(implicit remoteFileUtil: RemoteFileUtil): Boolean = {
    if (ScioUtil.isLocalUri(value)) {
      VoyagerFiles.exists(f => Paths.get(value.resolve(f)).toFile.exists())
    } else {
      VoyagerFiles.exists(f => remoteFileUtil.remoteExists(value.resolve(f)))
    }
  }
}

object VoyagerUri {
  def apply(value: String): VoyagerUri =
    new VoyagerUri(URI.create(value.stripSuffix("/") + "/"))

  private[voyager] val IndexFile = "index.hnsw"
  private[voyager] val NamesFile = "names.json"

  private[voyager] val VoyagerFiles: Seq[String] = Seq(IndexFile, NamesFile)

}

/** Result of a voyager query */
final case class VoyagerResult(name: String, distance: Float)

class VoyagerWriter private[voyager] (
  indexFile: Path,
  namesFile: Path,
  indexSettings: VoyagerWriter.IndexSettings
) {
  import VoyagerWriter._

  def write(vectors: Iterable[(String, Array[Float])]): Unit = {
    val indexOutputStream = Files.newOutputStream(indexFile)
    val namesOutputStream = Files.newOutputStream(namesFile)

    val names = List.newBuilder[String]
    val index = new Index(
      indexSettings.space,
      indexSettings.numDimensions,
      indexSettings.indexM,
      indexSettings.efConstruction,
      indexSettings.randomSeed,
      indexSettings.maxElements,
      indexSettings.storageDataType
    )

    vectors.zipWithIndex
      .map { case ((name, vector), idx) => (name, vector, idx.toLong) }
      .grouped(ChunkSize)
      .map(_.unzip3)
      .foreach { case (ns, vs, is) =>
        names ++= ns
        index.addItems(vs.toArray, is.toArray, -1)
      }

    // save index
    index.saveIndex(indexOutputStream)
    index.close()
    // save names
    val json = names.result().mkString("[\"", "\",\"", "\"]")
    namesOutputStream.write(json.getBytes(StandardCharsets.UTF_8))
    // close
    indexOutputStream.close()
    namesOutputStream.close()
  }
}

private[voyager] object VoyagerWriter {

  private val ChunkSize: Int = 32786 // 2^15

  // same default as voyager Index constructor
  val DefaultIndexM: Long = 16L
  val DefaultEfConstruction: Long = 200L
  val DefaultRandomSeed: Long = 1L
  val DefaultMaxElements: Long = 1L
  val DefaultStorageDataType: StorageDataType = StorageDataType.Float32

  final case class IndexSettings(
    space: SpaceType,
    numDimensions: Int,
    indexM: Long = DefaultIndexM,
    efConstruction: Long = DefaultEfConstruction,
    randomSeed: Long = DefaultRandomSeed,
    maxElements: Long = DefaultMaxElements,
    storageDataType: StorageDataType = DefaultStorageDataType
  )

}

/**
 * Voyager reader class for nearest neighbor lookups. Supports looking up neighbors for a vector and
 * returning the string labels and distances associated.
 *
 * @param indexFile
 *   The `index.hnsw` file.
 * @param namesFile
 *   The `names.json` file.
 * @param indexSettings
 *   The voyager index settings, either provided or read from the index v2 metadata
 */
class VoyagerReader private[voyager] (
  indexFile: Path,
  namesFile: Path,
  indexSettings: VoyagerReader.IndexSettings
) {
  @transient private lazy val index: StringIndex = indexSettings match {
    case VoyagerReader.ProvidedSettings(space, numDimensions, storageDataType) =>
      StringIndex.load(
        indexFile.toString,
        namesFile.toString,
        space,
        numDimensions,
        storageDataType
      )
    case VoyagerReader.MetadataSettings =>
      StringIndex.load(
        indexFile.toString,
        namesFile.toString
      )
  }

  /**
   * Gets maxNumResults nearest neighbors for vector v using ef (where ef is the size of the dynamic
   * list for the nearest neighbors during search).
   */
  def getNearest(v: Array[Float], maxNumResults: Int, ef: Int): Array[VoyagerResult] = {
    val queryResults = index.query(v, maxNumResults, ef)
    queryResults.getNames
      .zip(queryResults.getDistances)
      .map { case (name, distance) => VoyagerResult(name, distance) }
  }
}

private[voyager] object VoyagerReader {
  sealed trait IndexSettings
  final case class ProvidedSettings(
    space: SpaceType,
    numDimensions: Int,
    storageDataType: StorageDataType
  ) extends IndexSettings
  case object MetadataSettings extends IndexSettings

}

/**
 * Construction for a VoyagerSide input that leverages a synchronized map to ensure that the reader
 * is only loaded once per [[VoyagerUri]].
 */
private[voyager] class VoyagerSideInput(
  indexSettings: VoyagerReader.IndexSettings,
  remoteFileUtil: RemoteFileUtil,
  override val view: PCollectionView[VoyagerUri]
) extends SideInput[VoyagerReader] {

  import VoyagerSideInput._

  private def createReader(uri: VoyagerUri): VoyagerReader = {
    val indexUri = uri.value.resolve(VoyagerUri.IndexFile)
    val namesUri = uri.value.resolve(VoyagerUri.NamesFile)

    val (localIndex, localNames) = if (ScioUtil.isLocalUri(uri.value)) {
      (Paths.get(indexUri), Paths.get(namesUri))
    } else {
      val downloadedIndex = remoteFileUtil.download(indexUri)
      val downloadedNames = remoteFileUtil.download(namesUri)
      (downloadedIndex, downloadedNames)
    }
    new VoyagerReader(localIndex, localNames, indexSettings)
  }

  override def get[I, O](context: DoFn[I, O]#ProcessContext): VoyagerReader = {
    val uri = context.sideInput(view)
    VoyagerReaderSharedCache.computeIfAbsent(uri, createReader)
  }
}

private object VoyagerSideInput {
  // cache the VoyagerUri to VoyagerReader per JVM so workers with multiple
  // voyager side-input steps load the index only once
  @transient private lazy val VoyagerReaderSharedCache
    : ConcurrentHashMap[VoyagerUri, VoyagerReader] =
    new ConcurrentHashMap()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy