All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.dashscope.aigc.conversation.Conversation Maven / Gradle / Ivy

// Copyright (c) Alibaba, Inc. and its affiliates.
package com.alibaba.dashscope.aigc.conversation;

import static com.alibaba.dashscope.utils.ApiKeywords.EVENT;

import com.alibaba.dashscope.BaseConversation;
import com.alibaba.dashscope.common.*;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.protocol.*;
import com.google.common.collect.Lists;
import io.reactivex.Flowable;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public final class Conversation implements BaseConversation {

  @Setter private long responseTimeout = 30;

  @Getter private final List history = Lists.newCopyOnWriteArrayList();

  private final Protocol protocol;

  public Conversation() {
    protocol = Protocol.HTTP;
  }

  public Conversation(String protocol) {
    this.protocol = Protocol.of(protocol);
  }

  /**
   * Call the server to get the result in the callback function.
   *
   * @param param The input param of class `ConversationParam`.
   * @param callback The callback to receive response, the template class is `ConversationResult`.
   */
  @Override
  public void call(ConversationParam param, ResultCallback callback) {
    final AtomicReference lastResult = new AtomicReference<>();

    class ReactCallback extends ResultCallback {

      public void onOpen(Status status) {
        callback.onOpen(status);
      }

      @Override
      public void onEvent(Response message) {
        try {
          ConversationResult result =
              (ConversationResult) ServiceFacility.prepareResult(protocol, param, message);
          if (protocol == Protocol.HTTP
              || !WebSocketEventType.TASK_STARTED
                  .getValue()
                  .equals(result.getHeaders().get(EVENT))) {
            lastResult.set(result);
            callback.onEvent(result);
          }
        } catch (Exception e) {
          callback.onError(new ApiException(e));
        }
      }

      @Override
      public void onComplete() {
        if (lastResult.get() != null) {
          history.add(
              ChatMessage.builder().role(Role.USER.getValue()).payload(param.getPrompt()).build());
          history.add(
              ChatMessage.builder()
                  .role(Role.BOT.getValue())
                  .payload(lastResult.get().getMessage().getPayload())
                  .build());
        }
        callback.onComplete();
      }

      public void onClose(Status status) {
        callback.onClose(status);
      }

      @Override
      public void onError(Exception e) {
        callback.onError(e);
      }

      public void doClose(Status status) {
        callback.doClose(status);
      }
    }
    if (param.getHistory() == null) {
      param.setHistory(history);
    }

    ServiceFacility.streamingOutWithCallback(
        protocol,
        ServiceFacility.prepareUrl(protocol, param),
        ServiceFacility.prepareHeaders(protocol, param),
        ServiceFacility.prepareRequest(protocol, param, WebSocketEventType.RUN_TASK),
        HttpMethod.POST,
        param.getMode(),
        new ReactCallback());
  }

  /**
   * Call the server to get the result by stream.
   *
   * @param param The input param of class `ConversationParam`.
   * @return A `Flowable` of the output structure.
   */
  @Override
  public Flowable streamCall(ConversationParam param) {
    final AtomicReference lastResult = new AtomicReference<>();
    if (param.getHistory() == null) {
      param.setHistory(history);
    }
    param.setMode(StreamingMode.OUT);
    return ServiceFacility.streamCall(protocol, null, param, responseTimeout)
        .map(
            message -> {
              lastResult.set((ConversationResult) message);
              return (ConversationResult) message;
            })
        .doOnComplete(
            () -> {
              if (lastResult.get() != null) {
                history.add(
                    ChatMessage.builder()
                        .role(Role.USER.getValue())
                        .payload(param.getPrompt())
                        .build());
                history.add(
                    ChatMessage.builder()
                        .role(Role.BOT.getValue())
                        .payload(
                            lastResult.get().getMessage() == null
                                ? null
                                : lastResult.get().getMessage().getPayload())
                        .build());
              }
            });
  }

  /**
   * Call the server to get the whole result.
   *
   * @param param The input param of class `ConversationParam`.
   * @return The output structure of `QWenConversationResult`.
   */
  @Override
  public ConversationResult call(ConversationParam param) throws ApiException {
    try {
      if (param.getHistory() == null && !history.isEmpty()) {
        param.setHistory(history);
      }
      param.setMode(StreamingMode.NONE);
      ConversationResult result =
          (ConversationResult)
              ServiceFacility.call(protocol, null, param, HttpMethod.POST, responseTimeout);
      history.add(
          ChatMessage.builder().role(Role.USER.getValue()).payload(param.getPrompt()).build());
      history.add(
          ChatMessage.builder()
              .role(Role.BOT.getValue())
              .payload(result.getMessage().getPayload())
              .build());
      return result;
    } catch (Exception e) {
      throw e instanceof ApiException ? (ApiException) e : new ApiException(e);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy