Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
apoc.ml.aws.SageMaker Maven / Gradle / Ivy
package apoc.ml.aws;
import apoc.Description;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import static apoc.ml.aws.AWSConfig.HEADERS_KEY;
import static apoc.ml.aws.AWSConfig.JSON_PATH;
import static apoc.ml.aws.SageMakerConfig.ENDPOINT_NAME_KEY;
import static apoc.util.JsonUtil.OBJECT_MAPPER;
@Extended
public class SageMaker {
@Context
public URLAccessChecker urlAccessChecker;
public record EmbeddingResult(long index, String text, List embedding) {}
@Procedure("apoc.ml.sagemaker.custom")
@Description("apoc.ml.sagemaker.chat(body, $conf) - To create a customizable SageMaker call")
public Stream custom(@Name(value = "body") Object body,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
AWSConfig conf = new SageMakerConfig(configuration);
return executeRequestReturningMap(body, conf)
.map(MapResult::new);
}
@Procedure("apoc.ml.sagemaker.chat")
@Description("apoc.ml.sagemaker.chat(messages, $conf) - Prompts the chat completion API")
public Stream chatCompletion(
@Name("messages") List> messages,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
var config = new HashMap<>(configuration);
config.putIfAbsent(ENDPOINT_NAME_KEY, "Endpoint-Distilbart-xsum-1-1-1");
config.putIfAbsent(HEADERS_KEY, Util.map("Content-Type", "application/x-text"));
AWSConfig conf = new SageMakerConfig(config);
return messages
.stream()
.flatMap(message -> {
// to emulate OpenAI behaviour, e.g `{content: 'text..'},
// otherwise we put all json message as a body (with other models)
Object body = message.containsKey("content")
? message.get("content")
: message;
return executeRequestReturningMap(body, conf)
.map(MapResult::new);
});
}
@Procedure("apoc.ml.sagemaker.completion")
@Description("apoc.ml.sagemaker.completion(prompt, $conf) - Prompts the completion API")
public Stream completion(@Name("prompt") String prompt,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
var config = new HashMap<>(configuration);
config.putIfAbsent(ENDPOINT_NAME_KEY, "Endpoint-GPT-2-1");
config.putIfAbsent(HEADERS_KEY, Map.of("Content-Type", "application/x-text"));
AWSConfig conf = new SageMakerConfig(config);
return executeRequestReturningMap(prompt, conf)
.map(MapResult::new);
}
@Procedure("apoc.ml.sagemaker.embedding")
@Description("apoc.ml.sagemaker.embedding([texts], $configuration) - Returns the embeddings for a given text")
public Stream embedding(@Name(value = "texts") List texts,
@Name(value = "configuration", defaultValue = "{}") Map configuration) {
var config = new HashMap<>(configuration);
config.putIfAbsent(ENDPOINT_NAME_KEY, "Endpoint-Jina-Embeddings-v2-Base-en-1");
config.putIfAbsent(JSON_PATH, "data[*]");
AWSConfig conf = new SageMakerConfig(config);
List> inputs = texts.stream().map(text -> Map.of("text", text)).toList();
Object data = Map.of("data", inputs);
AtomicInteger idx = new AtomicInteger();
return executeRequestCommon(data, conf)
.flatMap(v -> ((List>) v).stream())
.map(i -> {
int index = idx.getAndIncrement();
return new EmbeddingResult(index, texts.get(index), (List) i.get("embedding"));
});
}
private Stream> executeRequestReturningMap(Object body, AWSConfig config) {
return executeRequestCommon(body, config)
.map(i -> (Map) i);
}
private Stream executeRequestCommon(Object body, AWSConfig conf) {
try {
String bodyString = body instanceof String string
? string
: OBJECT_MAPPER.writeValueAsString(body);
Map headers = conf.getHeaders();
headers.putIfAbsent("Content-Type", "application/json");
headers.putIfAbsent("accept", "*/*");
if (!headers.containsKey("Authorization")) {
AwsSignatureV4Generator.calculateAuthorizationHeaders(conf, bodyString, headers, "sagemaker");
}
return JsonUtil.loadJson(conf.getEndpoint(), conf.getHeaders(), bodyString, conf.getJsonPath(), true, List.of(), urlAccessChecker);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}