All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cc.unitmesh.store.ElasticsearchStore.kt Maven / Gradle / Ivy

Go to download

Chocolate Factory is a cutting-edge LLM toolkit designed to empower you in creating your very own AI assistant.

There is a newer version: 1.0.0
Show newest version
package cc.unitmesh.store

import cc.unitmesh.cf.core.utils.IdUtil
import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.rag.document.Document
import cc.unitmesh.rag.store.EmbeddingMatch
import cc.unitmesh.rag.store.EmbeddingStore
import co.elastic.clients.elasticsearch.ElasticsearchClient
import co.elastic.clients.elasticsearch._types.InlineScript
import co.elastic.clients.elasticsearch._types.Script
import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty
import co.elastic.clients.elasticsearch._types.mapping.Property
import co.elastic.clients.elasticsearch._types.mapping.TextProperty
import co.elastic.clients.elasticsearch._types.mapping.TypeMapping
import co.elastic.clients.elasticsearch._types.query_dsl.MatchAllQuery
import co.elastic.clients.elasticsearch._types.query_dsl.Query
import co.elastic.clients.elasticsearch._types.query_dsl.ScriptScoreQuery
import co.elastic.clients.elasticsearch.core.BulkRequest
import co.elastic.clients.elasticsearch.core.SearchRequest
import co.elastic.clients.elasticsearch.core.SearchResponse
import co.elastic.clients.elasticsearch.core.bulk.BulkOperation
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation
import co.elastic.clients.elasticsearch.indices.CreateIndexRequest
import co.elastic.clients.elasticsearch.indices.ExistsRequest
import co.elastic.clients.json.JsonData
import co.elastic.clients.json.jackson.JacksonJsonpMapper
import co.elastic.clients.transport.ElasticsearchTransport
import co.elastic.clients.transport.rest_client.RestClientTransport
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.http.Header
import org.apache.http.HttpHost
import org.apache.http.auth.AuthScope
import org.apache.http.auth.UsernamePasswordCredentials
import org.apache.http.client.CredentialsProvider
import org.apache.http.impl.client.BasicCredentialsProvider
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder
import org.apache.http.message.BasicHeader
import org.elasticsearch.client.RestClient
import java.io.IOException
import java.util.*
import kotlin.streams.toList

/**
 * ElasticsearchStore is an implementation of the EmbeddingStore interface that uses Elasticsearch as the underlying storage.
 * It allows storing and retrieving embeddings along with associated documents.
 *
 * The ElasticsearchStore class requires the following parameters to be provided:
 * - serverUrl: The URL of the Elasticsearch server. The default value is "http://localhost:9200".
 * - indexName: The name of the Elasticsearch index to use. The default value is "chocolate-code".
 * - username: The username for authentication with the Elasticsearch server. This parameter is optional.
 * - password: The password for authentication with the Elasticsearch server. This parameter is optional.
 * - apiKey: The API key for authentication with the Elasticsearch server. This parameter is optional.
 *
 * The ElasticsearchStore class provides methods for adding embeddings and documents, as well as retrieving relevant embeddings based on a reference embedding.
 *
 * ```kotlin
 * val store: ElasticsearchStore = ElasticsearchStore(elasticsearchUrl)
 * ```
 */
