org.devlive.sdk.openai.DefaultClient Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of openai-java-sdk Show documentation
Show all versions of openai-java-sdk Show documentation
Provides an easy-to-use SDK for Java developers to interact with the APIs of open AI models.
package org.devlive.sdk.openai;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.lang3.ObjectUtils;
import org.devlive.sdk.openai.entity.AudioEntity;
import org.devlive.sdk.openai.entity.ChatEntity;
import org.devlive.sdk.openai.entity.CompletionEntity;
import org.devlive.sdk.openai.entity.EditEntity;
import org.devlive.sdk.openai.entity.EmbeddingEntity;
import org.devlive.sdk.openai.entity.FileEntity;
import org.devlive.sdk.openai.entity.ImageEntity;
import org.devlive.sdk.openai.entity.ModelEntity;
import org.devlive.sdk.openai.entity.ModerationEntity;
import org.devlive.sdk.openai.entity.UserKeyEntity;
import org.devlive.sdk.openai.entity.google.MessageEntity;
import org.devlive.sdk.openai.exception.RequestException;
import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
import org.devlive.sdk.openai.model.ProviderModel;
import org.devlive.sdk.openai.model.UrlModel;
import org.devlive.sdk.openai.response.AudioResponse;
import org.devlive.sdk.openai.response.ChatResponse;
import org.devlive.sdk.openai.response.CompleteResponse;
import org.devlive.sdk.openai.response.EditResponse;
import org.devlive.sdk.openai.response.EmbeddingResponse;
import org.devlive.sdk.openai.response.FileResponse;
import org.devlive.sdk.openai.response.ImageResponse;
import org.devlive.sdk.openai.response.ModelResponse;
import org.devlive.sdk.openai.response.ModerationResponse;
import org.devlive.sdk.openai.response.UserKeyResponse;
import org.devlive.sdk.openai.utils.MultipartBodyUtils;
import org.devlive.sdk.openai.utils.ProviderUtils;
@Slf4j
public abstract class DefaultClient
implements AutoCloseable
{
protected DefaultApi api;
protected ProviderModel provider;
protected OkHttpClient client;
protected String apiHost;
protected EventSourceListener listener;
public ModelResponse getModels()
{
return this.api.fetchModels(ProviderUtils.getUrl(provider, UrlModel.FETCH_MODELS))
.blockingGet();
}
public ModelEntity getModel(String model)
{
return this.api.fetchModel(model)
.blockingGet();
}
public CompleteResponse createCompletion(CompletionEntity configure)
{
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS);
if (ObjectUtils.isNotEmpty(this.listener)) {
configure.setStream(true);
this.createEventSource(url, configure);
return null;
}
else {
return this.api.fetchCompletions(url, configure)
.blockingGet();
}
}
public CompleteResponse createPaLMCompletion(org.devlive.sdk.openai.entity.google.CompletionEntity configure)
{
return this.api.fetchPaLMCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
.blockingGet();
}
public CompleteResponse createPaLMChat(org.devlive.sdk.openai.entity.google.ChatEntity configure)
{
MessageEntity message = MessageEntity.builder()
.content("NEXT REQUEST")
.build();
configure.getPrompt().getMessages()
.add(message);
return this.api.fetchPaLMChat(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
.blockingGet();
}
public ChatResponse createChatCompletion(ChatEntity configure)
{
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_CHAT_COMPLETIONS);
if (ObjectUtils.isNotEmpty(this.listener)) {
configure.setStream(true);
this.createEventSource(url, configure);
return null;
}
else {
return this.api.fetchChatCompletions(url, configure)
.blockingGet();
}
}
public UserKeyResponse getKeys()
{
return this.api.fetchUserAPIKeys()
.blockingGet();
}
public UserKeyResponse createUserAPIKey(UserKeyEntity configure)
{
return this.api.fetchCreateUserAPIKey(configure)
.blockingGet();
}
public ImageResponse createImages(ImageEntity configure)
{
configure.setIsVariation(null);
configure.setIsEdit(null);
return this.api.fetchImagesGenerations(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_GENERATIONS), configure)
.blockingGet();
}
public ImageResponse editImages(ImageEntity configure)
{
MultipartBody.Part imageBody = MultipartBodyUtils.getPart(configure.getImage(), "image");
MultipartBody.Part maskBody = null;
if (ObjectUtils.isNotEmpty(configure.getMask())) {
maskBody = MultipartBodyUtils.getPart(configure.getMask(), "mask");
}
return this.api.fetchImagesEdits(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_EDITS),
imageBody,
maskBody,
configure.convertMap())
.blockingGet();
}
public ImageResponse variationsImages(ImageEntity configure)
{
MultipartBody.Part imageBody = MultipartBodyUtils.getPart(configure.getImage(), "image");
return this.api.fetchImagesVariations(ProviderUtils.getUrl(provider, UrlModel.FETCH_IMAGES_VARIATIONS),
imageBody,
configure.convertMap())
.blockingGet();
}
public EmbeddingResponse createEmbeddings(EmbeddingEntity configure)
{
return this.api.fetchEmbeddings(ProviderUtils.getUrl(provider, UrlModel.FETCH_EMBEDDINGS),
configure)
.blockingGet();
}
public AudioResponse audioTranscriptions(AudioEntity configure)
{
MultipartBody.Part fileBody = MultipartBodyUtils.getPart(configure.getFile(), "file");
return this.api.fetchAudioTranscriptions(ProviderUtils.getUrl(provider, UrlModel.FETCH_AUDIO_TRANSCRIPTIONS),
fileBody,
configure.convertMap())
.blockingGet();
}
public ModerationResponse moderations(ModerationEntity configure)
{
return this.api.fetchModerations(ProviderUtils.getUrl(provider, UrlModel.FETCH_MODERATIONS), configure)
.blockingGet();
}
public EditResponse edit(EditEntity configure)
{
return this.api.fetchEdits(ProviderUtils.getUrl(provider, UrlModel.FETCH_EDITS), configure)
.blockingGet();
}
public FileResponse files()
{
return this.api.fetchFiles(ProviderUtils.getUrl(provider, UrlModel.FETCH_FILES))
.blockingGet();
}
public FileEntity uploadFile(FileEntity configure)
{
MultipartBody.Part fileBody = MultipartBodyUtils.getPart(configure.getFile(), "file");
return this.api.fetchUploadFile(ProviderUtils.getUrl(provider, UrlModel.FETCH_FILES),
fileBody,
configure.convertMap())
.blockingGet();
}
public FileResponse deleteFile(String id)
{
String url = String.join("/", ProviderUtils.getUrl(provider, UrlModel.FETCH_FILES), id);
return this.api.fetchDeleteFile(url)
.blockingGet();
}
public FileEntity retrieveFile(String id)
{
String url = String.join("/", ProviderUtils.getUrl(provider, UrlModel.FETCH_FILES), id);
return this.api.fetchRetrieveFile(url)
.blockingGet();
}
public Object retrieveFileContent(String id)
{
String url = String.join("/", ProviderUtils.getUrl(provider, UrlModel.FETCH_FILES), id, "content");
return this.api.fetchRetrieveFileContent(url)
.blockingGet();
}
private ObjectMapper createObjectMapper()
{
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class);
return objectMapper;
}
private void createEventSource(String url, Object configure)
{
try {
EventSource.Factory factory = EventSources.createFactory(this.client);
ObjectMapper mapper = this.createObjectMapper();
Request request = new Request.Builder()
.url(String.join("/", this.apiHost, url))
.post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure)))
.build();
factory.newEventSource(request, this.listener);
}
catch (Exception e) {
throw new RequestException(String.format("Failed to create event source: %s", e.getMessage()));
}
}
public void close()
{
if (ObjectUtils.isNotEmpty(this.client)) {
this.client.dispatcher().cancelAll();
this.client.connectionPool().evictAll();
this.client = null;
}
}
}