All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.micronaut.data.hibernate.reactive.operations.DefaultHibernateReactiveRepositoryOperations Maven / Gradle / Ivy

There is a newer version: 4.10.5
Show newest version
/*
 * Copyright 2017-2022 original authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.micronaut.data.hibernate.reactive.operations;

import io.micronaut.aop.InvocationContext;
import io.micronaut.context.annotation.EachBean;
import io.micronaut.context.annotation.Parameter;
import io.micronaut.core.annotation.AnnotationMetadata;
import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.convert.ConversionService;
import io.micronaut.core.type.Argument;
import io.micronaut.data.annotation.QueryHint;
import io.micronaut.data.connection.reactive.ReactorConnectionOperations;
import io.micronaut.data.hibernate.conf.RequiresReactiveHibernate;
import io.micronaut.data.hibernate.operations.AbstractHibernateOperations;
import io.micronaut.data.model.Page;
import io.micronaut.data.model.Pageable;
import io.micronaut.data.model.runtime.DeleteBatchOperation;
import io.micronaut.data.model.runtime.DeleteOperation;
import io.micronaut.data.model.runtime.InsertBatchOperation;
import io.micronaut.data.model.runtime.InsertOperation;
import io.micronaut.data.model.runtime.PagedQuery;
import io.micronaut.data.model.runtime.PreparedQuery;
import io.micronaut.data.model.runtime.RuntimeEntityRegistry;
import io.micronaut.data.model.runtime.RuntimePersistentEntity;
import io.micronaut.data.model.runtime.StoredQuery;
import io.micronaut.data.model.runtime.UpdateBatchOperation;
import io.micronaut.data.model.runtime.UpdateOperation;
import io.micronaut.data.operations.reactive.ReactorCriteriaRepositoryOperations;
import io.micronaut.data.runtime.convert.DataConversionService;
import io.micronaut.transaction.reactive.ReactorReactiveTransactionOperations;
import jakarta.persistence.EntityGraph;
import jakarta.persistence.FlushModeType;
import jakarta.persistence.Tuple;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaDelete;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.CriteriaUpdate;
import org.hibernate.SessionFactory;
import org.hibernate.reactive.stage.Stage;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Collection;
import java.util.function.Function;

/**
 * Hibernate reactive implementation of {@link io.micronaut.data.operations.reactive.ReactiveRepositoryOperations}
 * and {@link ReactorReactiveTransactionOperations}.
 *
 * @author Denis Stepanov
 * @since 3.5.0
 */
@RequiresReactiveHibernate
@EachBean(SessionFactory.class)
@Internal
final class DefaultHibernateReactiveRepositoryOperations extends AbstractHibernateOperations>
        implements HibernateReactorRepositoryOperations, ReactorCriteriaRepositoryOperations {

    private final SessionFactory sessionFactory;
    private final Stage.SessionFactory stageSessionFactory;
    private final ReactiveHibernateHelper helper;
    private final ReactorConnectionOperations connectionOperations;
    private final ReactorReactiveTransactionOperations transactionOperations;

    DefaultHibernateReactiveRepositoryOperations(SessionFactory sessionFactory,
                                                 RuntimeEntityRegistry runtimeEntityRegistry,
                                                 DataConversionService dataConversionService,
                                                 @Parameter ReactorConnectionOperations connectionOperations,
                                                 @Parameter ReactorReactiveTransactionOperations transactionOperations) {
        super(runtimeEntityRegistry, dataConversionService);
        this.sessionFactory = sessionFactory;
        this.stageSessionFactory = sessionFactory.unwrap(Stage.SessionFactory.class);
        this.connectionOperations = connectionOperations;
        this.transactionOperations = transactionOperations;
        this.helper = new ReactiveHibernateHelper(stageSessionFactory);
    }

    @Override
    protected void setParameter(Stage.AbstractQuery query, String parameterName, Object value) {
        query.setParameter(parameterName, value);
    }

    @Override
    protected void setParameter(Stage.AbstractQuery query, String parameterName, Object value, Argument argument) {
        query.setParameter(parameterName, value);
    }

    @Override
    protected void setParameterList(Stage.AbstractQuery query, String parameterName, Collection value) {
        query.setParameter(parameterName, value);
    }

    @Override
    protected void setParameterList(Stage.AbstractQuery query, String parameterName, Collection value, Argument argument) {
        query.setParameter(parameterName, value);
    }

    @Override
    protected void setParameter(Stage.AbstractQuery query, int parameterIndex, Object value) {
        query.setParameter(parameterIndex, value);
    }

    @Override
    protected void setParameter(Stage.AbstractQuery query, int parameterIndex, Object value, Argument argument) {
        query.setParameter(parameterIndex, value);
    }

    @Override
    protected void setParameterList(Stage.AbstractQuery query, int parameterIndex, Collection value) {
        query.setParameter(parameterIndex, value);
    }

    @Override
    protected void setParameterList(Stage.AbstractQuery query, int parameterIndex, Collection value, Argument argument) {
        query.setParameter(parameterIndex, value);
    }

    @Override
    protected void setHint(Stage.SelectionQuery query, String hintName, Object value) {
        if (value instanceof EntityGraph plan) {
            query.setPlan(plan);
            return;
        }
        throw new IllegalStateException("Unrecognized parameter: " + hintName + " with value: " + value);
    }

    @Override
    protected void setMaxResults(Stage.SelectionQuery query, int max) {
        query.setMaxResults(max);
    }

    @Override
    protected void setOffset(Stage.SelectionQuery query, int offset) {
        query.setFirstResult(offset);
    }

    @Override
    protected  EntityGraph getEntityGraph(Stage.Session session, Class entityType, String graphName) {
        return session.getEntityGraph(entityType, graphName);
    }

    @Override
    protected  EntityGraph createEntityGraph(Stage.Session session, Class entityType) {
        return session.createEntityGraph(entityType);
    }

    @Override
    protected  RuntimePersistentEntity getEntity(Class type) {
        return runtimeEntityRegistry.getEntity(type);
    }

    @Override
    public CriteriaBuilder getCriteriaBuilder() {
        return stageSessionFactory.getCriteriaBuilder();
    }

    @Override
    public Mono flush() {
        return withSession(helper::flush);
    }

    @Override
    public Mono persistAndFlush(Object entity) {
        return operation(session -> helper.persist(session, entity).then(helper.flush(session)));
    }

    @Override
    public  Mono findOne(Class type, Object id) {
        return operation(session -> helper.find(session, type, id));
    }

    @Override
    public  Mono exists(PreparedQuery preparedQuery) {
        return findOne(preparedQuery).hasElement();
    }

    @Override
    protected Stage.SelectionQuery createNativeQuery(Stage.Session session, String query, Class resultType) {
        if (resultType == null) {
            return session.createNativeQuery(query);
        }
        return session.createNativeQuery(query, resultType);
    }

    @Override
    protected Stage.SelectionQuery createQuery(Stage.Session session, String query, Class resultType) {
        if (resultType == null) {
            return session.createQuery(query);
        }
        return session.createQuery(query, resultType);
    }

    @Override
    protected Stage.SelectionQuery createQuery(Stage.Session session, CriteriaQuery criteriaQuery) {
        return session.createQuery(criteriaQuery);
    }

    @Override
    public  Mono findOne(PreparedQuery preparedQuery) {
        return operation(session -> {
            // TODO: Until this issue https://github.com/hibernate/hibernate-reactive/issues/1551 is fixed
            // we should not limit maxResults or else we could start having bugs
            // FirstResultCollector collector = new FirstResultCollector<>(!preparedQuery.isNative());
            FirstResultCollector collector = new FirstResultCollector<>(false);
            collectFindOne(session, preparedQuery, collector);
            return collector.result;
        });
    }

    @Override
    public  Mono findOptional(Class type, Object id) {
        return findOne(type, id);
    }

    @Override
    public  Mono findOptional(PreparedQuery preparedQuery) {
        return findOne(preparedQuery);
    }

    @Override
    public  Flux findAll(PagedQuery pagedQuery) {
        return operationFlux(session -> findPaged(session, pagedQuery));
    }

    @Override
    public  Mono> findPage(PagedQuery pagedQuery) {
        return operation(session -> findPaged(session, pagedQuery).collectList()
                .flatMap(resultList -> countOf(session, pagedQuery.getRootEntity(), pagedQuery.getPageable())
                        .map(total -> Page.of(resultList, pagedQuery.getPageable(), total))));
    }

    @Override
    public  Mono count(PagedQuery pagedQuery) {
        return operation(session -> countOf(session, Long.class, null));
    }

    private  Flux findPaged(Stage.Session session, PagedQuery pagedQuery) {
        ListResultCollector collector = new ListResultCollector<>();
        collectPagedResults(sessionFactory.getCriteriaBuilder(), session, pagedQuery, collector);
        return collector.result;
    }

    private  Mono countOf(Stage.Session session, Class entity, @Nullable Pageable pageable) {
        SingleResultCollector collector = new SingleResultCollector<>();
        collectCountOf(sessionFactory.getCriteriaBuilder(), session, entity, pageable, collector);
        return collector.result;
    }

    @Override
    public  Flux findAll(PreparedQuery preparedQuery) {
        return operationFlux(session -> {
            ListResultCollector resultCollector = new ListResultCollector<>();
            collectFindAll(session, preparedQuery, resultCollector);
            return resultCollector.result;
        });
    }

    @Override
    public  Mono persist(InsertOperation operation) {
        return operation(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            T entity = operation.getEntity();
            Mono result;
            if (storedQuery != null) {
                result = executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), entity)
                    .thenReturn(entity);
            } else {
                result = helper.persist(session, entity);
            }
            return flushIfNecessary(result, session, operation.getAnnotationMetadata());
        });
    }

    @Override
    public  Mono update(UpdateOperation operation) {
        return operation(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            T entity = operation.getEntity();
            Mono result;
            if (storedQuery != null) {
                result = executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), entity)
                        .thenReturn(entity);
            } else {
                result = helper.merge(session, entity);
            }
            return flushIfNecessary(result, session, operation.getAnnotationMetadata());
        });
    }

    private  Mono executeEntityUpdate(Stage.Session session,
                                                  StoredQuery storedQuery,
                                                  InvocationContext invocationContext,
                                                  T entity) {
        Stage.MutationQuery query = session.createMutationQuery(storedQuery.getQuery());
        bindParameters(query, storedQuery, invocationContext, true, entity);
        return helper.executeUpdate(query);
    }

    @Override
    public  Flux updateAll(UpdateBatchOperation operation) {
        return operationFlux(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            Flux result;
            if (storedQuery != null) {
                result = Flux.fromIterable(operation)
                        .concatMap(t -> executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), t).thenReturn(t));
            } else {
                result = helper.mergeAll(session, operation);
            }
            return flushIfNecessaryFlux(result, session, operation.getAnnotationMetadata());
        });
    }

    @Override
    public  Flux persistAll(InsertBatchOperation operation) {
        return operationFlux(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            Flux result;
            if (storedQuery != null) {
                result = Flux.fromIterable(operation)
                    .concatMap(t -> executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), t).thenReturn(t));
            } else {
                result = helper.persistAll(session, operation);
            }
            return flushIfNecessaryFlux(result, session, operation.getAnnotationMetadata());
        });
    }

    @Override
    public Mono executeUpdate(PreparedQuery preparedQuery) {
        return operation(session -> {
            String query = preparedQuery.getQuery();
            Stage.MutationQuery q = session.createMutationQuery(query);
            bindParameters(q, preparedQuery, true);
            Mono result = helper.executeUpdate(q).cast(Number.class);
            return flushIfNecessary(result, session, preparedQuery.getAnnotationMetadata());
        });
    }

    @Override
    public Mono executeDelete(PreparedQuery preparedQuery) {
        return executeUpdate(preparedQuery);
    }

    @Override
    public  Mono delete(DeleteOperation operation) {
        return operation(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            Mono result;
            if (storedQuery != null) {
                result = executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), operation.getEntity()).cast(Number.class);
            } else {
                result = helper.remove(session, operation.getEntity()).thenReturn(1);
            }
            return flushIfNecessary(result, session, operation.getAnnotationMetadata());
        });
    }

    @Override
    public  Mono deleteAll(DeleteBatchOperation operation) {
        return operation(session -> {
            StoredQuery storedQuery = operation.getStoredQuery();
            Mono result;
            if (storedQuery != null) {
                result = Flux.fromIterable(operation)
                        .concatMap(entity -> executeEntityUpdate(session, storedQuery, operation.getInvocationContext(), entity))
                        .reduce(0, (i1, i2) -> i1 + i2)
                        .cast(Number.class);
            } else {
                result = helper.removeAll(session, operation);
            }
            return flushIfNecessary(result, session, operation.getAnnotationMetadata());
        });
    }

    private  Mono flushIfNecessary(Mono m, Stage.Session session, AnnotationMetadata annotationMetadata) {
        if (annotationMetadata.hasAnnotation(QueryHint.class)) {
            FlushModeType flushModeType = getFlushModeType(annotationMetadata);
            if (flushModeType == FlushModeType.AUTO) {
                return m.flatMap(t -> helper.flush(session).thenReturn(t));
            }
        }
        return m;
    }

    private  Flux flushIfNecessaryFlux(Flux flux, Stage.Session session, AnnotationMetadata annotationMetadata) {
        return flushIfNecessary(flux.collectList(), session, annotationMetadata).flatMapMany(Flux::fromIterable);
    }

    private  Mono operation(Function> work) {
        return transactionOperations.withTransactionMono(tx -> work.apply(tx.getConnection()));
    }

    private  Flux operationFlux(Function> work) {
        return transactionOperations.withTransactionFlux(tx -> work.apply(tx.getConnection()));
    }

    @Override
    public  Mono withSession(Function> work) {
        return connectionOperations.withConnectionMono(status -> work.apply(status.getConnection()));
    }

    @Override
    public  Flux withSessionFlux(Function> work) {
        return connectionOperations.withConnectionFlux(status -> work.apply(status.getConnection()));
    }

    @Override
    public ConversionService getConversionService() {
        return dataConversionService;
    }

    @Override
    public  Mono findOne(CriteriaQuery query) {
        return withSession(session -> helper.monoFromCompletionStage(() -> session.createQuery(query).getSingleResult()));
    }

    @Override
    public  Flux findAll(CriteriaQuery query) {
        return withSession(session -> helper.monoFromCompletionStage(() -> session.createQuery(query).getResultList()))
            .flatMapIterable(res -> res);
    }

    @Override
    public  Flux findAll(CriteriaQuery query, int offset, int limit) {
        return withSession(session -> helper.monoFromCompletionStage(() -> {
            Stage.SelectionQuery sessionQuery = session.createQuery(query);
            if (offset != -1) {
                sessionQuery = sessionQuery.setFirstResult(offset);
            }
            if (limit != -1) {
                sessionQuery = sessionQuery.setMaxResults(limit);
            }
            return sessionQuery.getResultList();
        })).flatMapIterable(res -> res);
    }

    @Override
    public Mono updateAll(CriteriaUpdate query) {
        return withSession(session -> helper.monoFromCompletionStage(() -> session.createQuery(query).executeUpdate()).map(n -> n));
    }

    @Override
    public Mono deleteAll(CriteriaDelete query) {
        return withSession(session -> helper.monoFromCompletionStage(() -> session.createQuery(query).executeUpdate()).map(n -> n));
    }

    private final class ListResultCollector extends ResultCollector {

        private Flux result;

        @Override
        protected void collectTuple(Stage.SelectionQuery query, Function fn) {
            Flux tuples = (Flux) helper.list(query);
            result = tuples.map(fn);
        }

        @Override
        protected void collect(Stage.SelectionQuery query) {
            result = (Flux) helper.list(query);
        }
    }

    private final class SingleResultCollector extends ResultCollector {

        private Mono result;

        @Override
        protected void collectTuple(Stage.SelectionQuery query, Function fn) {
            result = ((Mono) helper.singleResult(query)).map(fn);
        }

        @Override
        protected void collect(Stage.SelectionQuery query) {
            result = (Mono) helper.singleResult(query);
        }

    }

    private final class FirstResultCollector extends ResultCollector {

        private final boolean limitOne;
        private Mono result;

        private FirstResultCollector(boolean limitOne) {
            this.limitOne = limitOne;
        }

        @Override
        protected void collectTuple(Stage.SelectionQuery query, Function fn) {
            result = getFirst((Stage.SelectionQuery) query).map(fn);
        }

        @Override
        protected void collect(Stage.SelectionQuery query) {
            result = getFirst((Stage.SelectionQuery) query);
        }

        private  Mono getFirst(Stage.SelectionQuery q) {
            if (limitOne) {
                q.setMaxResults(1);
            }
            return helper.list(q).next();
        }

    }

}