class ElasticsearchStore(
    private val serverUrl: String = "http://localhost:9200",
    private val indexName: String = "chocolate-code",
    private val username: String? = null,
    private val password: String? = null,
    private val apiKey: String? = null,
) : EmbeddingStore {
    private val client: ElasticsearchClient
    private val objectMapper: ObjectMapper

    init {
        val restClientBuilder = RestClient
            .builder(HttpHost.create(serverUrl))

        if (!username.isNullOrBlank()) {
            val provider: CredentialsProvider = BasicCredentialsProvider()
            provider.setCredentials(AuthScope.ANY, UsernamePasswordCredentials(username, password))
            restClientBuilder.setHttpClientConfigCallback { httpClientBuilder: HttpAsyncClientBuilder ->
                httpClientBuilder.setDefaultCredentialsProvider(
                    provider
                )
            }
        }

        if (!apiKey.isNullOrBlank()) {
            restClientBuilder.setDefaultHeaders(
                arrayOf
( BasicHeader("Authorization", "Apikey $apiKey") ) ) } val transport: ElasticsearchTransport = RestClientTransport(restClientBuilder.build(), JacksonJsonpMapper()) client = ElasticsearchClient(transport) objectMapper = ObjectMapper() } override fun add(embedding: Embedding): String { val id: String = IdUtil.uuid() add(id, embedding) return id } override fun add(id: String, embedding: Embedding) { addInternal(id, embedding, null) } override fun add(embedding: Embedding, document: Document): String { val id: String = IdUtil.uuid() addInternal(id, embedding, document) return id } override fun addAll(embeddings: List): List { val ids = embeddings.stream() .map { _: Embedding -> IdUtil.uuid() } .toList() addAllInternal(ids, embeddings, null) return ids } override fun addAll(embeddings: List, embedded: List): List { val ids = embeddings.stream() .map { _: Embedding -> IdUtil.uuid() } .toList() addAllInternal(ids, embeddings, embedded) return ids } override fun findRelevant( referenceEmbedding: Embedding, maxResults: Int, minScore: Double, ): List> { return try { // Use Script Score and cosineSimilarity to calculate // see https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions-cosine val scriptScoreQuery = buildDefaultScriptScoreQuery(referenceEmbedding, minScore.toFloat()) val response: SearchResponse = client.search( SearchRequest.of { s: SearchRequest.Builder -> s.query { n: Query.Builder -> n.scriptScore(scriptScoreQuery) } .size(maxResults) }, Document::class.java ) toEmbeddingMatch(response) } catch (e: IOException) { log.error("[ElasticSearch encounter I/O Exception]", e) throw ElasticsearchRequestFailedException(e.message!!) } } private fun addInternal(id: String, embedding: Embedding, embedded: Document?) { addAllInternal( listOf(id), listOf(embedding), if (embedded == null) null else listOf(embedded) ) } private fun addAllInternal(ids: List, embeddings: List, embedded: List?) { if (ids.isEmpty() || embeddings.isEmpty()) { log.info("[do not add empty embeddings to elasticsearch]") return } if (ids.size != embeddings.size) { throw IllegalArgumentException("ids size is not equal to embeddings size") } if (embedded != null && embeddings.size != embedded.size) { throw IllegalArgumentException("embeddings size is not equal to embedded size") } try { createIndexIfNotExist(embeddings[0].size) bulk(ids, embeddings, embedded) } catch (e: IOException) { log.error("[ElasticSearch encounter I/O Exception]", e) throw ElasticsearchRequestFailedException(e.message!!) } } @Throws(IOException::class) private fun createIndexIfNotExist(dim: Int) { val response = client.indices().exists { c: ExistsRequest.Builder -> c.index(indexName) } if (!response.value()) { client.indices().create { c: CreateIndexRequest.Builder -> c.index(indexName) .mappings(getDefaultMappings(dim)) } } } private fun getDefaultMappings(dim: Int): TypeMapping { // do this like LangChain do val properties: MutableMap = HashMap(4) properties["text"] = Property.of { p: Property.Builder -> p.text(TextProperty.of { t: TextProperty.Builder? -> t }) } properties["vector"] = Property.of { p: Property.Builder -> p.denseVector(DenseVectorProperty.of { d: DenseVectorProperty.Builder -> d.dims(dim) }) } return TypeMapping.of { c: TypeMapping.Builder -> c.properties(properties) } } @Throws(IOException::class) private fun bulk(ids: List, embeddings: List, embedded: List?) { val size = ids.size val bulkBuilder = BulkRequest.Builder() for (i in 0 until size) { val document: Document = Document( text = embedded?.get(i)?.text ?: "", metadata = embedded?.get(i)?.metadata ?: HashMap(), vector = embeddings[i] ) bulkBuilder.operations { op: BulkOperation.Builder -> op.index { idx: IndexOperation.Builder -> idx .index(indexName) .id(ids[i]) .document(document) } } } val response = client.bulk(bulkBuilder.build()) if (response.errors()) { for (item in response.items()) { if (item.error() != null) { throw ElasticsearchRequestFailedException( "type: " + item.error()!! .type() + ", reason: " + item.error()!!.reason() ) } } } } @Throws(JsonProcessingException::class) private fun buildDefaultScriptScoreQuery(vector: Embedding, minScore: Float): ScriptScoreQuery { val queryVector = toJsonData(vector) return ScriptScoreQuery.of { q: ScriptScoreQuery.Builder -> q .minScore(minScore) .query(Query.of { qu: Query.Builder -> qu.matchAll { m: MatchAllQuery.Builder? -> m } }) .script { s: Script.Builder -> s.inline(InlineScript.of { i: InlineScript.Builder -> i // The script adds 1.0 to the cosine similarity to prevent the score from being negative. // divided by 2 to keep the score in the range [0, 1] .source("(cosineSimilarity(params.query_vector, 'vector') + 1.0) / 2") .params("query_vector", queryVector) }) } } } @Throws(JsonProcessingException::class) private fun toJsonData(rawData: T): JsonData { return JsonData.fromJson(objectMapper.writeValueAsString(rawData)) } private fun toEmbeddingMatch(response: SearchResponse): List> { return response.hits().hits().map { hit -> val document = hit.source() ?: return@map null val segmentEmbeddingMatch = if (document.text.isEmpty()) { return@map null } else { EmbeddingMatch(hit.score()!!, hit.id(), document.vector, Document(document.text, document.metadata)) } segmentEmbeddingMatch }.filterNotNull() } companion object { private val log = org.slf4j.LoggerFactory.getLogger(ElasticsearchStore::class.java) } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy