All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.xlrit.gears.engine.snel.JpqlQueryBuilder Maven / Gradle / Ivy

package com.xlrit.gears.engine.snel;

import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.xlrit.gears.base.snel.SnelBaseVisitor;
import com.xlrit.gears.base.snel.SnelLexer;
import com.xlrit.gears.base.snel.SnelParser;
import com.xlrit.gears.engine.meta.*;
import com.xlrit.gears.engine.search.SearchAdapter;
import jakarta.persistence.EntityManager;
import jakarta.persistence.Query;
import jakarta.persistence.Tuple;
import jakarta.persistence.TypedQuery;
import lombok.RequiredArgsConstructor;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.xlrit.gears.base.snel.Operator.*;

@RequiredArgsConstructor
public class JpqlQueryBuilder {
	private static final Logger LOG = LoggerFactory.getLogger(JpqlQueryBuilder.class);
	private static final String SEP = " ";

	private final SearchAdapter searchAdapter;
	private final Map functions = SnelFunctions.asMap;

	private Root root;
	private From current;
	private SelectionMode selectionMode = SelectionMode.STANDARD;
	private final Set aliasses = new HashSet<>();
	private final Map joins = new LinkedHashMap<>();
	private final List selections = new ArrayList<>();
	private final List restrictions = new ArrayList<>();
	private String orderBy = null;

	protected JpqlQueryBuilder(JpqlQueryBuilder builder) {
		this.searchAdapter = builder.searchAdapter;
		this.root = builder.root;
		this.current = builder.current;
		this.selectionMode = builder.selectionMode;
		this.aliasses.addAll(builder.aliasses);
		this.joins.putAll(builder.joins);
		this.selections.addAll(builder.selections);
		this.restrictions.addAll(builder.restrictions);
		this.orderBy = builder.orderBy;
	}

	public JpqlQueryBuilder from(EntityInfo entityInfo) {
		String alias = deriveUniqueAlias(entityInfo.getEntityName());
		root = new Root(entityInfo, alias);
		current = root;
		return this;
	}

	public JpqlQueryBuilder innerJoin(String path, boolean asCurrent) {
		StringBuilder joinExpr = new StringBuilder(current.alias());
		EntityInfo entityInfo = current.entityInfo();

		BaseField field;
		for (String part : path.split("\\.")) {
			field = entityInfo.getField(part);
			entityInfo = field.getAssociatedEntityInfo();
			joinExpr.append(".").append(field.getPropertyName());
		}

		getJoin(joinExpr.toString(), JoinKind.INNER, entityInfo, asCurrent);
		return this;
	}

	private Join getFieldJoin(String left, BaseField field, boolean asCurrent) {
		Preconditions.checkArgument(field.isEntityReference());
		String joinExpr = left + "." + field.getPropertyName();
		EntityInfo entityInfo = field.getAssociatedEntityInfo();
		return getJoin(joinExpr, JoinKind.LEFT, entityInfo, asCurrent);
	}

	private Join getJoin(String joinExpr, JoinKind joinKind, EntityInfo entityInfo, boolean asCurrent) {
		Join join = joins.get(joinExpr);
		if (join == null) {
			String alias = deriveUniqueAlias(entityInfo.getEntityName());
			join = new Join(joinExpr, alias, joinKind, entityInfo);
			joins.put(joinExpr, join);
		}

		if (asCurrent) current = join;
		return join;
	}

	public JpqlQueryBuilder selectAs(String expression, String alias) {
		addSelection(translateExpr(expression) + " AS " + alias);
		return this;
	}

	// maps alias to expression
	public JpqlQueryBuilder selectAs(Map aliassedExpressions) {
		for (Map.Entry entry : aliassedExpressions.entrySet()) {
			String alias = entry.getKey();
			String expression = entry.getValue();
			String selection = translateExpr(expression);
			addSelection(selection + " AS " + alias);
		}
		return this;
	}

	public JpqlQueryBuilder select(String... expressions) {
		return select(Arrays.asList(expressions));
	}

	public JpqlQueryBuilder select(List expressions) {
		for (String expression : expressions) {
			addSelection(translateExpr(expression));
		}
		return this;
	}

	public JpqlQueryBuilder selectRoot() {
		Preconditions.checkState(current != null, "Current must be set");
		addSelection(current.alias());
		return this;
	}

	public JpqlQueryBuilder count(boolean distinct) {
		this.selectionMode = distinct ? SelectionMode.COUNT_DISTINCT : SelectionMode.COUNT;
		return this;
	}

	public JpqlQueryBuilder distinct(boolean value) {
		this.selectionMode = selectionMode.distinct(value);
		return this;
	}

	private void addSelection(String selection) {
		selections.add(selection);
	}

	public JpqlQueryBuilder where(String expression) {
		if (expression != null)
			restrictions.add(translateExpr(expression));
		return this;
	}

	public JpqlQueryBuilder where(List expressions) {
		restrictions.addAll(expressions.stream().map(this::translateExpr).toList());
		return this;
	}

	public JpqlQueryBuilder whereSearch(String mode, List fields) {
		List translatedFields = fields.stream().map(this::translateExpr).toList();
		restrictions.add(searchAdapter.toSearchFilter(mode, translatedFields));
		return this;
	}

	public JpqlQueryBuilder sort(String sort) {
		if (sort != null)
			orderBy = this.translateSortClauses(sort);
		return this;
	}

	public TypedQuery build(EntityManager entityManager) {
		return build(entityManager, Tuple.class);
	}

	public  TypedQuery build(EntityManager entityManager, Class resultClass) {
		String qlString = toQlString();
		LOG.info("qlString={}", qlString);
		return entityManager.createQuery(qlString, resultClass);
	}

	public JpqlQueryBuilder register(EntityManager entityManager, String name) {
		String qlString = toQlString();
		LOG.info("register: {} -> {}", name, qlString);
		Query query = entityManager.createQuery(qlString);
		entityManager.getEntityManagerFactory().addNamedQuery(name, query);
		return this;
	}

	public String toQlString() {
		StringBuilder sb = new StringBuilder();
		sb.append("SELECT ");
		appendSelections(sb);
		sb.append(SEP).append(root.render());
		for (Join join : joins.values()) {
			sb.append(SEP).append(join.render());
		}
		if (!restrictions.isEmpty()) {
			sb.append(SEP).append("WHERE ");
			sb.append(restrictions.stream().map(r -> "(" + r + ")").collect(Collectors.joining(" AND ")));
		}
		if (orderBy != null) {
			sb.append(SEP).append("ORDER BY ").append(orderBy);
		}
		return sb.toString();
	}

	private void appendSelections(StringBuilder sb) {
		final String selStr = String.join(", ", selections);
		switch (selectionMode) {
			case STANDARD       -> sb.append(selStr);
			case DISTINCT       -> sb.append("DISTINCT ").append(selStr);
			case COUNT          -> sb.append("COUNT(").append(selStr).append(")");
			case COUNT_DISTINCT -> sb.append("COUNT(DISTINCT ").append(selStr).append(")");
		};
	}

	private String deriveUniqueAlias(String base) {
		char ch = Character.toLowerCase(base.charAt(0));
		for (int i = 0; i < Integer.MAX_VALUE; i++) {
			String result = String.valueOf(ch) + i;
			if (!aliasses.contains(result)) {
				aliasses.add(result);
				return result;
			}
		}
		throw new IllegalStateException();
	}

	public String translateExpr(String s) {
		return translateToFragment(s, SnelParser::startExpr).text();
	}

	private String translateSortClauses(String s) {
		return translateToFragment(s, SnelParser::sortClauses).text();
	}

	private Fragment translateToFragment(String s, Function production) {
		CharStream input = CharStreams.fromString(s);
		SnelParser parser = new SnelParser(new CommonTokenStream(new SnelLexer(input)));
		ParseTree parseTree = production.apply(parser);
		JpqlVisitor jpqlVisitor = new JpqlVisitor();
		Fragment fragment = jpqlVisitor.visit(parseTree);
		if (fragment == null) throw new RuntimeException("Unable to parse SNEL expression: `" + s + "`");
		return fragment;
	}

	private static BaseField getField(TypeInfo leftType, String name) {
		if (!(leftType instanceof ObjectInfo objectInfo)) {
			throw new EvaluatorException("Type '" + leftType.getTypeName() + "' has no fields");
		}
		return objectInfo.getField(name);
	}

	public JpqlQueryBuilder copy() {
		return new JpqlQueryBuilder(this);
	}

	@RequiredArgsConstructor
	class JpqlVisitor extends SnelBaseVisitor {

		private Fragment typed(TypeInfo type, int precedence, String text) {
			return new Fragment(type, precedence, text);
		}

		private Fragment typed(TypeInfo type, String text) {
			return new Fragment(type, -1, text);
		}

		private Fragment untyped(String text) {
			return new Fragment(null, -1, text);
		}

		// === expresssions === //

		@Override
		public Fragment visitStartExpr(SnelParser.StartExprContext ctx) {
			return visit(ctx.expr());
		}

		@Override
		public Fragment visitPrefixExpr(SnelParser.PrefixExprContext ctx) {
			String op = ctx.prefixOp().getText();  // '-' | 'not'
			Fragment x = visit(ctx.expr());
			return switch (op) {
				case NEG  -> prefixOp(x.type(),           JpqlOperator.UNARY_MINUS, x);
				case NOT  -> prefixOp(BasicTypes.BOOLEAN, JpqlOperator.LOGICAL_NOT, x);
				default   -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitPostfixExpr(SnelParser.PostfixExprContext ctx) {
			String op = ctx.postfixOp().getText();  // 'exists' | 'does not exist'
			Fragment unjoined = visit(ctx.expr(), false);
			if (unjoined.type() instanceof MultipleInfo) {
				return switch (op) {
					case EXISTS  -> postfixOp(BasicTypes.BOOLEAN, unjoined, JpqlOperator.IS_NOT_EMPTY);
					case NEXISTS -> postfixOp(BasicTypes.BOOLEAN, unjoined, JpqlOperator.IS_EMPTY);
					default      -> throw unknownOp(op);
				};
			}
			else {
				Fragment x = visit(ctx.expr());
				return switch (op) {
					case EXISTS  -> postfixOp(BasicTypes.BOOLEAN, x, JpqlOperator.IS_NOT_NULL);
					case NEXISTS -> postfixOp(BasicTypes.BOOLEAN, x, JpqlOperator.IS_NULL);
					default      -> throw unknownOp(op);
				};
			}
		}

		@Override
		public Fragment visitMulExpr(SnelParser.MulExprContext ctx) {
			String op = ctx.mulOp().getText(); // '*' | '/' | '%'
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return switch (op) {
				case MUL  -> binOp(x.type(),           JpqlOperator.MUL, x, y);
				case DIV  -> binOp(x.type(),           JpqlOperator.DIV, x, y);
				case MOD  -> binOp(BasicTypes.INTEGER, JpqlOperator.MOD, x, y);
				default   -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitAddExpr(SnelParser.AddExprContext ctx) {
			String op = ctx.addOp().getText(); // '+' | '-'
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return switch (op) {
				case PLUS  -> binOp(x.type(), JpqlOperator.ADD,  x, y);
				case MINUS -> binOp(x.type(), JpqlOperator.SUB, x, y);
				default    -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitEqExpr(SnelParser.EqExprContext ctx) {
			String op = ctx.eqOp().getText(); // '==' | '<>' | 'like'
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return switch (op) {
				case EQ   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.EQUAL,     x, y);
				case NE   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.NOT_EQUAL, x, y);
				case LIKE -> binOp(BasicTypes.BOOLEAN, JpqlOperator.LIKE,      x, y);
				default   -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitRelExpr(SnelParser.RelExprContext ctx) {
			String op = ctx.relOp().getText(); // '<' | '>' | '<=' | '>='
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return switch (op) {
				case LT   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.LT, x, y);
				case GT   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.GT, x, y);
				case LE   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.LE, x, y);
				case GE   -> binOp(BasicTypes.BOOLEAN, JpqlOperator.GE, x, y);
				default   -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitMemberExpr(SnelParser.MemberExprContext ctx) {
			String op = ctx.memberOp().getText(); // 'in' | 'contains'
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return switch (op) {
				case IN       -> binOp(BasicTypes.BOOLEAN, JpqlOperator.MEMBER_OF, x, y);
				case CONTAINS -> binOp(BasicTypes.BOOLEAN, JpqlOperator.MEMBER_OF, y, x);
				default       -> throw unknownOp(op);
			};
		}

		@Override
		public Fragment visitAndExpr(SnelParser.AndExprContext ctx) {
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return binOp(BasicTypes.BOOLEAN, JpqlOperator.LOGICAL_AND, x, y);
		}

		@Override
		public Fragment visitOrExpr(SnelParser.OrExprContext ctx) {
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return binOp(BasicTypes.BOOLEAN, JpqlOperator.LOGICAL_OR, x, y);
		}

		@Override
		public Fragment visitOtherwiseExpr(SnelParser.OtherwiseExprContext ctx) {
			Fragment x = visit(ctx.expr(0));
			Fragment y = visit(ctx.expr(1));
			return typed(x.type(), "COALESCE(" + x.text() + ", " + y.text() + ")");
		}

		@Override
		public Fragment visitAttrExpr(SnelParser.AttrExprContext ctx) {
			SnelParser.ExprContext leftExpr = ctx.expr();
			Fragment leftFragment = visit(leftExpr);
			String name = ctx.identifier().getText();

			String leftT = leftFragment.text();
			BaseField field = getField(leftFragment.type(), name);

			if (leftExpr instanceof SnelParser.AttrExprContext innerAttrExpr) {
				Fragment innerLeftFragment = visit(innerAttrExpr.expr());
				BaseField innerField = getField(innerLeftFragment.type(), innerAttrExpr.identifier().getText());
				leftT = getFieldJoin(innerLeftFragment.text(), innerField, false).alias;
			}

			return typed(field.getTypeInfo(), leftT + "." + field.getPropertyName());
		}

		@Override
		public Fragment visitListExpression(SnelParser.ListExpressionContext ctx) {
			String elems = ctx.expr().stream()
				.map(this::visit)
				.map(Fragment::text)
				.collect(Collectors.joining(", "));;
			return typed(BasicTypes.ANY, elems);
		}

		@Override
		public Fragment visitFunctionExpression(SnelParser.FunctionExpressionContext ctx) {
			String name = ctx.identifier().getText();
			List params = translateActualParams(ctx.actualParams());
			SnelFunction snelFunction = functions.get(name);
			if (snelFunction == null) throw new RuntimeException("Unsupported function " + name);
			return snelFunction.translate(params);
		}

		private List translateActualParams(SnelParser.ActualParamsContext actualParamsCtx) {
			if (actualParamsCtx == null) {
				return Collections.emptyList();
			}
			return actualParamsCtx.expr().stream()
				.map(this::visit)
				.collect(Collectors.toList());
		}

		@Override
		public Fragment visitParenExpr(SnelParser.ParenExprContext ctx) {
			// we cannot simply pass on the parenthesis as such:
			// return new Fragment(inner.type(), "(" + inner + ")");
			// because JPQL is more strict than SN about where parenthesis are allowed
			return visit(ctx.expr());
		}

		@Override
		public Fragment visitParamExpr(SnelParser.ParamExprContext ctx) {
			return untyped(ctx.parameter().getText());
		}

		@Override
		public Fragment visitIdentExpr(SnelParser.IdentExprContext ctx) {
			Preconditions.checkState(current != null, "current must be set");

			String name = ctx.identifier().getText();
			if ("id".equals(name)) return typed(BasicTypes.TEXT, current.alias() + ".id");

			BaseField field = current.entityInfo().getField(name);
			if (field.isEntityReference() && withJoins) {
				String alias = getFieldJoin(current.alias(), field, false).alias();
				return typed(field.getTypeInfo(), alias);
			}
			return typed(field.getTypeInfo(), current.alias() + "." + field.getPropertyName());
		}

		@Override
		public Fragment visitBooleanLiteral(SnelParser.BooleanLiteralContext ctx) {
			return typed(BasicTypes.BOOLEAN, ctx.getText());
		}

		@Override
		public Fragment visitIntegerLiteral(SnelParser.IntegerLiteralContext ctx) {
			return typed(BasicTypes.INTEGER, ctx.getText());
		}

		@Override
		public Fragment visitStringLiteral(SnelParser.StringLiteralContext ctx) {
			return typed(BasicTypes.TEXT, ctx.getText());
		}

		// === unary/binary expression helpers === //

		private Fragment prefixOp(TypeInfo type, JpqlOperator op, Fragment x) {
			String operand = maybeParen(x.text(), x.precedence(), op.precedence);
			return typed(type, op.precedence, op.symbol + " " + operand);
		}

		private Fragment postfixOp(TypeInfo type, Fragment x, JpqlOperator op) {
			String operand = maybeParen(x.text(), x.precedence(), op.precedence);
			return typed(type, op.precedence, operand + " " + op.symbol);
		}

		private Fragment binOp(TypeInfo type, JpqlOperator op, Fragment x, Fragment y) {
			String leftOperand  = maybeParen(x.text(), x.precedence(), op.precedence);
			String rightOperand = maybeParen(y.text(), y.precedence(), op.precedence);
			return typed(type, op.precedence, leftOperand + " " + op.symbol + " " + rightOperand);
		}

		private RuntimeException unknownOp(String op) {
			return new RuntimeException("Unknown SNEL operator: " + op);
		}

		private String maybeParen(String s, int innerPrecedence, int outerPrecedence) {
			return requiresParens(innerPrecedence, outerPrecedence) ? "(" + s + ")" : s;
		}

		private boolean requiresParens(int innerPrecedence, int outerPrecedence) {
			if (innerPrecedence == -1 || outerPrecedence == -1) return false;
			return innerPrecedence > outerPrecedence;
		}

		// === sort clauses === //

		@Override
		public Fragment visitStartSortClauses(SnelParser.StartSortClausesContext ctx) {
			return visit(ctx.sortClauses());
		}

		@Override
		public Fragment visitSortClauses(SnelParser.SortClausesContext ctx) {
			return untyped(
				ctx.sortClause().stream()
					.map(this::visit)
					.map(Fragment::text)
					.collect(Collectors.joining(", "))
			);
		}

		@Override
		public Fragment visitSortClause(SnelParser.SortClauseContext ctx) {
			StringBuilder sb = new StringBuilder();
			sb.append(visit(ctx.expr()).text());
			if (ctx.sortDirection() != null) {
				sb.append(SEP).append(translateSortDirection(ctx.sortDirection().getText()));
			}
			return untyped(sb.toString());
		}

		private String translateSortDirection(String d) {
			return switch (d) {
				case "ascending"  -> "ASC";
				case "descending" -> "DESC";
				default           -> throw new RuntimeException("Unknown sort direction: '" + d + "'");
			};
		}

		private boolean withJoins = true;
		private Fragment visit(ParseTree tree, boolean withJoins) {
			boolean before = this.withJoins;
			this.withJoins = withJoins;
			Fragment result = visit(tree);
			this.withJoins = before;
			return result;
		}
	}

	private interface From {
		EntityInfo entityInfo();
		String alias();
	}

	private record Root(EntityInfo entityInfo, String alias) implements From {
		public String entityClassName() {
			return entityInfo.getObjectClass().getSimpleName();
		}

		public String render() {
			return "FROM " + entityClassName() + " " + alias;
		}
	}

	private record Join(String expr, String alias, JoinKind kind, EntityInfo entityInfo) implements From {
		public String render() {
			return kind + " " + expr + " " + alias;
		}
	}

	enum JoinKind {
		INNER("JOIN"),
		LEFT("LEFT JOIN");

		private final String code;
		JoinKind(String code) { this.code = code; }

		@Override
		public String toString() {
			return code;
		}
	}

	enum SelectionMode {
		STANDARD,
		DISTINCT,
		COUNT,
		COUNT_DISTINCT;

		public SelectionMode distinct(boolean value) {
			return switch (this) {
				case STANDARD       -> value ? DISTINCT : this;
				case DISTINCT       -> value ? this : STANDARD;
				case COUNT          -> value ? COUNT_DISTINCT : this;
				case COUNT_DISTINCT -> value ? this : COUNT;
			};
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy