All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tjake.jlama.net.openai.OpenAIChatService Maven / Gradle / Ivy

There is a newer version: 0.8.3
Show newest version
/*
 * Copyright 2024 T Jake Luciani
 *
 * The Jlama Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.github.tjake.jlama.net.openai;

import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.net.openai.model.*;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import jakarta.validation.Valid;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

@RestController
@Validated
public class OpenAIChatService {

    private static final Logger logger = LoggerFactory.getLogger(OpenAIChatService.class);
    private static final String JLAMA_SESSION_HEADER = "X-Jlama-Session";

    @Autowired
    private Generator model;

    /**
     * POST /chat/completions : Creates a model response for the given chat conversation.
     *
     * @param request  (required)
     * @return OK (status code 200)
     */
    @RequestMapping(method = RequestMethod.POST, value = "/chat/completions", produces = { "application/json",
        "text/event-stream" }, consumes = { "application/json" })
    Object createChatCompletion(@RequestHeader Map headers, @Valid @RequestBody CreateChatCompletionRequest request) {

        List messages = request.getMessages();

        if (model.promptSupport().isEmpty()) {
            return new ResponseEntity<>(HttpStatus.BAD_GATEWAY);
        }

        UUID id = UUID.randomUUID();

        if (headers.containsKey(JLAMA_SESSION_HEADER)) {
            try {
                id = UUID.fromString(headers.get(JLAMA_SESSION_HEADER));
            } catch (IllegalArgumentException e) {
                return new ResponseEntity<>(HttpStatus.BAD_REQUEST);
            }
        }

        UUID sessionId = id;

        PromptSupport.Builder builder = model.promptSupport().get().builder();

        for (ChatCompletionRequestMessage m : messages) {

            if (m.getActualInstance() instanceof ChatCompletionRequestUserMessage) {
                ChatCompletionRequestUserMessageContent content = m.getChatCompletionRequestUserMessage().getContent();

                if (content.getActualInstance() instanceof String) {
                    builder.addUserMessage(content.getString());
                } else {
                    for (ChatCompletionRequestMessageContentPart p : content.getListChatCompletionRequestMessageContentPart()) {
                        if (p.getActualInstance() instanceof ChatCompletionRequestMessageContentPartText) {
                            builder.addUserMessage(p.getChatCompletionRequestMessageContentPartText().getText());
                        } else {
                            // We don't support other types of content... yet...
                            return new ResponseEntity<>(HttpStatus.NOT_IMPLEMENTED);
                        }
                    }
                }
            } else if (m.getActualInstance() instanceof ChatCompletionRequestSystemMessage) {
                builder.addSystemMessage(m.getChatCompletionRequestSystemMessage().getContent());
            } else if (m.getActualInstance() instanceof ChatCompletionRequestAssistantMessage) {
                builder.addAssistantMessage(m.getChatCompletionRequestAssistantMessage().getContent());
            } else {
                return new ResponseEntity<>(HttpStatus.NOT_IMPLEMENTED);
            }
        }

        float temperature = request.getTemperature() == null ? 0.3f : request.getTemperature().floatValue();
        int maxTokens = request.getMaxTokens() == null ? 1024 : request.getMaxTokens();

        AtomicInteger index = new AtomicInteger(0);
        if (request.getStream() != null && request.getStream()) {
            SseEmitter emitter = new SseEmitter(-1L);
            CompletableFuture.supplyAsync(() -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> {
                try {
                    emitter.send(
                        new CreateChatCompletionStreamResponse().id(sessionId.toString())
                            .choices(
                                List.of(
                                    new CreateChatCompletionStreamResponseChoicesInner().index(index.getAndIncrement())
                                        .delta(new ChatCompletionStreamResponseDelta().content(t))
                                )
                            )
                    );
                } catch (IOException e) {
                    emitter.completeWithError(e);
                }
            })).handle((r, ex) -> {
                try {
                    emitter.send(
                        new CreateChatCompletionStreamResponse().id(sessionId.toString())
                            .choices(
                                List.of(
                                    new CreateChatCompletionStreamResponseChoicesInner().finishReason(
                                        CreateChatCompletionStreamResponseChoicesInner.FinishReasonEnum.STOP
                                    )
                                )
                            )
                    );

                    emitter.complete();

                    logger.info("Completed streaming response {} tok/sec", r.generatedTokens / (r.generatedTokens / 1000f));
                } catch (IOException e) {
                    emitter.completeWithError(e);
                }

                return null;
            });

            return emitter;
        } else {
            Generator.Response r = model.generate(sessionId, builder.build(), temperature, maxTokens, (s, f) -> {});

            CreateChatCompletionResponse out = new CreateChatCompletionResponse().id(sessionId.toString())
                .choices(
                    List.of(
                        new CreateChatCompletionResponseChoicesInner().finishReason(
                            CreateChatCompletionResponseChoicesInner.FinishReasonEnum.STOP
                        ).message(new ChatCompletionResponseMessage().content(r.responseText))
                    )
                );

            return new ResponseEntity<>(out, HttpStatus.OK);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy