All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bsc.langgraph4j.studio.LangGraphStreamingServer Maven / Gradle / Ivy

There is a newer version: 1.0-20241113
Show newest version
package org.bsc.langgraph4j.studio;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.ser.std.StdSerializer;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.*;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.serializer.plain_text.PlainTextStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.StateSnapshot;
import org.bsc.langgraph4j.studio.jetty.LangGraphStreamingServerJetty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;

import static java.util.Optional.ofNullable;


public interface LangGraphStreamingServer {

    Logger log = LoggerFactory.getLogger(LangGraphStreamingServer.class);

    CompletableFuture start() throws Exception;

    class NodeOutputSerializer extends StdSerializer  {
        Logger log = LangGraphStreamingServer.log;

        protected NodeOutputSerializer() {
            super( NodeOutput.class );
        }

        @Override
        public void serialize(NodeOutput nodeOutput, JsonGenerator gen, SerializerProvider serializerProvider) throws
                IOException {
            log.trace( "NodeOutputSerializer start! {}", nodeOutput.getClass() );
            gen.writeStartObject();
            if( nodeOutput instanceof StateSnapshot snapshot) {
                var checkpoint = snapshot.config().checkPointId();
                log.trace( "checkpoint: {}", checkpoint );
                if( checkpoint.isPresent() ) {
                    gen.writeStringField("checkpoint", checkpoint.get());
                }
            }
            gen.writeStringField("node", nodeOutput.node());
            gen.writeObjectField("state", nodeOutput.state().data());
            gen.writeEndObject();
        }
    }

    record PersistentConfig(String sessionId, String threadId) {
        public PersistentConfig {
            Objects.requireNonNull(sessionId);
        }

    }

    class GraphStreamServlet extends HttpServlet {
        Logger log = LangGraphStreamingServer.log;
        final BaseCheckpointSaver saver;

        final StateGraph stateGraph;
        final ObjectMapper objectMapper;
        final Map> graphCache = new HashMap<>();

        public GraphStreamServlet(StateGraph stateGraph,
                                  ObjectMapper objectMapper,
                                  BaseCheckpointSaver saver) {

            Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
            this.stateGraph = stateGraph;
            this.objectMapper = objectMapper;
            var module = new SimpleModule();
            module.addSerializer(NodeOutput.class, new NodeOutputSerializer());
            objectMapper.registerModule(module);
            this.saver = saver;
        }

        private CompileConfig compileConfig(PersistentConfig config) {
            return CompileConfig.builder()
                    .checkpointSaver(saver)
                    //.stateSerializer(stateSerializer)
                    .build();
        }

        RunnableConfig runnableConfig(PersistentConfig config ) {
            return RunnableConfig.builder()
                    .threadId(config.threadId())
                    .build();
        }

        @Override
        protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            response.setHeader("Accept", "application/json");
            response.setContentType("text/plain");
            response.setCharacterEncoding("UTF-8");

            var session = request.getSession(true);
            Objects.requireNonNull(session, "session cannot be null");

            var threadId = ofNullable(request.getParameter("thread"))
                    .orElseThrow(() -> new IllegalStateException("Missing thread id!"));

            var resume = ofNullable(request.getParameter("resume"))
                    .map(Boolean::parseBoolean).orElse(false);

            final PrintWriter writer = response.getWriter();

            // Start asynchronous processing
            var asyncContext = request.startAsync();

            try {

                AsyncGenerator> generator = null;

                var persistentConfig = new PersistentConfig(session.getId(), threadId);

                var compiledGraph = graphCache.get(persistentConfig);

                final Map dataMap;
                if( resume && stateGraph.getStateSerializer() instanceof PlainTextStateSerializer textSerializer  ) {

                    dataMap = textSerializer.read( new InputStreamReader(request.getInputStream()) ).data();
                }
                else {

                    dataMap = objectMapper.readValue(request.getInputStream(), new TypeReference<>() {
                    });
                }

                if( resume ) {

                    log.trace( "RESUME REQUEST PREPARE" );

                    if (compiledGraph == null) {
                        throw new IllegalStateException( "Missing CompiledGraph in session!" );
                    }

                    var checkpointId = ofNullable(request.getParameter("checkpoint"))
                            .orElseThrow(() -> new IllegalStateException("Missing checkpoint id!"));

                    var node = request.getParameter("node");
                    var config = RunnableConfig.builder()
                            .threadId(threadId)
                            .checkPointId(checkpointId)
                            .build();

                    var stateSnapshot = compiledGraph.getState(config);

                    config = stateSnapshot.config();

                    log.trace( "RESUME UPDATE STATE FORM {} USING CONFIG {}\n{}", node, config, dataMap);

                    config = compiledGraph.updateState(config, dataMap, node );

                    log.trace( "RESUME REQUEST STREAM {}", config);

                    generator = compiledGraph.streamSnapshots(null, config);

                }
                else {

                    log.trace( "dataMap: {}", dataMap );

                    if (compiledGraph == null) {
                        compiledGraph = stateGraph.compile(compileConfig(persistentConfig));
                        graphCache.put(persistentConfig, compiledGraph);
                    }

                    generator = compiledGraph.streamSnapshots(dataMap, runnableConfig(persistentConfig));
                }

                generator.forEachAsync(s -> {
                            try {
                                try {
                                    writer.printf("[ \"%s\",", threadId);
                                    writer.println();
                                    var outputAsString = objectMapper.writeValueAsString(s);
                                    writer.println(outputAsString);
                                    writer.println( "]" );
                                } catch (IOException e) {
                                    log.warn("error serializing state", e);
                                }
                                writer.flush();
                                TimeUnit.SECONDS.sleep(1);
                            } catch ( InterruptedException e) {
                                throw new CompletionException(e);
                            }

                        })
                        .thenAccept(v -> writer.close())
                        .thenAccept(v -> asyncContext.complete())
                        .exceptionally(e -> {
                            log.error("Error streaming", e);
                            writer.close();
                            asyncContext.complete();
                            return null;
                        })
                ;

            } catch (Throwable e) {
                log.error("Error streaming", e);
                throw new ServletException(e);
            }
        }
    }

    record ArgumentMetadata(
            String type,
            boolean required) {
    }

    record ThreadEntry( String id, List> entries) {

    }

    record InitData(
            String title,
            String graph,
            Map args,
            List threads
    ) {

        public InitData( String title, String graph, Map args ) {
            this( title, graph, args, List.of(new ThreadEntry("default", List.of())));
        }
    }

    class InitDataSerializer extends StdSerializer {
        Logger log = LangGraphStreamingServer.log;

        protected InitDataSerializer(Class t) {
            super(t);
        }

        @Override
        public void serialize(InitData initData, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
            log.trace( "InitDataSerializer start!" );
            jsonGenerator.writeStartObject();

            jsonGenerator.writeStringField("graph", initData.graph());
            jsonGenerator.writeStringField("title", initData.title());
            jsonGenerator.writeObjectField("args", initData.args());


//            jsonGenerator.writeArrayFieldStart("nodes" );
//            for( var node : initData.nodes() ) {
//                jsonGenerator.writeString(node);
//            }
//            jsonGenerator.writeEndArray();

            jsonGenerator.writeArrayFieldStart("threads" );
            for( var thread : initData.threads() ) {
                jsonGenerator.writeStartArray();
                jsonGenerator.writeString(thread.id());
                jsonGenerator.writeStartArray( thread.entries() );
                jsonGenerator.writeEndArray();
                jsonGenerator.writeEndArray();
            }
            jsonGenerator.writeEndArray();

            jsonGenerator.writeEndObject();
        }
    }

    /**
     * return the graph representation in mermaid format
     */
    class GraphInitServlet extends HttpServlet {

        Logger log = LangGraphStreamingServer.log;

        final StateGraph stateGraph;
        final ObjectMapper objectMapper = new ObjectMapper();
        final InitData initData;

        public GraphInitServlet(StateGraph stateGraph, String title, Map args) {
            Objects.requireNonNull(stateGraph, "stateGraph cannot be null");
            this.stateGraph = stateGraph;

            var module = new SimpleModule();
            module.addSerializer(InitData.class, new InitDataSerializer(InitData.class));
            objectMapper.registerModule(module);

            var graph = stateGraph.getGraph(GraphRepresentation.Type.MERMAID, title, false);

            initData = new InitData( title, graph.getContent(), args );
        }

        @Override
        protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
            response.setContentType("application/json");
            response.setCharacterEncoding("UTF-8");

            String resultJson = objectMapper.writeValueAsString(initData);

            log.trace( "{}", resultJson);

            // Start asynchronous processing
            final PrintWriter writer = response.getWriter();
            writer.println(resultJson);
            writer.close();
        }
    }


}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy