All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.opensearch.ml.common.conversation.Interaction Maven / Gradle / Ivy

There is a newer version: 2.17.1.0
Show newest version
/*
 * Copyright 2023 Aryn
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.opensearch.ml.common.conversation;

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.search.SearchHit;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;

/**
 * Class for dealing with Interactions
 */
@Builder
@AllArgsConstructor
public class Interaction implements Writeable, ToXContentObject {

    @Getter
    private String id;
    @Getter
    private Instant createTime;
    @Getter
    private String conversationId;
    @Getter
    private String input;
    @Getter
    private String promptTemplate;
    @Getter
    private String response;
    @Getter
    private String origin;
    @Getter
    private Map additionalInfo;
    @Getter
    private String parentInteractionId;
    @Getter
    private Integer traceNum;

    @Builder(toBuilder = true)
    public Interaction(
        String id,
        Instant createTime,
        String conversationId,
        String input,
        String promptTemplate,
        String response,
        String origin,
        Map additionalInfo
    ) {
        this.id = id;
        this.createTime = createTime;
        this.conversationId = conversationId;
        this.input = input;
        this.promptTemplate = promptTemplate;
        this.response = response;
        this.origin = origin;
        this.additionalInfo = additionalInfo;
        this.parentInteractionId = null;
        this.traceNum = null;
    }

    /**
     * Creates an Interaction object from a map of fields in the OS index
     * @param id the Interaction id
     * @param fields the field mapping from the OS document
     * @return a new Interaction object representing the OS document
     */
    public static Interaction fromMap(String id, Map fields) {
        Instant createTime = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD));
        String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD);
        String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD);
        String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
        String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
        String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
        Map additionalInfo = (Map) fields
            .get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
        String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null);
        Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null);
        return new Interaction(
            id,
            createTime,
            conversationId,
            input,
            promptTemplate,
            response,
            origin,
            additionalInfo,
            parentInteractionId,
            traceNum
        );
    }

    /**
     * Creates an Interaction object from a search hit
     * @param hit the search hit from the interactions index
     * @return a new Interaction object representing the search hit
     */
    public static Interaction fromSearchHit(SearchHit hit) {
        String id = hit.getId();
        return fromMap(id, hit.getSourceAsMap());
    }

    /**
     * Creates a new Interaction object from a stream
     * @param in stream to read from; assumes Interactions.writeTo was called on it
     * @return a new Interaction
     * @throws IOException if can't read or the stream isn't pointing to an intraction or something
     */
    public static Interaction fromStream(StreamInput in) throws IOException {
        String id = in.readString();
        Instant createTime = in.readInstant();
        String conversationId = in.readString();
        String input = in.readString();
        String promptTemplate = in.readString();
        String response = in.readString();
        String origin = in.readString();
        Map additionalInfo = new HashMap<>();
        if (in.readBoolean()) {
            additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
        }
        String parentInteractionId = in.readOptionalString();
        Integer traceNum = in.readOptionalInt();
        return new Interaction(
            id,
            createTime,
            conversationId,
            input,
            promptTemplate,
            response,
            origin,
            additionalInfo,
            parentInteractionId,
            traceNum
        );
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(id);
        out.writeInstant(createTime);
        out.writeString(conversationId);
        out.writeString(input);
        out.writeString(promptTemplate);
        out.writeString(response);
        out.writeString(origin);
        if (additionalInfo != null) {
            out.writeBoolean(true);
            out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
        } else {
            out.writeBoolean(false);
        }
        out.writeOptionalString(parentInteractionId);
        out.writeOptionalInt(traceNum);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException {
        builder.startObject();
        builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId);
        builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id);
        builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime);
        builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
        builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
        builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
        builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
        if (additionalInfo != null) {
            builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
        }
        if (parentInteractionId != null) {
            builder.field(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
        }
        if (traceNum != null) {
            builder.field(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNum);
        }
        builder.endObject();
        return builder;
    }

    @Override
    public boolean equals(Object other) {
        return (other instanceof Interaction
            && ((Interaction) other).id.equals(this.id)
            && ((Interaction) other).conversationId.equals(this.conversationId)
            && ((Interaction) other).createTime.equals(this.createTime)
            && ((Interaction) other).input.equals(this.input)
            && ((Interaction) other).promptTemplate.equals(this.promptTemplate)
            && ((Interaction) other).response.equals(this.response)
            && ((Interaction) other).origin.equals(this.origin)
            && ((((Interaction) other).additionalInfo == null && this.additionalInfo == null)
                || ((Interaction) other).additionalInfo.equals(this.additionalInfo))
            && ((((Interaction) other).parentInteractionId == null && this.parentInteractionId == null)
                || ((Interaction) other).parentInteractionId.equals(this.parentInteractionId))
            && ((((Interaction) other).traceNum == null && this.traceNum == null) || ((Interaction) other).traceNum.equals(this.traceNum))

        );
    }

    @Override
    public String toString() {
        return "Interaction{"
            + "id="
            + id
            + ",cid="
            + conversationId
            + ",create_time="
            + createTime
            + ",origin="
            + origin
            + ",input="
            + input
            + ",promt_template="
            + promptTemplate
            + ",response="
            + response
            + ",additional_info="
            + additionalInfo
            + ",parentInteractionId="
            + parentInteractionId
            + ",traceNum="
            + traceNum
            + "}";
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy