All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fi.evolver.ai.vaadin.admin.ChatReportGenerator Maven / Gradle / Ivy

There is a newer version: 1.5.5
Show newest version
package fi.evolver.ai.vaadin.admin;

import java.io.IOException;
import java.io.OutputStream;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFFont;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;

import fi.evolver.ai.vaadin.entity.Chat;
import fi.evolver.ai.vaadin.entity.ChatMessage;
import fi.evolver.ai.vaadin.entity.ChatMessage.ChatMessageRole;

public class ChatReportGenerator {
	private static final List EXPORT_HEADERS = List.of(
			new HeaderConfig("Chat date time"),
			new HeaderConfig("First user message", 30),
			new HeaderConfig("First assistant response", 30),
			new HeaderConfig("Summary", 30),
			new HeaderConfig("View", 20),
			new HeaderConfig("LLM model", 14),
			new HeaderConfig("Chat messages", 30),
			new HeaderConfig("User message count"),
			new HeaderConfig("Assistant message count"),
			new HeaderConfig("Total message count"),
			new HeaderConfig("Token count"),
			new HeaderConfig("Chat rating")
			);

	private static final List ANALYTICS_HEADERS = List.of(
			new HeaderConfig("Start date", 20),
			new HeaderConfig("End date", 20),
			new HeaderConfig("View", 20),
			new HeaderConfig("LLM model", 14),
			new HeaderConfig("Chat count"),
			new HeaderConfig("Distinct users"),
			new HeaderConfig("Average user message count"),
			new HeaderConfig("Average assistant message count"),
			new HeaderConfig("Average total message count"),
			new HeaderConfig("Total token count"),
			new HeaderConfig("Average chat rating")
			);


	private ChatReportGenerator() {}


	public static void generateChatReport(List chats, LocalDate startDate, LocalDate endDate, OutputStream out) throws IOException {
		try (Workbook workbook = new XSSFWorkbook()) {
			CellFormat cellFormat = new CellFormat(workbook);

			generateChatExportSheet(chats, workbook, cellFormat);
			generateChatAnalyticsSheet(chats, startDate, endDate, workbook, cellFormat);

			workbook.write(out);
		}
	}


	private static void generateChatAnalyticsSheet(List chats, LocalDate startDate, LocalDate endDate, Workbook workbook, CellFormat cellFormat) {
		Sheet sheet = workbook.createSheet("Chat analytics");
		generateHeaderRow(ANALYTICS_HEADERS, sheet, cellFormat.headerStyle());

		TreeMap>> chatsByTypeAndModel = new TreeMap<>(chats.stream()
				.collect(Collectors.groupingBy(Chat::getChatType, Collectors.groupingBy(ChatReportGenerator::getLlmModel))));

		int rowIndex = 0;
		for (String type : chatsByTypeAndModel.keySet()) {
			for (String model : chatsByTypeAndModel.get(type).keySet()) {
				List chatEntries = chatsByTypeAndModel.get(type).get(model);

				int averageUserMessageCount = calculateAverageMessageCount(chatEntries, ChatMessageRole.USER);
				int averageAssistantMessageCount = calculateAverageMessageCount(chatEntries, ChatMessageRole.ASSISTANT);
				createContentCells(
						sheet.createRow(rowIndex + 1),
						List.of(startDate,
								endDate,
								type,
								model,
								chatEntries.size(),
								chatEntries.stream().map(Chat::getUsername).collect(Collectors.toSet()).size(),
								averageUserMessageCount,
								averageAssistantMessageCount,
								averageUserMessageCount + averageAssistantMessageCount,
								calculateTotalTokenCount(chatEntries),
								calculateAverageChatRating(chatEntries)),
						cellFormat);

				rowIndex++;
			}
		}
	}


	private static int calculateTotalTokenCount(List chats) {
		return chats.stream()
			.map(chat -> chat.getChatMessages().stream()
					.filter(cm -> cm.getRole() == ChatMessageRole.USER || cm.getRole() == ChatMessageRole.ASSISTANT)
					.toList())
			.flatMap(List::stream)
			.mapToInt(cm -> cm.getTokenCount() != null ? cm.getTokenCount() : 0)
			.sum();
	}


	private static int calculateAverageMessageCount(List chats, ChatMessageRole role) {
		return (int) Math.round(chats.stream()
				.map(chat -> chat.getChatMessages().stream()
						.filter(cm -> cm.getRole() == role)
						.count())
				.mapToDouble(Long::doubleValue)
				.average()
				.orElse(0));
	}


	private static String getLlmModel(Chat chat) {
		return chat.getChatMessages().stream()
				.filter(cm -> cm.getRole() == ChatMessageRole.USER)
				.findFirst()
				.map(ChatMessage::getModel)
				.orElse("");
	}

	private static double calculateAverageChatRating(List chats) {
		return chats.stream()
			.map(Chat::getChatRating)
			.filter(Objects::nonNull)
			.mapToDouble(Integer::doubleValue)
			.average()
			.orElse(0);
	}


	private static void generateChatExportSheet(List chats, Workbook workbook, CellFormat cellFormat) {
		Sheet sheet = workbook.createSheet("Chat export");
		generateHeaderRow(EXPORT_HEADERS, sheet, cellFormat.headerStyle());

		for (int i = 0; i < chats.size(); i++) {
			Chat chat = chats.get(i);
			Row row = sheet.createRow(i + 1);
			row.setHeightInPoints(30);

			List userMessages = chat.getChatMessages().stream()
					.filter(cm -> cm.getRole() == ChatMessageRole.USER)
					.collect(Collectors.toList());
			List assistantMessages = chat.getChatMessages().stream()
					.filter(cm -> cm.getRole() == ChatMessageRole.ASSISTANT)
					.collect(Collectors.toList());

			createContentCells(
					row,
					List.of(!userMessages.isEmpty() ? userMessages.get(0).getSendTime() : "",
							!userMessages.isEmpty() ? userMessages.get(0).getMessage() : "",
							!assistantMessages.isEmpty() ? assistantMessages.get(0).getMessage() : "",
							chat.getSummary(),
							chat.getChatType(),
							!userMessages.isEmpty() ? Optional.ofNullable(userMessages.get(0).getModel()).orElse("") : "",
							printChatMessages(chat),
							userMessages.size(),
							assistantMessages.size(),
							userMessages.size() + assistantMessages.size(),
							calculateTokenCount(userMessages, assistantMessages),
							Optional.ofNullable(chat.getChatRating())),
					cellFormat);
		}
	}


	private static void createContentCells(Row row, List data, CellFormat cellFormat) {
		for (int i = 0; i < data.size(); i++)
			createContentCell(row, data.get(i), i, cellFormat);
	}


	private static String printChatMessages(Chat chat) {
		return chat.getChatMessages().stream()
				.map(cm -> "%s\n%s".formatted(getRoleHeader(cm.getRole()), cm.getMessage()))
				.collect(Collectors.joining("\n\n"));
	}


	private static String getRoleHeader(ChatMessageRole role) {
		return switch (role) {
			case ASSISTANT	-> "***ASSISTANT_MESSAGE***";
			case SYSTEM		-> "***SYSTEM_MESSAGE***";
			case USER		-> "***USER_MESSAGE***";
			case ERROR		-> "***ERROR_MESSAGE***";
			default			-> throw new IllegalArgumentException("Unknown ChatMessageRole role: %s".formatted(role));
		};
	}


	private static int calculateTokenCount(List userMessages, List assistantMessages) {
		return Stream.concat(userMessages.stream(), assistantMessages.stream())
				.filter(cm -> cm.getTokenCount() != null)
				.mapToInt(ChatMessage::getTokenCount)
				.sum();
	}

	private static void createContentCell(Row row, Object value, int contentCellIndex, CellFormat cellFormat) {
		Cell cell = row.createCell(contentCellIndex);

		if (value instanceof Optional optionalValue) {
			optionalValue.ifPresent(val -> createContentCell(row, val, contentCellIndex, cellFormat));
			return;
		}

		if (value instanceof LocalDate) {
			cell.setCellValue((LocalDate) value);
			cell.setCellStyle(cellFormat.dateStyle());
			return;
		}
		else if (value instanceof LocalDateTime) {
			cell.setCellValue((LocalDateTime) value);
			cell.setCellStyle(cellFormat.dateTimeStyle());
			return;
		}

		if (value instanceof Integer)
			cell.setCellValue(Double.valueOf((int) value));
		else if (value instanceof Long)
			cell.setCellValue(Double.valueOf((long) value));
		else if (value instanceof Double)
			cell.setCellValue((double)value);
		else
			cell.setCellValue(cleanTextValue(value));

		cell.setCellStyle(cellFormat.contentStyle());
	}


	private static String cleanTextValue(Object value) {
		if (value == null)
			return null;

		String text = String.valueOf(value);
		return text.substring(0, Math.min(text.length(), 32767));
	}


	private static void generateHeaderRow(List headers, Sheet sheet, CellStyle cellStyle) {
		Row header = sheet.createRow(0);

		for (int i = 0; i < headers.size(); i++) {
			Cell headerCell = header.createCell(i);
			headerCell.setCellValue(headers.get(i).header());
			headerCell.setCellStyle(cellStyle);

			HeaderConfig config = headers.get(i);
			if (config.width() != null)
				sheet.setColumnWidth(i, config.width() * 256);
			else
				sheet.autoSizeColumn(i);
		}
	}


	private record CellFormat(CellStyle headerStyle, CellStyle contentStyle, CellStyle dateStyle, CellStyle dateTimeStyle) {

		public CellFormat(Workbook workbook) {
			this(createHeaderStyle(workbook),
				createContentStyle(workbook),
				createDateStyle(workbook, "yyyy-mm-dd"),
				createDateStyle(workbook, "yyyy-mm-dd hh:mm:ss"));
		}


		private static CellStyle createDateStyle(Workbook workbook, String format) {
			CreationHelper creationHelper = workbook.getCreationHelper();
			CellStyle dateStyle = workbook.createCellStyle();
			dateStyle.setWrapText(true);
			dateStyle.setVerticalAlignment(VerticalAlignment.TOP);
			dateStyle.setDataFormat(creationHelper.createDataFormat().getFormat(format));

			return dateStyle;
		}


		private static CellStyle createContentStyle(Workbook workbook) {
			CellStyle contentStyle = workbook.createCellStyle();
			contentStyle.setWrapText(true);
			contentStyle.setVerticalAlignment(VerticalAlignment.TOP);

			return contentStyle;
		}


		private static CellStyle createHeaderStyle(Workbook workbook) {
			XSSFFont font = ((XSSFWorkbook) workbook).createFont();
			font.setFontName("Arial");
			font.setFontHeightInPoints((short) 13);
			font.setBold(true);

			CellStyle headerStyle = workbook.createCellStyle();
			headerStyle.setFont(font);

			return headerStyle;
		}

	}


	private record HeaderConfig(String header, Integer width) {
		public HeaderConfig(String header) {
			this(header, null);
		}
	}

}