All Downloads are FREE. Search and download functionalities are using the official Maven repository.

xyz.felh.openai.OpenAiService Maven / Gradle / Ivy

The newest version!
package xyz.felh.openai;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import io.reactivex.rxjava3.core.Single;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.jetbrains.annotations.Nullable;
import retrofit2.HttpException;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava3.RxJava3CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
import xyz.felh.openai.audio.AudioResponse;
import xyz.felh.openai.audio.CreateAudioTranscriptionRequest;
import xyz.felh.openai.audio.CreateAudioTranslationRequest;
import xyz.felh.openai.completion.Completion;
import xyz.felh.openai.completion.CreateCompletionRequest;
import xyz.felh.openai.completion.chat.ChatCompletion;
import xyz.felh.openai.completion.chat.CreateChatCompletionRequest;
import xyz.felh.openai.edit.CreateEditRequest;
import xyz.felh.openai.edit.Edit;
import xyz.felh.openai.embedding.CreateEmbeddingRequest;
import xyz.felh.openai.embedding.CreateEmbeddingResponse;
import xyz.felh.openai.file.RetrieveFileContentResponse;
import xyz.felh.openai.finetune.CreateFineTuneRequest;
import xyz.felh.openai.finetune.FineTune;
import xyz.felh.openai.finetune.FineTuneEvent;
import xyz.felh.openai.image.CreateImageRequest;
import xyz.felh.openai.image.ImageResponse;
import xyz.felh.openai.image.edit.CreateImageEditRequest;
import xyz.felh.openai.image.variation.CreateImageVariationRequest;
import xyz.felh.openai.model.Model;
import xyz.felh.openai.moderation.CreateModerationRequest;
import xyz.felh.openai.moderation.CreateModerationResponse;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.TimeUnit;


/**
 * OpenAi Service Class
 */
@Slf4j
public class OpenAiService {

    private static final String BASE_URL = "https://api.openai.com";
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(30);
    private static final ObjectMapper errorMapper = defaultObjectMapper();

    private final OpenAiApi api;

    private final OkHttpClient client;

    /**
     * Creates a new OpenAiService that wraps OpenAiApi
     *
     * @param token OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
     */
    public OpenAiService(final String token) {
        this(token, DEFAULT_TIMEOUT);
    }

    /**
     * Creates a new OpenAiService that wraps OpenAiApi
     *
     * @param token   OpenAi token string "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
     * @param timeout http read timeout, Duration.ZERO means no timeout
     */
    public OpenAiService(final String token, final Duration timeout) {
        this(buildApi(token, timeout), defaultClient(token, timeout));
    }

    /**
     * Creates a new OpenAiService that wraps OpenAiApi.
     * Use this if you need more customization.
     *
     * @param api OpenAiApi instance to use for all methods
     */
    public OpenAiService(final OpenAiApi api, final OkHttpClient client) {
        this.api = api;
        this.client = client;
    }

    /**
     * Calls the Open AI api, returns the response, and parses error messages if the request fails
     */
    public static  T execute(Single apiCall) {
        try {
            return apiCall.blockingGet();
        } catch (HttpException e) {
            try {
                if (e.response() == null || e.response().errorBody() == null) {
                    throw e;
                }
                String errorBody = e.response().errorBody().string();

                OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
                throw new OpenAiHttpException(error, e, e.code());
            } catch (IOException ex) {
                // couldn't parse OpenAI error
                throw e;
            }
        }
    }

    public static OpenAiApi buildApi(String token, Duration timeout) {
        ObjectMapper mapper = defaultObjectMapper();
        OkHttpClient client = defaultClient(token, timeout);
        Retrofit retrofit = defaultRetrofit(client, mapper);
        return retrofit.create(OpenAiApi.class);
    }

    public static ObjectMapper defaultObjectMapper() {
        ObjectMapper mapper = new ObjectMapper();
        mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
        mapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
        return mapper;
    }

    public static OkHttpClient defaultClient(String token, Duration timeout) {
        return defaultClient(token, null, timeout);
    }

    public static OkHttpClient defaultClient(String token, String orgId, Duration timeout) {
        return new OkHttpClient.Builder()
                .addInterceptor(new AuthenticationInterceptor(token, orgId))
                .connectionPool(new ConnectionPool(10, 4, TimeUnit.SECONDS))
                .readTimeout(timeout.toMillis(), TimeUnit.MILLISECONDS)
                .build();
    }

    public static Retrofit defaultRetrofit(OkHttpClient client, ObjectMapper mapper) {
        return new Retrofit.Builder()
                .baseUrl(BASE_URL)
                .client(client)
                .addConverterFactory(JacksonConverterFactory.create(mapper))
                .addCallAdapterFactory(RxJava3CallAdapterFactory.create())
                .build();
    }

    public List listModels() {
        return execute(api.listModels()).getData();
    }

    public Model getModel(String modelId) {
        return execute(api.getModel(modelId));
    }

    /**
     * text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001,
     * text-ada-001, davinci, curie, babbage, ada
     *
     * @param request create completion request
     * @return completion
     */
    public Completion createCompletion(CreateCompletionRequest request) {
        return execute(api.createCompletion(request));
    }

    /**
     * gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
     *
     * @param request create chat completion request
     * @return chat completion
     */
    public ChatCompletion createChatCompletion(CreateChatCompletionRequest request) {
        request.setStream(false);
        return execute(api.createChatCompletion(request));
    }

    /**
     * create chat completion by stream
     *
     * @param requestId request ID, every observer is unique
     * @param request   detail of request
     * @param listener  StreamChatCompletionListener
     */
    public void createSteamChatCompletion(String requestId,
                                          CreateChatCompletionRequest request,
                                          @NonNull StreamChatCompletionListener listener) {
        request.setStream(true);
        Request okHttpRequest;
        try {
            okHttpRequest = new Request.Builder().url(BASE_URL + "/v1/chat/completions")
                    .header("content-type", "text/event-stream")
                    .header("Accept", "text/event-stream")
                    .post(RequestBody.create(defaultObjectMapper().writeValueAsString(request),
                            MediaType.parse("application/json")))
                    .build();
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        EventSource.Factory factory = EventSources.createFactory(client);
        EventSourceListener eventSourceListener = new EventSourceListener() {
            @Override
            public void onOpen(@NonNull EventSource eventSource, @NonNull Response response) {
                listener.onOpen(requestId, response);
            }

            @Override
            public void onEvent(@NonNull EventSource eventSource, @Nullable String id, @Nullable String type, @NonNull String data) {
                if (data.equals("[DONE]")) {
                    listener.onEventDone(requestId);
                } else {
                    try {
                        ChatCompletion chatCompletion = defaultObjectMapper().readValue(data, ChatCompletion.class);
                        listener.onEvent(requestId, chatCompletion);
                    } catch (JsonProcessingException e) {
                        throw new RuntimeException(e);
                    }
                }
            }

            @Override
            public void onClosed(@NonNull EventSource eventSource) {
                listener.onClosed(requestId);
            }

            @Override
            public void onFailure(@NonNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
                listener.onFailure(requestId, t, response);
            }
        };
        EventSource eventSource = factory.newEventSource(okHttpRequest, eventSourceListener);
        listener.setEventSource(eventSource);
    }

    /**
     * text-davinci-edit-001, code-davinci-edit-001
     *
     * @param request create edit request
     * @return edit
     */
    public Edit createEdit(CreateEditRequest request) {
        return execute(api.createEdit(request));
    }

    public ImageResponse createImage(CreateImageRequest request) {
        return execute(api.createImage(request));
    }

    public ImageResponse createImageEdit(CreateImageEditRequest request) {
        byte[] imageBytes;
        if (request.getImage() != null && request.getImage().length > 0) {
            imageBytes = request.getImage();
        } else {
            File image = new File(request.getImagePath());
            try {
                imageBytes = Files.readAllBytes(image.toPath());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        byte[] maskBytes = null;
        if (request.getMask() != null && request.getMask().length > 0) {
            maskBytes = request.getMask();
        } else if (request.getMaskPath() != null && request.getMaskPath().length() > 0) {
            File mask = new File(request.getMaskPath());
            try {
                maskBytes = Files.readAllBytes(mask.toPath());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return createImageEdit(request, imageBytes, maskBytes);
    }

    private ImageResponse createImageEdit(CreateImageEditRequest request, byte[] image, byte[] mask) {
        RequestBody imageBody = RequestBody.create(image, MediaType.parse("image"));
        MultipartBody.Builder builder = new MultipartBody.Builder()
                .setType(MediaType.get("multipart/form-data"))
                .addFormDataPart("prompt", request.getPrompt())
                .addFormDataPart("image", "image", imageBody);
        if (request.getN() != null) {
            builder.addFormDataPart("n", request.getN().toString());
        }
        if (request.getSize() != null) {
            builder.addFormDataPart("size", request.getSize());
        }
        if (request.getResponseFormat() != null) {
            builder.addFormDataPart("response_format", request.getResponseFormat());
        }
        if (mask != null) {
            RequestBody maskBody = RequestBody.create(mask, MediaType.parse("image"));
            builder.addFormDataPart("mask", "mask", maskBody);
        }
        return execute(api.createImageEdit(builder.build()));
    }

    public ImageResponse createImageVariation(CreateImageVariationRequest request) {
        RequestBody imageBody;
        if (request.getImage() != null && request.getImage().length > 0) {
            imageBody = RequestBody.create(request.getImage(), MediaType.parse("image"));
        } else {
            File mask = new File(request.getImagePath());
            try {
                imageBody = RequestBody.create(Files.readAllBytes(mask.toPath()), MediaType.parse("image"));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        MultipartBody.Builder builder = new MultipartBody.Builder()
                .setType(MediaType.get("multipart/form-data"))
                .addFormDataPart("image", "image", imageBody);

        if (request.getN() != null) {
            builder.addFormDataPart("n", request.getN().toString());
        }
        if (request.getSize() != null) {
            builder.addFormDataPart("size", request.getSize());
        }
        if (request.getResponseFormat() != null) {
            builder.addFormDataPart("response_format", request.getResponseFormat());
        }
        return execute(api.createImageVariation(builder.build()));
    }

    /**
     * text-embedding-ada-002, text-search-ada-doc-001
     *
     * @param request create Embedding request
     * @return create embedding response
     */
    public CreateEmbeddingResponse createEmbeddings(CreateEmbeddingRequest request) {
        return execute(api.createEmbedding(request));
    }

    /**
     * whisper-1
     * 

* TODO only support mp3 so far * * @param request create audio transcription request * @return audio response */ public AudioResponse createAudioTranscription(CreateAudioTranscriptionRequest request) { byte[] fileBytes; if (request.getFile() != null && request.getFile().length > 0) { fileBytes = request.getFile(); } else { File file = new File(request.getFilePath()); try { fileBytes = Files.readAllBytes(file.toPath()); } catch (IOException e) { throw new RuntimeException(e); } } RequestBody audioBody = RequestBody.create(fileBytes, MediaType.parse("mp3")); MultipartBody.Builder builder = new MultipartBody.Builder() .setType(MediaType.get("multipart/form-data")) .addFormDataPart("model", request.getModel()) .addFormDataPart("file", "mp3", audioBody); if (request.getPrompt() != null) { builder.addFormDataPart("prompt", request.getModel()); } if (request.getResponseFormat() != null) { builder.addFormDataPart("response_format", request.getResponseFormat()); } if (request.getTemperature() != null) { builder.addFormDataPart("temperature", request.getTemperature().toString()); } if (request.getPrompt() != null) { builder.addFormDataPart("language", request.getLanguage()); } return execute(api.createAudioTranscription(builder.build())); } /** * whisper-1 *

* TODO only support mp3 so far * * @param request create audio traslation request * @return audio */ public AudioResponse createAudioTranslation(CreateAudioTranslationRequest request) { byte[] fileBytes; if (request.getFile() != null && request.getFile().length > 0) { fileBytes = request.getFile(); } else { File file = new File(request.getFilePath()); try { fileBytes = Files.readAllBytes(file.toPath()); } catch (IOException e) { throw new RuntimeException(e); } } RequestBody audioBody = RequestBody.create(fileBytes, MediaType.parse("mp3")); MultipartBody.Builder builder = new MultipartBody.Builder() .setType(MediaType.get("multipart/form-data")) .addFormDataPart("model", request.getModel()) .addFormDataPart("file", "mp3", audioBody); if (request.getPrompt() != null) { builder.addFormDataPart("prompt", request.getModel()); } if (request.getResponseFormat() != null) { builder.addFormDataPart("response_format", request.getResponseFormat()); } if (request.getTemperature() != null) { builder.addFormDataPart("temperature", request.getTemperature().toString()); } return execute(api.createAudioTranslation(builder.build())); } public List listFiles() { return execute(api.listFiles()).getData(); } public xyz.felh.openai.file.File uploadFile(String filepath, String purpose) { File file = new File(filepath); RequestBody purposeBody = RequestBody.create(purpose, okhttp3.MultipartBody.FORM); RequestBody fileBody = RequestBody.create(file, MediaType.parse("text")); MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody); return execute(api.uploadFile(body, purposeBody)); } public DeleteResponse deleteFile(String fileId) { return execute(api.deleteFile(fileId)); } public xyz.felh.openai.file.File retrieveFile(String fileId) { return execute(api.retrieveFile(fileId)); } /** * @param fileId file id * @return file content */ public RetrieveFileContentResponse retrieveFileContent(String fileId) { return execute(api.retrieveFileContent(fileId)); } public FineTune createFineTune(CreateFineTuneRequest request) { return execute(api.createFineTune(request)); } public List listFineTunes() { return execute(api.listFineTunes()).getData(); } public FineTune retrieveFineTune(String fineTuneId) { return execute(api.retrieveFineTune(fineTuneId)); } public FineTune cancelFineTune(String fineTuneId) { return execute(api.cancelFineTune(fineTuneId)); } public List listFineTuneEvents(String fineTuneId) { return execute(api.listFineTuneEvents(fineTuneId)).getData(); } public DeleteResponse deleteFineTune(String model, String fineTuneId) { return execute(api.deleteFineTune(String.format("%s:%s", model, fineTuneId))); } /** * text-moderation-stable, text-moderation-latest * * @param request create moderation request * @return moderation response */ public CreateModerationResponse createModeration(CreateModerationRequest request) { return execute(api.createModeration(request)); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy