All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.dataloader.DataLoaderHelper Maven / Gradle / Ivy

There is a newer version: 2022-09-12T23-25-35-08559ba
Show newest version
package org.dataloader;

import org.dataloader.annotations.GuardedBy;
import org.dataloader.annotations.Internal;
import org.dataloader.impl.CompletableFutureKit;
import org.dataloader.scheduler.BatchLoaderScheduler;
import org.dataloader.stats.StatisticsCollector;
import org.dataloader.stats.context.IncrementBatchLoadCountByStatisticsContext;
import org.dataloader.stats.context.IncrementBatchLoadExceptionCountStatisticsContext;
import org.dataloader.stats.context.IncrementCacheHitCountStatisticsContext;
import org.dataloader.stats.context.IncrementLoadCountStatisticsContext;
import org.dataloader.stats.context.IncrementLoadErrorCountStatisticsContext;

import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.stream.Collectors.toList;
import static org.dataloader.impl.Assertions.assertState;
import static org.dataloader.impl.Assertions.nonNull;

/**
 * This helps break up the large DataLoader class functionality, and it contains the logic to dispatch the
 * promises on behalf of its peer dataloader.
 *
 * @param  the type of keys
 * @param  the type of values
 */
@Internal
class DataLoaderHelper {

    static class LoaderQueueEntry {

        final K key;
        final V value;
        final Object callContext;

        public LoaderQueueEntry(K key, V value, Object callContext) {
            this.key = key;
            this.value = value;
            this.callContext = callContext;
        }

        K getKey() {
            return key;
        }

        V getValue() {
            return value;
        }

        Object getCallContext() {
            return callContext;
        }
    }

    private final DataLoader dataLoader;
    private final Object batchLoadFunction;
    private final DataLoaderOptions loaderOptions;
    private final CacheMap futureCache;
    private final ValueCache valueCache;
    private final List>> loaderQueue;
    private final StatisticsCollector stats;
    private final Clock clock;
    private final AtomicReference lastDispatchTime;

    DataLoaderHelper(DataLoader dataLoader,
                     Object batchLoadFunction,
                     DataLoaderOptions loaderOptions,
                     CacheMap futureCache,
                     ValueCache valueCache,
                     StatisticsCollector stats,
                     Clock clock) {
        this.dataLoader = dataLoader;
        this.batchLoadFunction = batchLoadFunction;
        this.loaderOptions = loaderOptions;
        this.futureCache = futureCache;
        this.valueCache = valueCache;
        this.loaderQueue = new ArrayList<>();
        this.stats = stats;
        this.clock = clock;
        this.lastDispatchTime = new AtomicReference<>();
        this.lastDispatchTime.set(now());
    }

    Instant now() {
        return clock.instant();
    }

    public Instant getLastDispatchTime() {
        return lastDispatchTime.get();
    }

    Optional> getIfPresent(K key) {
        synchronized (dataLoader) {
            boolean cachingEnabled = loaderOptions.cachingEnabled();
            if (cachingEnabled) {
                Object cacheKey = getCacheKey(nonNull(key));
                if (futureCache.containsKey(cacheKey)) {
                    stats.incrementCacheHitCount(new IncrementCacheHitCountStatisticsContext<>(key));
                    return Optional.of(futureCache.get(cacheKey));
                }
            }
        }
        return Optional.empty();
    }

    Optional> getIfCompleted(K key) {
        synchronized (dataLoader) {
            Optional> cachedPromise = getIfPresent(key);
            if (cachedPromise.isPresent()) {
                CompletableFuture promise = cachedPromise.get();
                if (promise.isDone()) {
                    return cachedPromise;
                }
            }
        }
        return Optional.empty();
    }


    CompletableFuture load(K key, Object loadContext) {
        synchronized (dataLoader) {
            boolean batchingEnabled = loaderOptions.batchingEnabled();
            boolean cachingEnabled = loaderOptions.cachingEnabled();

            stats.incrementLoadCount(new IncrementLoadCountStatisticsContext<>(key, loadContext));

            if (cachingEnabled) {
                return loadFromCache(key, loadContext, batchingEnabled);
            } else {
                return queueOrInvokeLoader(key, loadContext, batchingEnabled, false);
            }
        }
    }

    Object getCacheKey(K key) {
        return loaderOptions.cacheKeyFunction().isPresent() ?
                loaderOptions.cacheKeyFunction().get().getKey(key) : key;
    }

    Object getCacheKeyWithContext(K key, Object context) {
        return loaderOptions.cacheKeyFunction().isPresent() ?
                loaderOptions.cacheKeyFunction().get().getKeyWithContext(key, context) : key;
    }

    DispatchResult dispatch() {
        boolean batchingEnabled = loaderOptions.batchingEnabled();
        final List keys;
        final List callContexts;
        final List> queuedFutures;
        synchronized (dataLoader) {
            int queueSize = loaderQueue.size();
            if (queueSize == 0) {
                lastDispatchTime.set(now());
                return emptyDispatchResult();
            }

            // we copy the pre-loaded set of futures ready for dispatch
            keys = new ArrayList<>(queueSize);
            callContexts = new ArrayList<>(queueSize);
            queuedFutures = new ArrayList<>(queueSize);

            loaderQueue.forEach(entry -> {
                keys.add(entry.getKey());
                queuedFutures.add(entry.getValue());
                callContexts.add(entry.getCallContext());
            });
            loaderQueue.clear();
            lastDispatchTime.set(now());
        }
        if (!batchingEnabled) {
            return emptyDispatchResult();
        }
        final int totalEntriesHandled = keys.size();
        //
        // order of keys -> values matter in data loader hence the use of linked hash map
        //
        // See https://github.com/facebook/dataloader/blob/master/README.md for more details
        //

        //
        // when the promised list of values completes, we transfer the values into
        // the previously cached future objects that the client already has been given
        // via calls to load("foo") and loadMany(["foo","bar"])
        //
        int maxBatchSize = loaderOptions.maxBatchSize();
        CompletableFuture> futureList;
        if (maxBatchSize > 0 && maxBatchSize < keys.size()) {
            futureList = sliceIntoBatchesOfBatches(keys, queuedFutures, callContexts, maxBatchSize);
        } else {
            futureList = dispatchQueueBatch(keys, callContexts, queuedFutures);
        }
        return new DispatchResult<>(futureList, totalEntriesHandled);
    }

    private CompletableFuture> sliceIntoBatchesOfBatches(List keys, List> queuedFutures, List callContexts, int maxBatchSize) {
        // the number of keys is > than what the batch loader function can accept
        // so make multiple calls to the loader
        int len = keys.size();
        int batchCount = (int) Math.ceil(len / (double) maxBatchSize);
        List>> allBatches = new ArrayList<>(batchCount);
        for (int i = 0; i < batchCount; i++) {

            int fromIndex = i * maxBatchSize;
            int toIndex = Math.min((i + 1) * maxBatchSize, len);

            List subKeys = keys.subList(fromIndex, toIndex);
            List> subFutures = queuedFutures.subList(fromIndex, toIndex);
            List subCallContexts = callContexts.subList(fromIndex, toIndex);

            allBatches.add(dispatchQueueBatch(subKeys, subCallContexts, subFutures));
        }
        //
        // now reassemble all the futures into one that is the complete set of results
        return allOf(allBatches.toArray(new CompletableFuture[0]))
                .thenApply(v -> allBatches.stream()
                        .map(CompletableFuture::join)
                        .flatMap(Collection::stream)
                        .collect(toList()));
    }

    @SuppressWarnings("unchecked")
    private CompletableFuture> dispatchQueueBatch(List keys, List callContexts, List> queuedFutures) {
        stats.incrementBatchLoadCountBy(keys.size(), new IncrementBatchLoadCountByStatisticsContext<>(keys, callContexts));
        CompletableFuture> batchLoad = invokeLoader(keys, callContexts, loaderOptions.cachingEnabled());
        return batchLoad
                .thenApply(values -> {
                    assertResultSize(keys, values);

                    List clearCacheKeys = new ArrayList<>();
                    for (int idx = 0; idx < queuedFutures.size(); idx++) {
                        K key = keys.get(idx);
                        V value = values.get(idx);
                        Object callContext = callContexts.get(idx);
                        CompletableFuture future = queuedFutures.get(idx);
                        if (value instanceof Throwable) {
                            stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
                            future.completeExceptionally((Throwable) value);
                            clearCacheKeys.add(keys.get(idx));
                        } else if (value instanceof Try) {
                            // we allow the batch loader to return a Try so we can better represent a computation
                            // that might have worked or not.
                            Try tryValue = (Try) value;
                            if (tryValue.isSuccess()) {
                                future.complete(tryValue.get());
                            } else {
                                stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
                                future.completeExceptionally(tryValue.getThrowable());
                                clearCacheKeys.add(keys.get(idx));
                            }
                        } else {
                            future.complete(value);
                        }
                    }
                    possiblyClearCacheEntriesOnExceptions(clearCacheKeys);
                    return values;
                }).exceptionally(ex -> {
                    stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
                    if (ex instanceof CompletionException) {
                        ex = ex.getCause();
                    }
                    for (int idx = 0; idx < queuedFutures.size(); idx++) {
                        K key = keys.get(idx);
                        CompletableFuture future = queuedFutures.get(idx);
                        future.completeExceptionally(ex);
                        // clear any cached view of this key because they all failed
                        dataLoader.clear(key);
                    }
                    return emptyList();
                });
    }


    private void assertResultSize(List keys, List values) {
        assertState(keys.size() == values.size(), () -> "The size of the promised values MUST be the same size as the key list");
    }

    private void possiblyClearCacheEntriesOnExceptions(List keys) {
        if (keys.isEmpty()) {
            return;
        }
        // by default, we don't clear the cached view of this entry to avoid
        // frequently loading the same error.  This works for short-lived request caches
        // but might work against long-lived caches. Hence, we have an option that allows
        // it to be cleared
        if (!loaderOptions.cachingExceptionsEnabled()) {
            keys.forEach(dataLoader::clear);
        }
    }

    @GuardedBy("dataLoader")
    private CompletableFuture loadFromCache(K key, Object loadContext, boolean batchingEnabled) {
        final Object cacheKey = loadContext == null ? getCacheKey(key) : getCacheKeyWithContext(key, loadContext);

        if (futureCache.containsKey(cacheKey)) {
            // We already have a promise for this key, no need to check value cache or queue up load
            stats.incrementCacheHitCount(new IncrementCacheHitCountStatisticsContext<>(key, loadContext));
            return futureCache.get(cacheKey);
        }

        CompletableFuture loadCallFuture = queueOrInvokeLoader(key, loadContext, batchingEnabled, true);
        futureCache.set(cacheKey, loadCallFuture);
        return loadCallFuture;
    }

    @GuardedBy("dataLoader")
    private CompletableFuture queueOrInvokeLoader(K key, Object loadContext, boolean batchingEnabled, boolean cachingEnabled) {
        if (batchingEnabled) {
            CompletableFuture loadCallFuture = new CompletableFuture<>();
            loaderQueue.add(new LoaderQueueEntry<>(key, loadCallFuture, loadContext));
            return loadCallFuture;
        } else {
            stats.incrementBatchLoadCountBy(1, new IncrementBatchLoadCountByStatisticsContext<>(key, loadContext));
            // immediate execution of batch function
            return invokeLoaderImmediately(key, loadContext, cachingEnabled);
        }
    }

    CompletableFuture invokeLoaderImmediately(K key, Object keyContext, boolean cachingEnabled) {
        List keys = singletonList(key);
        List keyContexts = singletonList(keyContext);
        return invokeLoader(keys, keyContexts, cachingEnabled)
                .thenApply(list -> list.get(0))
                .toCompletableFuture();
    }

    CompletableFuture> invokeLoader(List keys, List keyContexts, boolean cachingEnabled) {
        if (!cachingEnabled) {
            return invokeLoader(keys, keyContexts);
        }
        CompletableFuture>> cacheCallCF = getFromValueCache(keys);
        return cacheCallCF.thenCompose(cachedValues -> {

            // the following is NOT a Map because keys in data loader can repeat (by design)
            // and hence "a","b","c","b" is a valid set of keys
            List> valuesInKeyOrder = new ArrayList<>();
            List missedKeyIndexes = new ArrayList<>();
            List missedKeys = new ArrayList<>();
            List missedKeyContexts = new ArrayList<>();

            // if they return a ValueCachingNotSupported exception then we insert this special marker value, and it
            // means it's a total miss, we need to get all these keys via the batch loader
            if (cachedValues == NOT_SUPPORTED_LIST) {
                for (int i = 0; i < keys.size(); i++) {
                    valuesInKeyOrder.add(ALWAYS_FAILED);
                    missedKeyIndexes.add(i);
                    missedKeys.add(keys.get(i));
                    missedKeyContexts.add(keyContexts.get(i));
                }
            } else {
                assertState(keys.size() == cachedValues.size(), () -> "The size of the cached values MUST be the same size as the key list");
                for (int i = 0; i < keys.size(); i++) {
                    Try cacheGet = cachedValues.get(i);
                    valuesInKeyOrder.add(cacheGet);
                    if (cacheGet.isFailure()) {
                        missedKeyIndexes.add(i);
                        missedKeys.add(keys.get(i));
                        missedKeyContexts.add(keyContexts.get(i));
                    }
                }
            }
            if (missedKeys.isEmpty()) {
                //
                // everything was cached
                //
                List assembledValues = valuesInKeyOrder.stream().map(Try::get).collect(toList());
                return completedFuture(assembledValues);
            } else {
                //
                // we missed some keys from cache, so send them to the batch loader
                // and then fill in their values
                //
                CompletableFuture> batchLoad = invokeLoader(missedKeys, missedKeyContexts);
                return batchLoad.thenCompose(missedValues -> {
                    assertResultSize(missedKeys, missedValues);

                    for (int i = 0; i < missedValues.size(); i++) {
                        V v = missedValues.get(i);
                        Integer listIndex = missedKeyIndexes.get(i);
                        valuesInKeyOrder.set(listIndex, Try.succeeded(v));
                    }
                    List assembledValues = valuesInKeyOrder.stream().map(Try::get).collect(toList());
                    //
                    // fire off a call to the ValueCache to allow it to set values into the
                    // cache now that we have them
                    return setToValueCache(assembledValues, missedKeys, missedValues);
                });
            }
        });
    }


    CompletableFuture> invokeLoader(List keys, List keyContexts) {
        CompletableFuture> batchLoad;
        try {
            Object context = loaderOptions.getBatchLoaderContextProvider().getContext();
            BatchLoaderEnvironment environment = BatchLoaderEnvironment.newBatchLoaderEnvironment()
                    .context(context).keyContexts(keys, keyContexts).build();
            if (isMapLoader()) {
                batchLoad = invokeMapBatchLoader(keys, environment);
            } else {
                batchLoad = invokeListBatchLoader(keys, environment);
            }
        } catch (Exception e) {
            batchLoad = CompletableFutureKit.failedFuture(e);
        }
        return batchLoad;
    }

    @SuppressWarnings("unchecked")
    private CompletableFuture> invokeListBatchLoader(List keys, BatchLoaderEnvironment environment) {
        CompletionStage> loadResult;
        BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
        if (batchLoadFunction instanceof BatchLoaderWithContext) {
            BatchLoaderWithContext loadFunction = (BatchLoaderWithContext) batchLoadFunction;
            if (batchLoaderScheduler != null) {
                BatchLoaderScheduler.ScheduledBatchLoaderCall loadCall = () -> loadFunction.load(keys, environment);
                loadResult = batchLoaderScheduler.scheduleBatchLoader(loadCall, keys, environment);
            } else {
                loadResult = loadFunction.load(keys, environment);
            }
        } else {
            BatchLoader loadFunction = (BatchLoader) batchLoadFunction;
            if (batchLoaderScheduler != null) {
                BatchLoaderScheduler.ScheduledBatchLoaderCall loadCall = () -> loadFunction.load(keys);
                loadResult = batchLoaderScheduler.scheduleBatchLoader(loadCall, keys, null);
            } else {
                loadResult = loadFunction.load(keys);
            }
        }
        return nonNull(loadResult, () -> "Your batch loader function MUST return a non null CompletionStage").toCompletableFuture();
    }


    /*
     * Turns a map of results that MAY be smaller than the key list back into a list by mapping null
     * to missing elements.
     */
    @SuppressWarnings("unchecked")
    private CompletableFuture> invokeMapBatchLoader(List keys, BatchLoaderEnvironment environment) {
        CompletionStage> loadResult;
        Set setOfKeys = new LinkedHashSet<>(keys);
        BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
        if (batchLoadFunction instanceof MappedBatchLoaderWithContext) {
            MappedBatchLoaderWithContext loadFunction = (MappedBatchLoaderWithContext) batchLoadFunction;
            if (batchLoaderScheduler != null) {
                BatchLoaderScheduler.ScheduledMappedBatchLoaderCall loadCall = () -> loadFunction.load(setOfKeys, environment);
                loadResult = batchLoaderScheduler.scheduleMappedBatchLoader(loadCall, keys, environment);
            } else {
                loadResult = loadFunction.load(setOfKeys, environment);
            }
        } else {
            MappedBatchLoader loadFunction = (MappedBatchLoader) batchLoadFunction;
            if (batchLoaderScheduler != null) {
                BatchLoaderScheduler.ScheduledMappedBatchLoaderCall loadCall = () -> loadFunction.load(setOfKeys);
                loadResult = batchLoaderScheduler.scheduleMappedBatchLoader(loadCall, keys, null);
            } else {
                loadResult = loadFunction.load(setOfKeys);
            }
        }
        CompletableFuture> mapBatchLoad = nonNull(loadResult, () -> "Your batch loader function MUST return a non null CompletionStage").toCompletableFuture();
        return mapBatchLoad.thenApply(map -> {
            List values = new ArrayList<>(keys.size());
            for (K key : keys) {
                V value = map.get(key);
                values.add(value);
            }
            return values;
        });
    }

    private boolean isMapLoader() {
        return batchLoadFunction instanceof MappedBatchLoader || batchLoadFunction instanceof MappedBatchLoaderWithContext;
    }

    int dispatchDepth() {
        synchronized (dataLoader) {
            return loaderQueue.size();
        }
    }

    private final List> NOT_SUPPORTED_LIST = emptyList();
    private final CompletableFuture>> NOT_SUPPORTED = CompletableFuture.completedFuture(NOT_SUPPORTED_LIST);
    private final Try ALWAYS_FAILED = Try.alwaysFailed();

    private CompletableFuture>> getFromValueCache(List keys) {
        try {
            return nonNull(valueCache.getValues(keys), () -> "Your ValueCache.getValues function MUST return a non null CompletableFuture");
        } catch (ValueCache.ValueCachingNotSupported ignored) {
            // use of a final field prevents CF object allocation for this special purpose
            return NOT_SUPPORTED;
        } catch (RuntimeException e) {
            return CompletableFutureKit.failedFuture(e);
        }
    }

    private CompletableFuture> setToValueCache(List assembledValues, List missedKeys, List missedValues) {
        try {
            boolean completeValueAfterCacheSet = loaderOptions.getValueCacheOptions().isCompleteValueAfterCacheSet();
            if (completeValueAfterCacheSet) {
                return nonNull(valueCache
                        .setValues(missedKeys, missedValues), () -> "Your ValueCache.setValues function MUST return a non null CompletableFuture")
                        // we don't trust the set cache to give us the values back - we have them - lets use them
                        // if the cache set fails - then they won't be in cache and maybe next time they will
                        .handle((ignored, setExIgnored) -> assembledValues);
            } else {
                // no one is waiting for the set to happen here so if its truly async
                // it will happen eventually but no result will be dependent on it
                valueCache.setValues(missedKeys, missedValues);
            }
        } catch (ValueCache.ValueCachingNotSupported ignored) {
            // ok no set caching is fine if they say so
        } catch (RuntimeException ignored) {
            // if we can't set values back into the cache - so be it - this must be a faulty
            // ValueCache implementation
        }
        return CompletableFuture.completedFuture(assembledValues);
    }

    private static final DispatchResult EMPTY_DISPATCH_RESULT = new DispatchResult<>(completedFuture(emptyList()), 0);

    @SuppressWarnings("unchecked") // Casting to any type is safe since the underlying list is empty
    private static  DispatchResult emptyDispatchResult() {
        return (DispatchResult) EMPTY_DISPATCH_RESULT;
    }
}