All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.hakenadu.javalangchains.chains.data.retrieval.ElasticsearchRetrievalChain Maven / Gradle / Ivy

package com.github.hakenadu.javalangchains.chains.data.retrieval;

import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.http.HttpHost;
import org.apache.lucene.search.Query;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.RestClient;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.hakenadu.javalangchains.util.PromptConstants;

/**
 * This {@link RetrievalChain} retrieves documents from an elasticsearch index
 */
public class ElasticsearchRetrievalChain extends RetrievalChain implements Closeable {

	/**
	 * elasticsearch index name
	 */
	private final String index;

	/**
	 * elasticsearch low level {@link RestClient}
	 */
	private final RestClient restClient;

	/**
	 * this {@link Function} accepts the user's question and provides the
	 * {@link Query} which is executed against the Elasticsearch _search API
	 */
	private final Function queryCreator;

	/**
	 * Consumes an elasticsearch hit and the question and creates a document as a
	 * result
	 */
	private final BiFunction> documentCreator;

	/**
	 * {@link ObjectMapper} used for query creation and document deserialization
	 */
	private final ObjectMapper objectMapper;

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index            {@link #index}
	 * @param restClient       {@link #restClient}
	 * @param maxDocumentCount {@link #getMaxDocumentCount()}
	 * @param objectMapper     {@link #objectMapper}
	 * @param queryCreator     {@link #queryCreator}
	 * @param documentCreator  {@link #documentCreator}
	 */
	public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount,
			final ObjectMapper objectMapper, final Function queryCreator,
			final BiFunction> documentCreator) {
		super(maxDocumentCount);
		this.index = index;
		this.restClient = restClient;
		this.objectMapper = objectMapper;
		this.queryCreator = queryCreator;
		this.documentCreator = documentCreator;
	}

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index            {@link #index}
	 * @param restClient       {@link #restClient}
	 * @param maxDocumentCount {@link #getMaxDocumentCount()}
	 * @param objectMapper     {@link #objectMapper}
	 * @param queryCreator     {@link #queryCreator}
	 */
	public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount,
			final ObjectMapper objectMapper, final Function queryCreator) {
		this(index, restClient, maxDocumentCount, objectMapper, queryCreator, defaultDocumentCreator(objectMapper));
	}

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index            {@link #index}
	 * @param restClient       {@link #restClient}
	 * @param maxDocumentCount {@link #getMaxDocumentCount}
	 * @param objectMapper     {@link #objectMapper}
	 */
	public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount,
			final ObjectMapper objectMapper) {
		this(index, restClient, maxDocumentCount, objectMapper, question -> createQuery(objectMapper, question));
	}

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index            {@link #index}
	 * @param restClient       {@link #restClient}
	 * @param maxDocumentCount {@link #getMaxDocumentCount}
	 */
	public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount) {
		this(index, restClient, maxDocumentCount, new ObjectMapper());
	}

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index      {@link #index}
	 * @param restClient {@link #restClient}
	 */
	public ElasticsearchRetrievalChain(final String index, final RestClient restClient) {
		this(index, restClient, 4);
	}

	/**
	 * Creates an instance of {@link ElasticsearchRetrievalChain}
	 * 
	 * @param index {@link #index}
	 */
	public ElasticsearchRetrievalChain(final String index) {
		this(index, RestClient.builder(new HttpHost("localhost", 9200)).build());
	}

	@Override
	public Stream> run(final String input) {
		final ObjectNode query = queryCreator.apply(input);

		final String requestJson = objectMapper.createObjectNode().put("size", getMaxDocumentCount())
				.set("query", query).toString();

		final Request searchRequest = new Request("GET", String.format("/%s/_search", index));
		searchRequest.setJsonEntity(requestJson);

		final Response searchResponse;
		try {
			searchResponse = restClient.performRequest(searchRequest);
		} catch (final IOException ioException) {
			throw new IllegalStateException("error executing search with request " + requestJson, ioException);
		}

		final ObjectNode response;
		try (final InputStream responseInputStream = searchResponse.getEntity().getContent()) {
			response = (ObjectNode) objectMapper.readTree(responseInputStream);
		} catch (final IOException ioException) {
			throw new IllegalStateException("error parsing search response", ioException);
		}

		final ArrayNode hits = Optional.of(response).map(o -> o.get("hits")).map(ObjectNode.class::cast)
				.map(o -> o.get("hits")).map(ArrayNode.class::cast).orElse(null);

		if (hits == null) {
			return Stream.empty();
		}

		return StreamSupport.stream(Spliterators.spliteratorUnknownSize(hits.iterator(), Spliterator.ORDERED), false)
				.map(ObjectNode.class::cast).map(hitNode -> documentCreator.apply(hitNode, input));
	}

	@Override
	public void close() throws IOException {
		this.restClient.close();
	}

	/**
	 * @param objectMapper {@link ObjectMapper} used for {@link ObjectNode} creation
	 * @param question     the question used for retrieval
	 * @return {"match": {"content": question}}
	 */
	private static ObjectNode createQuery(final ObjectMapper objectMapper, final String question) {
		final ObjectNode query = objectMapper.createObjectNode();
		query.putObject("match").put(PromptConstants.CONTENT, question);
		return query;
	}

	/**
	 * creates the default {@link #queryCreator}
	 * 
	 * @param objectMapper the {@link ObjectMapper} used for json operations
	 * @return {@link BiFunction} which consumes a hit node and the question and
	 *         produces a document consisting of all (key, value)-pairs of the hit's
	 *         _source object
	 */
	public static BiFunction> defaultDocumentCreator(
			final ObjectMapper objectMapper) {
		return (hitObjectNode, question) -> {
			final ObjectNode source = (ObjectNode) hitObjectNode.get("_source");

			final Map sourceMap = objectMapper.convertValue(source,
					new TypeReference>() {
						// noop
					});

			final Map document = new HashMap<>();
			document.put(PromptConstants.QUESTION, question);

			for (final Entry sourceEntry : sourceMap.entrySet()) {
				document.put(sourceEntry.getKey(), sourceEntry.getValue().toString());
			}

			return document;
		};
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy