All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.liferay.search.experiences.internal.ml.embedding.text.HuggingFaceInferenceAPITextEmbeddingProvider Maven / Gradle / Ivy

/**
 * SPDX-FileCopyrightText: (c) 2000 Liferay, Inc. https://liferay.com
 * SPDX-License-Identifier: LGPL-2.1-or-later OR LicenseRef-Liferay-DXP-EULA-2.0.0-2023-06
 */

package com.liferay.search.experiences.internal.ml.embedding.text;

import com.liferay.petra.reflect.ReflectionUtil;
import com.liferay.petra.string.StringPool;
import com.liferay.portal.kernel.json.JSONArray;
import com.liferay.portal.kernel.json.JSONFactory;
import com.liferay.portal.kernel.json.JSONObject;
import com.liferay.portal.kernel.json.JSONUtil;
import com.liferay.portal.kernel.log.Log;
import com.liferay.portal.kernel.log.LogFactoryUtil;
import com.liferay.portal.kernel.servlet.HttpHeaders;
import com.liferay.portal.kernel.util.ContentTypes;
import com.liferay.portal.kernel.util.Http;
import com.liferay.portal.kernel.util.MapUtil;
import com.liferay.portal.kernel.util.StringUtil;
import com.liferay.portal.kernel.util.Validator;
import com.liferay.search.experiences.rest.dto.v1_0.EmbeddingProviderConfiguration;

import java.net.HttpURLConnection;

import java.util.List;
import java.util.Map;

import org.osgi.service.component.annotations.Component;
import org.osgi.service.component.annotations.Reference;

/**
 * @author Petteri Karttunen
 */
@Component(
	enabled = false,
	property = "search.experiences.text.embedding.provider.name=huggingFaceInferenceAPI",
	service = TextEmbeddingProvider.class
)
public class HuggingFaceInferenceAPITextEmbeddingProvider
	extends BaseTextEmbeddingProvider implements TextEmbeddingProvider {

	public Double[] getEmbedding(
		EmbeddingProviderConfiguration embeddingProviderConfiguration,
		String text) {

		Map attributes =
			(Map)embeddingProviderConfiguration.getAttributes();

		if ((attributes == null) || !attributes.containsKey("accessToken")) {
			if (_log.isDebugEnabled()) {
				_log.debug("Attributes do not contain access token");
			}

			return new Double[0];
		}

		String sentences = extractSentences(
			MapUtil.getInteger(attributes, "maxCharacterCount", 1000), text,
			MapUtil.getString(
				attributes, "textTruncationStrategy", "beginning"));

		if (Validator.isBlank(sentences)) {
			return new Double[0];
		}

		return _getEmbedding(attributes, sentences);
	}

	private Double[] _getEmbedding(
		Map attributes, String text) {

		try {
			Http.Options options = new Http.Options();

			JSONObject jsonObject = JSONUtil.put("inputs", text);

			options.addHeader(
				HttpHeaders.AUTHORIZATION,
				"Bearer " + MapUtil.getString(attributes, "accessToken"));
			options.addHeader(
				HttpHeaders.CONTENT_TYPE, ContentTypes.APPLICATION_JSON);
			options.setBody(
				jsonObject.toString(), ContentTypes.APPLICATION_JSON,
				StringPool.UTF8);
			options.setCookieSpec(Http.CookieSpec.STANDARD);
			options.setLocation(
				"https://api-inference.huggingface.co/models/" +
					MapUtil.getString(attributes, "model"));
			options.setPost(true);

			String responseJSON = _http.URLtoString(options);

			Http.Response response = options.getResponse();

			if (response.getResponseCode() ==
					HttpURLConnection.HTTP_UNAVAILABLE) {

				options.addHeader("x-wait-for-model", "true");
				options.setTimeout(
					MapUtil.getInteger(attributes, "modelTimeout", 30) * 1000);

				responseJSON = _http.URLtoString(options);
			}

			if (!isJSONArray(responseJSON)) {
				throw new IllegalArgumentException(responseJSON);
			}
			else if (!_isValidResponse(responseJSON)) {
				if (_log.isDebugEnabled()) {
					_log.debug("Invalid response: " + responseJSON);
				}

				throw new IllegalArgumentException(
					"The selected model is not valid for creating text " +
						"embedding");
			}

			List list = JSONUtil.toDoubleList(
				_getJSONArray(_jsonFactory.createJSONArray(responseJSON)));

			return list.toArray(new Double[0]);
		}
		catch (Exception exception) {
			return ReflectionUtil.throwException(exception);
		}
	}

	private JSONArray _getJSONArray(JSONArray jsonArray1) {
		JSONArray jsonArray2 = jsonArray1.getJSONArray(0);

		if (jsonArray2 != null) {
			return _getJSONArray(jsonArray2);
		}

		return jsonArray1;
	}

	private boolean _isValidResponse(String s) {
		if (StringUtil.startsWith(s, "[[") && StringUtil.endsWith(s, "]]")) {
			return true;
		}

		return false;
	}

	private static final Log _log = LogFactoryUtil.getLog(
		HuggingFaceInferenceAPITextEmbeddingProvider.class);

	@Reference
	private Http _http;

	@Reference
	private JSONFactory _jsonFactory;

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy