fi.evolver.ai.vaadin.entity.ChatMessage Maven / Gradle / Ivy
package fi.evolver.ai.vaadin.entity;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import com.fasterxml.jackson.annotation.JsonIgnore;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.util.TokenUtils;
import fi.evolver.utils.GzipUtils;
import jakarta.persistence.*;
@Entity
@Table(name="chat_message")
public class ChatMessage {
public enum ChatMessageRole {
SYSTEM,
USER,
ASSISTANT,
ERROR
}
@Id
@GeneratedValue(strategy=GenerationType.IDENTITY)
private long id;
@Column(name="send_time")
private LocalDateTime sendTime;
@Enumerated(EnumType.STRING)
@Column(name="role")
private ChatMessageRole role;
@Column(name="message")
private byte[] message;
@Column(name="model")
private String model;
@Column(name="token_count")
private Integer tokenCount;
@ManyToOne(fetch=FetchType.LAZY)
@JoinColumn(name="chat_id")
private Chat chat;
@ManyToOne(fetch=FetchType.EAGER)
@JoinColumn(name="prompt_id")
private Prompt prompt;
public ChatMessage() {}
public ChatMessage(ChatMessageRole role, String message, Prompt prompt, Model> model) {
this.sendTime = LocalDateTime.now();
this.role = role;
this.message = compressData(message);
this.prompt = prompt;
this.model = model != null ? model.name() : null;
this.tokenCount = calculateTokenCount(role, message, prompt, model);
}
public long getId() {
return id;
}
public LocalDateTime getSendTime() {
return sendTime;
}
public ChatMessageRole getRole() {
return role;
}
public String getMessage() {
return readCompressed(message);
}
public String getModel() {
return model;
}
public Integer getTokenCount() {
return tokenCount;
}
public void setTokenCount(Integer tokenCount) {
this.tokenCount = tokenCount;
}
@JsonIgnore
public Chat getChat() {
return chat;
}
public void setChat(Chat chat) {
this.chat = chat;
}
public Prompt getPrompt() {
return prompt;
}
public void setPrompt(Prompt chatPrompt) {
this.prompt = chatPrompt;
}
private static byte[] compressData(String textData) {
return GzipUtils.zip(textData, StandardCharsets.UTF_8);
}
private static String readCompressed(byte[] data) {
return GzipUtils.unzip(data, StandardCharsets.UTF_8);
}
@Override
public String toString() {
return "ChatMessage [id=" + id + ", sendTime=" + sendTime + ", role=" + role + ", message=" +
getMessage() + ", model=" + model + ", tokenCount=" + tokenCount + ", chat=" + chat.getId() +
", prompt=" + prompt + "]";
}
private static Integer calculateTokenCount(ChatMessageRole role, String message, Prompt prompt, Model> model) {
return switch(role) {
case ASSISTANT -> model != null ? TokenUtils.calculateTokens(message, model) : null;
case USER -> prompt != null ? prompt.getTokenCount() : null;
default -> null;
};
}
}