xyz.felh.openai.OpenAiService Maven / Gradle / Ivy
package xyz.felh.openai;
import com.alibaba.fastjson2.JSONObject;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import io.reactivex.rxjava3.core.Single;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.Nullable;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava3.RxJava3CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
import retrofit2.converter.scalars.ScalarsConverterFactory;
import retrofit2.http.*;
import xyz.felh.StreamListener;
import xyz.felh.openai.assistant.Assistant;
import xyz.felh.openai.assistant.CreateAssistantRequest;
import xyz.felh.openai.assistant.ModifyAssistantRequest;
import xyz.felh.openai.assistant.message.CreateMessageRequest;
import xyz.felh.openai.assistant.message.Message;
import xyz.felh.openai.assistant.stream.message.MessageDelta;
import xyz.felh.openai.assistant.message.ModifyMessageRequest;
import xyz.felh.openai.assistant.run.*;
import xyz.felh.openai.assistant.runstep.RunStep;
import xyz.felh.openai.assistant.stream.runstep.RunStepDelta;
import xyz.felh.openai.assistant.thread.CreateThreadRequest;
import xyz.felh.openai.assistant.thread.ModifyThreadRequest;
import xyz.felh.openai.assistant.thread.Thread;
import xyz.felh.openai.assistant.vector.store.CreateVectorStoreRequest;
import xyz.felh.openai.assistant.vector.store.ModifyVectorStoreRequest;
import xyz.felh.openai.assistant.vector.store.VectorStore;
import xyz.felh.openai.assistant.vector.store.file.CreateVectorStoreFileRequest;
import xyz.felh.openai.assistant.vector.store.file.VectorStoreFile;
import xyz.felh.openai.assistant.vector.store.file.batch.CreateVectorStoreFileBatchRequest;
import xyz.felh.openai.assistant.vector.store.file.batch.VectorStoreFileBatch;
import xyz.felh.openai.audio.AudioResponse;
import xyz.felh.openai.audio.CreateAudioTranscriptionRequest;
import xyz.felh.openai.audio.CreateAudioTranslationRequest;
import xyz.felh.openai.audio.CreateSpeechRequest;
import xyz.felh.openai.batch.Batch;
import xyz.felh.openai.batch.CreateBatchRequest;
import xyz.felh.openai.bean.StreamToolCallsRequest;
import xyz.felh.openai.chat.ChatCompletion;
import xyz.felh.openai.chat.CreateChatCompletionRequest;
import xyz.felh.openai.embedding.CreateEmbeddingRequest;
import xyz.felh.openai.embedding.CreateEmbeddingResponse;
import xyz.felh.openai.fineTuning.CreateFineTuningJobRequest;
import xyz.felh.openai.fineTuning.FineTuningJob;
import xyz.felh.openai.fineTuning.FineTuningJobEvent;
import xyz.felh.openai.image.CreateImageRequest;
import xyz.felh.openai.image.ImageResponse;
import xyz.felh.openai.image.edit.CreateEditRequest;
import xyz.felh.openai.image.variation.CreateVariationRequest;
import xyz.felh.openai.interceptor.AuthenticationInterceptor;
import xyz.felh.openai.model.Model;
import xyz.felh.openai.moderation.CreateModerationRequest;
import xyz.felh.openai.moderation.CreateModerationResponse;
import xyz.felh.utils.Preconditions;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Function;
import static xyz.felh.openai.constant.OpenAiConstants.BASE_URL;
/**
* OpenAi Service Class
*/
@Slf4j
public class OpenAiService {
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(30);
private static final ObjectMapper errorMapper = defaultObjectMapper();
private final OpenAiApi api;
private final OkHttpClient client;
/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
*/
public OpenAiService(final String token) {
this(token, DEFAULT_TIMEOUT);
}
/**
* Creates a new OpenAiService that wraps OpenAiApi
*
* @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(buildApi(token, timeout), defaultClient(token, timeout));
}
/**
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(final OpenAiApi api, final OkHttpClient client) {
this.api = api;
this.client = client;
}
/**
* Calls the Open AI api, returns the response, and parses error messages if the request fails
*/
public static T execute(Single apiCall) {
try {
return apiCall.blockingGet();
} catch (HttpException e) {
try {
if (e.response() == null || e.response().errorBody() == null) {
throw e;
}
String errorBody = e.response().errorBody().string();
OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) {
// couldn't parse OpenAI error
throw e;
}
}
}
public static OpenAiApi buildApi(String token, Duration timeout) {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);
return retrofit.create(OpenAiApi.class);
}
public static ObjectMapper defaultObjectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
return mapper;
}
public static OkHttpClient defaultClient(String token, Duration timeout) {
return defaultClient(token, null, null, timeout);
}
public static OkHttpClient defaultClient(String token, String orgId, String projectId, Duration timeout) {
return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token, orgId, projectId))
.connectionPool(new ConnectionPool(10, 4, TimeUnit.SECONDS))
.readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
.build();
}
public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) {
return new Retrofit.Builder()
.baseUrl(BASE_URL)
.client(client)
.addConverterFactory(ScalarsConverterFactory.create())
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(RxJava3CallAdapterFactory.create())
.build();
}
public List listModels() {
return execute(api.listModels()).getData();
}
public Model getModel(String modelId) {
return execute(api.getModel(modelId));
}
/**
* gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
* gpt-3.5-turbo-1006, gpt-4-1106-preview
*
* @param request create chat completion request
* @return chat completion
*/
public ChatCompletion createChatCompletion(CreateChatCompletionRequest request) {
return createChatCompletion(request, null);
}
/**
* @param request create chat completion request
* @param toolCallsHandler handle tool calls and then build the next request
* @return chat completion
*/
public ChatCompletion createChatCompletion(CreateChatCompletionRequest request,
Function toolCallsHandler) {
request.setStream(false);
ChatCompletion chatCompletion = execute(api.createChatCompletion(request));
if (Preconditions.isBlank(toolCallsHandler)) {
return chatCompletion;
} else {
if (Preconditions.isNotBlank(chatCompletion.getChoices())
&& Preconditions.isNotBlank(chatCompletion.getChoices().get(0).getMessage())
&& Preconditions.isNotBlank(chatCompletion.getChoices().get(0).getMessage().getToolCalls())) {
CreateChatCompletionRequest newRequest = toolCallsHandler.apply(chatCompletion);
return createChatCompletion(newRequest);
} else {
return chatCompletion;
}
}
}
/**
* create chat completion by stream, user-side handled if there is tool_calls
*
* @param requestId request ID, every observer is unique
* @param request detail of request
* @param listener StreamChatCompletionListener
*/
public void createSteamChatCompletion(String requestId,
CreateChatCompletionRequest request,
@NonNull StreamListener listener) {
createSteamChatCompletion(requestId, request, listener, null);
}
/**
* create chat completion by stream, sdk-side handled if there is tool_calls
*
* @param requestId request ID, every observer is unique
* @param request detail of request
* @param listener StreamChatCompletionListener
* @param toolCallsHandler handle tool calls and then build the next request
*/
public void createSteamChatCompletion(String requestId,
CreateChatCompletionRequest request,
@NonNull StreamListener listener,
BiFunction toolCallsHandler) {
request.setStream(true);
Request okHttpRequest;
try {
okHttpRequest = new Request.Builder().url(BASE_URL + "/v1/chat/completions")
.header("content-type", "text/event-stream")
.header("Accept", "text/event-stream")
.post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
MediaType.parse("application/json")))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
final StreamToolCallsReceiver streamToolCallsReceiver;
final CountDownLatch countDownLatch;
if (Preconditions.isBlank(toolCallsHandler)) {
streamToolCallsReceiver = null;
countDownLatch = null;
} else {
countDownLatch = new CountDownLatch(1);
streamToolCallsReceiver = new StreamToolCallsReceiver(this, requestId, toolCallsHandler, listener, countDownLatch);
}
EventSource.Factory factory = EventSources.createFactory(client);
EventSourceListener eventSourceListener = new EventSourceListener() {
@Override
public void onOpen(@NonNull EventSource eventSource, @NonNull Response response) {
listener.onOpen(requestId, response);
}
@Override
public void onEvent(@NonNull EventSource eventSource, @Nullable String id, @Nullable String type, @NonNull String data) {
if (data.equals("[DONE]")) {
if (Preconditions.isBlank(toolCallsHandler)) {
listener.onEventDone(requestId);
} else {
assert streamToolCallsReceiver != null;
if (!streamToolCallsReceiver.receiveDone(requestId)) {
listener.onEventDone(requestId);
}
}
} else {
try {
ChatCompletion chatCompletion = defaultObjectMapper().readValue(data, ChatCompletion.class);
if (Preconditions.isBlank(toolCallsHandler)) {
listener.onEvent(requestId, chatCompletion);
} else {
assert streamToolCallsReceiver != null;
if (!streamToolCallsReceiver.receive(chatCompletion)) {
listener.onEvent(requestId, chatCompletion);
}
}
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void onClosed(@NonNull EventSource eventSource) {
if (Preconditions.isBlank(toolCallsHandler)) {
listener.onClosed(requestId);
} else {
assert streamToolCallsReceiver != null;
if (streamToolCallsReceiver.getActive()) {
try {
countDownLatch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
listener.onClosed(requestId);
}
}
@Override
public void onFailure(@NonNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
// log.error("createSteamChatCompletion error {}", response, t);
listener.onFailure(requestId, t, response);
}
};
EventSource eventSource = factory.newEventSource(okHttpRequest, eventSourceListener);
listener.setEventSource(eventSource);
}
public ImageResponse createImage(CreateImageRequest request) {
return execute(api.createImage(request));
}
public ImageResponse createImageEdit(CreateEditRequest request) {
byte[] imageBytes;
if (request.getImage() != null && request.getImage().length > 0) {
imageBytes = request.getImage();
} else {
File image = new File(request.getImagePath());
try {
imageBytes = Files.readAllBytes(image.toPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
byte[] maskBytes = null;
if (request.getMask() != null && request.getMask().length > 0) {
maskBytes = request.getMask();
} else if (request.getMaskPath() != null && !request.getMaskPath().isEmpty()) {
File mask = new File(request.getMaskPath());
try {
maskBytes = Files.readAllBytes(mask.toPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return createImageEdit(request, imageBytes, maskBytes);
}
private ImageResponse createImageEdit(CreateEditRequest request, byte[] image, byte[] mask) {
RequestBody imageBody = RequestBody.create(image, MediaType.parse("image"));
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MediaType.get("multipart/form-data"))
.addFormDataPart("prompt", request.getPrompt())
.addFormDataPart("image", "image", imageBody);
if (request.getN() != null) {
builder.addFormDataPart("n", request.getN().toString());
}
if (request.getSize() != null) {
builder.addFormDataPart("size", request.getSize().value());
}
if (request.getResponseFormat() != null) {
builder.addFormDataPart("response_format", request.getResponseFormat().value());
}
if (mask != null) {
RequestBody maskBody = RequestBody.create(mask, MediaType.parse("image"));
builder.addFormDataPart("mask", "mask", maskBody);
}
return execute(api.createImageEdit(builder.build()));
}
public ImageResponse createImageVariation(CreateVariationRequest request) {
RequestBody imageBody;
if (request.getImage() != null && request.getImage().length > 0) {
imageBody = RequestBody.create(request.getImage(), MediaType.parse("image"));
} else {
File mask = new File(request.getImagePath());
try {
imageBody = RequestBody.create(Files.readAllBytes(mask.toPath()), MediaType.parse("image"));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MediaType.get("multipart/form-data"))
.addFormDataPart("image", "image", imageBody);
if (request.getN() != null) {
builder.addFormDataPart("n", request.getN().toString());
}
if (request.getSize() != null) {
builder.addFormDataPart("size", request.getSize().value());
}
if (request.getResponseFormat() != null) {
builder.addFormDataPart("response_format", request.getResponseFormat().value());
}
return execute(api.createImageVariation(builder.build()));
}
/**
* text-embedding-ada-002, text-search-ada-doc-001
*
* @param request create Embedding request
* @return create embedding response
*/
public CreateEmbeddingResponse createEmbeddings(CreateEmbeddingRequest request) {
return execute(api.createEmbedding(request));
}
/**
* Generates audio from the input text.
*
* @param request create speech request
* @return audio file
*/
@POST("/v1/audio/speech")
public byte[] createSpeech(CreateSpeechRequest request) throws IOException {
return execute(api.createSpeech(request)).bytes();
}
/**
* Transcribes audio into the input language.
* whisper-1
*
* Supported formats: ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
*
* @param request create audio transcription request
* @return audio response
*/
public AudioResponse createAudioTranscription(CreateAudioTranscriptionRequest request) {
byte[] fileBytes;
if (request.getFile() != null && request.getFile().length > 0) {
fileBytes = request.getFile();
} else {
File file = new File(request.getFilePath());
try {
fileBytes = Files.readAllBytes(file.toPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MultipartBody.FORM)
.addFormDataPart("file", request.getFileName(),
RequestBody.create(fileBytes, MediaType.parse("application/octet-stream")))
.addFormDataPart("model", request.getModel());
if (request.getPrompt() != null) {
builder.addFormDataPart("prompt", request.getModel());
}
if (request.getResponseFormat() != null) {
builder.addFormDataPart("response_format", request.getResponseFormat().value());
}
if (request.getTemperature() != null) {
builder.addFormDataPart("temperature", request.getTemperature().toString());
}
if (request.getPrompt() != null) {
builder.addFormDataPart("language", request.getLanguage());
}
return execute(api.createAudioTranscription(builder.build()));
}
/**
* whisper-1
*
* Supported formats: ['flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'mpga', 'oga', 'ogg', 'wav', 'webm']
*
* @param request create audio traslation request
* @return audio
*/
public AudioResponse createAudioTranslation(CreateAudioTranslationRequest request) {
byte[] fileBytes;
if (request.getFile() != null && request.getFile().length > 0) {
fileBytes = request.getFile();
} else {
File file = new File(request.getFilePath());
try {
fileBytes = Files.readAllBytes(file.toPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
MultipartBody.Builder builder = new MultipartBody.Builder()
.setType(MultipartBody.FORM)
.addFormDataPart("file", request.getFileName(),
RequestBody.create(fileBytes, MediaType.parse("application/octet-stream")))
.addFormDataPart("model", request.getModel());
if (request.getPrompt() != null) {
builder.addFormDataPart("prompt", request.getModel());
}
if (request.getResponseFormat() != null) {
builder.addFormDataPart("response_format", request.getResponseFormat().value());
}
if (request.getTemperature() != null) {
builder.addFormDataPart("temperature", request.getTemperature().toString());
}
return execute(api.createAudioTranslation(builder.build()));
}
public List listFiles() {
return execute(api.listFiles()).getData();
}
public xyz.felh.openai.file.File uploadFile(File file, xyz.felh.openai.file.File.Purpose purpose) {
RequestBody purposeBody = RequestBody.create(purpose.value(), okhttp3.MultipartBody.FORM);
RequestBody fileBody = RequestBody.create(file, MediaType.parse("text"));
MultipartBody.Part body = MultipartBody.Part.createFormData("file", file.getName(), fileBody);
return execute(api.uploadFile(body, purposeBody));
}
public xyz.felh.openai.file.File uploadFile(String filepath, xyz.felh.openai.file.File.Purpose purpose) {
File file = new File(filepath);
return uploadFile(file, purpose);
}
public DeleteResponse deleteFile(String fileId) {
return execute(api.deleteFile(fileId));
}
public xyz.felh.openai.file.File retrieveFile(String fileId) {
return execute(api.retrieveFile(fileId));
}
/**
* @param fileId file id
* @return file content
*/
public String retrieveFileContent(String fileId) {
return execute(api.retrieveFileContent(fileId));
}
/**
* text-moderation-stable, text-moderation-latest
*
* @param request create moderation request
* @return moderation response
*/
public CreateModerationResponse createModeration(CreateModerationRequest request) {
return execute(api.createModeration(request));
}
/**
* Creates a job that fine-tunes a specified model from a given dataset.
*
* @param request CreateFineTuningJobRequest
* @return A fine-tuning.job object.
*/
public FineTuningJob createFineTuningJob(CreateFineTuningJobRequest request) {
return execute(api.createFineTuningJob(request));
}
/**
* Get info about a fine-tuning job.
*
* @param fineTuningJobId fineTuningJobId
* @return The fine-tuning object with the given ID.
*/
public FineTuningJob retrieveFineTuningJob(String fineTuningJobId) {
return execute(api.retrieveFineTuningJob(fineTuningJobId));
}
/**
* Immediately cancel a fine-tune job.
*
* @param fineTuningJobId fineTuningJobId
* @return The cancelled fine-tuning object.
*/
public FineTuningJob cancelFineTuningJob(String fineTuningJobId) {
return execute(api.cancelFineTuningJob(fineTuningJobId));
}
/**
* Get status updates for a fine-tuning job.
*
* @param fineTuningJobId fineTuningJobId
* @param after Identifier for the last event from the previous pagination request.
* @param limit Number of events to retrieve.
* @return A list of fine-tuning event objects.
*/
public List listFineTuningEvents(String fineTuningJobId,
String after,
Integer limit) {
return execute(api.listFineTuningEvents(fineTuningJobId, after, limit)).getData();
}
/**
* Creates and executes a batch from an uploaded file of requests
*
* @param request create batch request
* @return a batch object
*/
public Batch createBatch(CreateBatchRequest request) {
return execute(api.createBatch(request));
}
/**
* Retrieves a batch.
*
* @param batchId batch id
* @return a batch object
*/
public Batch retrieveBatch(String batchId) {
return execute(api.retrieveBatch(batchId));
}
/**
* Cancels an in-progress batch.
*
* @param batchId batch id
* @return a batch object
*/
public Batch cancelBatch(String batchId) {
return execute(api.cancelBatch(batchId));
}
/**
* List your organization's batches.
*
* @param after A cursor for use in pagination
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @return a list of batch
*/
public List listBatches(String after,
Integer limit) {
return execute(api.listBatches(after, limit)).getData();
}
/***************** Assistant BETA ****************/
/**
* {@literal POST https://api.openai.com/v1/assistants}
*
* Create an assistant with a model and instructions.
*
* @param request Request body
* @return An {@link Assistant} object.
*/
public Assistant createAssistant(CreateAssistantRequest request) {
return execute(api.createAssistant(request));
}
/**
* Retrieves an assistant.
*
* @param assistantId The ID of the assistant to retrieve.
* @return The {@link Assistant} object matching the specified ID.
*/
public Assistant retrieveAssistant(String assistantId) {
return execute(api.retrieveAssistant(assistantId));
}
/**
* {@literal POST https://api.openai.com/v1/assistants/{assistant_id}}
*
* Modifies an assistant.
*
* @param assistantId The ID of the assistant to modify.
* @param request Request body
* @return The modified {@link Assistant} object.
*/
public Assistant modifyAssistant(String assistantId, ModifyAssistantRequest request) {
return execute(api.modifyAssistant(assistantId, request));
}
/**
* {@literal DELETE https://api.openai.com/v1/assistants/{assistant_id}}
*
* Delete an assistant.
*
* @param assistantId The ID of the assistant to delete.
* @return Deletion status
*/
public DeleteResponse deleteAssistant(String assistantId) {
return execute(api.deleteAssistant(assistantId));
}
/**
* {@literal GET https://api.openai.com/v1/assistants}
*
* Returns a list of assistants.
*
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @return A list of {@link Assistant} objects.
*/
public OpenAiApiListResponse listAssistants(Integer limit, String order, String after, String before) {
return execute(api.listAssistants(limit, order, after, before));
}
public OpenAiApiListResponse listAssistants() {
return listAssistants(null, null, null, null);
}
/**
* {@literal POST https://api.openai.com/v1/threads}
*
* Create thread
*
* @param request Request body
* @return An {@link Thread} object.
*/
public Thread createThread(CreateThreadRequest request) {
return execute(api.createThread(request));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}}
*
* Retrieve thread
*
* @param threadId The ID of the thread to retrieve.
* @return The {@link Thread} object matching the specified ID.
*/
public Thread retrieveThread(String threadId) {
return execute(api.retrieveThread(threadId));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}}
*
* Modify thread
*
* @param threadId The ID of the thread to modify. Only the metadata can be modified.
* @param request Request body
* @return The modified {@link java.lang.Thread} object matching the specified ID.
*/
public Thread modifyThread(String threadId, ModifyThreadRequest request) {
return execute(api.modifyThread(threadId, request));
}
/**
* {@literal DELETE https://api.openai.com/v1/threads/{thread_id}}
*
* Delete thread
*
* @param threadId The ID of the thread to delete.
* @return Deletion status
*/
public DeleteResponse deleteThread(String threadId) {
return execute(api.deleteThread(threadId));
}
/********************* Messages BETA *************/
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/messages}
*
* Create message
*
* @param threadId The ID of the {@link Thread} to create a message for.
* @param request Request body
* @return An {@link Message} object.
*/
public Message createThreadMessage(String threadId, CreateMessageRequest request) {
return execute(api.createThreadMessage(threadId, request));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/messages/{message_id}}
*
* Retrieve message
*
* @param threadId The ID of the {@link Thread} to which this message belongs.
* @param messageId The ID of the message to retrieve.
* @return The {@link Message} object matching the specified ID.
*/
public Message retrieveThreadMessage(String threadId, String messageId) {
return execute(api.retrieveThreadMessage(threadId, messageId));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/messages/{message_id}}
*
* Modify message
*
* @param threadId The ID of the thread to which this message belongs.
* @param messageId The ID of the message to modify.
* @param request Request body
* @return The modified {@link Message} object.
*/
public Message modifyThreadMessage(String threadId, String messageId, ModifyMessageRequest request) {
return execute(api.modifyThreadMessage(threadId, messageId, request));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/messages}
*
* List messages
*
* @param threadId The ID of the {@link Thread} the messages belong to.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @param runId Filter messages by the run ID that generated them.
* @return A list of {@link Message} objects.
*/
public OpenAiApiListResponse listThreadMessages(String threadId,
Integer limit,
String order,
String after,
String before,
String runId) {
return execute(api.listThreadMessages(threadId, limit, order, after, before, runId));
}
public OpenAiApiListResponse listThreadMessages(String threadId, String runId) {
return listThreadMessages(threadId, null, null, null, null, runId);
}
/********************* Runs BETA *************/
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/runs}
*
* Create a run.
*
* @param threadId The ID of the thread to run.
* @param request Request body
* @return An {@link Run} object.
*/
public Run createThreadRun(String threadId, CreateRunRequest request) {
return execute(api.createThreadRun(threadId, request));
}
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/runs}
*
* Create a run with stream = true.
*
* @param threadId The ID of the thread to run.
* @param request Request body
* @param listener listener
* @return An {@link Run} object.
*/
public void createThreadRun(String requestId,
String threadId,
CreateRunRequest request,
@NonNull StreamListener listener) {
request.setStream(true);
Request okHttpRequest;
try {
okHttpRequest = new Request.Builder().url(
String.format("%s/v1/threads/%s/runs", BASE_URL, threadId))
.header("content-type", "text/event-stream")
.header("Accept", "text/event-stream")
.post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
MediaType.parse("application/json")))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
createAssistantStreamEvent(requestId, okHttpRequest, listener);
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}}
*
* Retrieves a run.
*
* @param threadId The ID of the {@link Thread} that was run.
* @param runId The ID of the run to retrieve.
* @return The {@link Run} object matching the specified ID.
*/
public Run retrieveThreadRun(String threadId, String runId) {
return execute(api.retrieveThreadRun(threadId, runId));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}}
*
* Modify message
*
* @param threadId The ID of the {@link Thread} that was run.
* @param runId The ID of the run to modify.
* @param request Request body
* @return The modified {@link Run} object.
*/
public Run modifyThreadRun(String threadId, String runId, ModifyRunRequest request) {
return execute(api.modifyThreadRun(threadId, runId, request));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/runs}
*
* List messages
*
* @param threadId The ID of the {@link Thread} the messages belong to.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @return A list of {@link Run} objects.
*/
public OpenAiApiListResponse listThreadRuns(
String threadId,
Integer limit,
String order,
String after,
String before) {
return execute(api.listThreadRuns(threadId, limit, order, after, before));
}
public OpenAiApiListResponse listThreadRuns(String threadId) {
return listThreadRuns(threadId, null, null, null, null);
}
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs}
*
* @param threadId The ID of the {@link Thread} to which this run belongs.
* @param runId The ID of the run that requires the tool output submission.
* @param request Request body
* @return The modified {@link Run} object matching the specified ID.
*/
public Run submitToolOutputs(String threadId, String runId, SubmitToolOutputsRequest request) {
return execute(api.submitToolOutputs(threadId, runId, request));
}
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs}
*
* @param threadId The ID of the {@link Thread} to which this run belongs.
* @param runId The ID of the run that requires the tool output submission.
* @param request Request body
* @param listener stream event listener
* @return
*/
public void submitToolOutputs(String requestId,
String threadId,
String runId,
SubmitToolOutputsRequest request,
@NonNull StreamListener listener) {
request.setStream(true);
Request okHttpRequest;
try {
okHttpRequest = new Request.Builder().url(
String.format("%sv1/threads/%s/runs/%s/submit_tool_outputs", BASE_URL, threadId, runId))
.header("content-type", "text/event-stream")
.header("Accept", "text/event-stream")
.post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
MediaType.parse("application/json")))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
createAssistantStreamEvent(requestId, okHttpRequest, listener);
}
/**
* {@literal POST https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/cancel}
*
* Cancels a run that is in_progress.
*
* @param threadId The ID of the thread to which this run belongs.
* @param runId The ID of the run to cancel.
* @return The modified {@link Run} object matching the specified ID.
*/
public Run cancelThreadRun(String threadId, String runId) {
return execute(api.cancelThreadRun(threadId, runId));
}
/**
* {@literal POST https://api.openai.com/v1/threads/runs}
*
* Create a thread and run it in one request.
*
* @param request Request body
* @return A {@link Run} object.
*/
public Run createThreadAndRun(CreateThreadAndRunRequest request) {
return execute(api.createThreadAndRun(request));
}
/**
* {@literal POST https://api.openai.com/v1/threads/runs}
*
* Create a thread and run it in one request with stream = true.
*
* @param request Request body
* @return A {@link Run} object.
*/
public void createThreadAndRun(String requestId,
CreateThreadAndRunRequest request,
@NonNull StreamListener listener) {
request.setStream(true);
Request okHttpRequest;
try {
okHttpRequest = new Request.Builder().url(
String.format("%s/v1/threads/runs", BASE_URL))
.header("content-type", "text/event-stream")
.header("Accept", "text/event-stream")
.post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
MediaType.parse("application/json")))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
createAssistantStreamEvent(requestId, okHttpRequest, listener);
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}}
*
* Retrieves a run step.
*
* @param threadId The ID of the thread to which the run and run step belongs.
* @param runId The ID of the run to which the run step belongs.
* @param stepId The ID of the run step to retrieve.
* @return The {@link RunStep} object matching the specified ID.
*/
public RunStep retrieveThreadRunStep(String threadId, String runId, String stepId) {
return execute(api.retrieveThreadRunStep(threadId, runId, stepId));
}
/**
* {@literal GET https://api.openai.com/v1/threads/{thread_id}/runs/{run_id}/steps}
*
* List run steps
*
* @param threadId The ID of the {@link Thread} the messages belong to.
* @param runId The ID of the run steps belong to.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @return A list of {@link RunStep} objects.
*/
public OpenAiApiListResponse listThreadRunSteps(String threadId,
String runId,
Integer limit,
String order,
String after,
String before) {
return execute(api.listThreadRunSteps(threadId, runId, limit, order, after, before));
}
public OpenAiApiListResponse listThreadRunSteps(String threadId, String runId) {
return listThreadRunSteps(threadId, runId, null, null, null, null);
}
/**
* init assistant stream event SSE
*
* @param requestId request id
* @param okHttpRequest request
* @param listener listener
*/
private void createAssistantStreamEvent(String requestId, Request okHttpRequest, StreamListener listener) {
EventSource.Factory factory = EventSources.createFactory(client);
EventSourceListener eventSourceListener = new EventSourceListener() {
@Override
public void onOpen(@NonNull EventSource eventSource, @NonNull Response response) {
listener.onOpen(requestId, response);
}
@Override
public void onEvent(@NonNull EventSource eventSource, @Nullable String id, @Nullable String type, @NonNull String data) {
/**
* event: thread.created
* data: {"id": "thread_123", "object": "thread", ...}
*/
if (Preconditions.isNotBlank(type)) {
switch (type) {
case "thread.created" ->
listener.onEvent(requestId, JSONObject.parseObject(data, Thread.class));
case "thread.run.created", "thread.run.cancelled", "thread.run.completed",
"thread.run.in_progress", "thread.run.requires_action", "thread.run.failed",
"thread.run.queued", "thread.run.expired", "thread.run.cancelling" ->
listener.onEvent(requestId, JSONObject.parseObject(data, Run.class));
case "thread.run.step.created", "thread.run.step.in_progress", "thread.run.step.completed",
"thread.run.step.cancelled", "thread.run.step.expired", "thread.run.step.failed" ->
listener.onEvent(requestId, JSONObject.parseObject(data, RunStep.class));
case "thread.run.step.delta" ->
listener.onEvent(requestId, JSONObject.parseObject(data, RunStepDelta.class));
case "thread.message.created", "thread.message.in_progress", "thread.message.completed",
"thread.message.incomplete" ->
listener.onEvent(requestId, JSONObject.parseObject(data, Message.class));
case "thread.message.delta" ->
listener.onEvent(requestId, JSONObject.parseObject(data, MessageDelta.class));
case "error" ->
listener.onEvent(requestId, JSONObject.parseObject(data, OpenAiError.ErrorDetail.class));
case "done" -> {
if (data.equals("[DONE]")) {
listener.onEventDone(requestId);
}
}
default -> log.warn("not match any type");
}
}
}
@Override
public void onClosed(@NonNull EventSource eventSource) {
listener.onClosed(requestId);
}
@Override
public void onFailure(@NonNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
listener.onFailure(requestId, t, response);
}
};
EventSource eventSource = factory.newEventSource(okHttpRequest, eventSourceListener);
listener.setEventSource(eventSource);
}
/*******************************8 Vector Stores ****************/
/**
* {@linkplain POST https://api.openai.com/v1/vector_stores
*
* Create a vector store.
*
* @param request Request body
* @return An {@link VectorStore } object.
*/
public VectorStore createVectorStore(CreateVectorStoreRequest request) {
return execute(api.createVectorStore(request));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores/{vector_store_id}
*
* Retrieves a vector store.
*
* @param vectorStoreId The ID of the vector store to retrieve.
* @return The {@link VectorStore} object matching the specified ID.
*/
public VectorStore retrieveVectorStore(String vectorStoreId) {
return execute(api.retrieveVectorStore(vectorStoreId));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores/{vector_store_id}
*
* Modifies a vector store.
*
* @param vectorStoreId The ID of the vector store to modify.
* @param request Request body
* @return The modified {@link VectorStore} object.
*/
public VectorStore modifyVectorStore(String vectorStoreId,
ModifyVectorStoreRequest request) {
return execute(api.modifyVectorStore(vectorStoreId, request));
}
/**
* {@linkplain DELETE https://api.openai.com/v1/vector_stores/{vector_store_id}
*
* Delete a vector store.
*
* @param vectorStoreId The ID of the vector store to delete.
* @return Deletion status
*/
public DeleteResponse deleteVectorStore(String vectorStoreId) {
return execute(api.deleteVectorStore(vectorStoreId));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores
*
* List vector stores
*
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @return A list of {@link VectorStore} objects.
*/
public OpenAiApiListResponse listVectorStores(
Integer limit,
String order,
String after,
String before) {
return execute(api.listVectorStores(limit, order, after, before));
}
public OpenAiApiListResponse listVectorStores() {
return this.listVectorStores(null, null, null, null);
}
/****************** Vector Store Files ******************/
/**
* {@linkplain POST https://api.openai.com/v1/vector_stores/{vector_store_id}/files
*
* Create vector store file
*
* @param vectorStoreId The ID of the {@link VectorStore} for which to create a File.
* @param request Request body
* @return An {@link VectorStoreFile } object.
*/
public VectorStoreFile createVectorStoreFile(String vectorStoreId,
CreateVectorStoreFileRequest request) {
return execute(api.createVectorStoreFile(vectorStoreId, request));
}
/**
* {@linkplain DELETE https://api.openai.com/v1/vector_stores/{vector_store_id}/files/{file_id}
*
* Delete vector store file
*
* @param vectorStoreId The ID of the vector store that the file belongs to.
* @param fileId The ID of the file to delete.
* @return Deletion status
*/
public DeleteResponse deleteVectorStoreFile(String vectorStoreId,
String fileId) {
return execute(api.deleteVectorStoreFile(vectorStoreId, fileId));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores/{vector_store_id}/files
*
* List vector store files
*
* @param vectorStoreId The ID of the vector store that the files belong to.
* @param filter Filter by file status. One of in_progress, completed, failed, cancelled of {@link xyz.felh.openai.assistant.vector.store.file.VectorStoreFile.Status}.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @return A list of {@link VectorStoreFile} objects.
*/
public OpenAiApiListResponse listVectorStoreFiles(
String vectorStoreId,
Integer limit,
String order,
String after,
String before,
String filter) {
return execute(api.listVectorStoreFiles(vectorStoreId, limit, order, after, before, filter));
}
public OpenAiApiListResponse listVectorStoreFiles(
String vectorStoreId, String filter) {
return this.listVectorStoreFiles(vectorStoreId, null, null, null, null, filter);
}
/****************** Vector Store File Batch **********************/
/**
* {@linkplain POST https://api.openai.com/v1/vector_stores/{vector_store_id}/file_batches
*
* Create a vector store file batch.
*
* @param vectorStoreId The ID of the vector store for which to create a File Batch.
* @param request Request body
* @return An {@link VectorStoreFileBatch } object.
*/
public VectorStoreFileBatch createVectorStoreFileBatch(String vectorStoreId,
CreateVectorStoreFileBatchRequest request) {
return execute(api.createVectorStoreFileBatch(vectorStoreId, request));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
*
* Retrieves a vector store file batch.
*
* @param vectorStoreId The ID of the vector store that the file batch belongs to.
* @param batchId The ID of the file batch being retrieved.
* @return The {@link VectorStoreFileBatch} object matching the specified ID.
*/
public VectorStoreFileBatch retrieveVectorStoreFileBatch(String vectorStoreId,
String batchId) {
return execute(api.retrieveVectorStoreFileBatch(vectorStoreId, batchId));
}
/**
* {@linkplain POST https://api.openai.com/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
*
* Cancel vector store file batch
*
* @param vectorStoreId The ID of the vector store that the file batch belongs to.
* @param batchId The ID of the file batch to cancel.
* @return The modified {@link VectorStoreFileBatch} object.
*/
public VectorStoreFileBatch cancelVectorStoreFileBatch(String vectorStoreId,
String batchId) {
return execute(api.cancelVectorStoreFileBatch(vectorStoreId, batchId));
}
/**
* {@linkplain GET https://api.openai.com/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
*
* List vector store files in a batchBeta
*
* @param vectorStoreId The ID of the vector store that the files belong to.
* @param limit A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
* @param order Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
* @param after A cursor for use in pagination. after is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
* @param before A cursor for use in pagination. before is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
* @param batchId The ID of the file batch that the files belong to.
* @param filter Filter by file status. One of in_progress, completed, failed, cancelled.{@link xyz.felh.openai.assistant.vector.store.file.batch.VectorStoreFileBatch.Status}
* @return A list of {@link VectorStoreFileBatch} objects.
*/
public OpenAiApiListResponse listVectorStoreFileBatches(
String vectorStoreId,
String batchId,
Integer limit,
String order,
String after,
String before,
String filter) {
return execute(api.listVectorStoreFileBatches(vectorStoreId, batchId, limit, order, after, before, filter));
}
public OpenAiApiListResponse listVectorStoreFileBatches(
String vectorStoreId,
String batchId,
String filter) {
return this.listVectorStoreFileBatches(vectorStoreId, batchId, null, null, null, null, filter);
}
}