All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlin.storage.LockBasedStorageManager Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2015 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.kotlin.storage;

import kotlin.Unit;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.text.StringsKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlin.utils.ExceptionUtilsKt;
import org.jetbrains.kotlin.utils.WrappedValues;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

public class LockBasedStorageManager implements StorageManager {
    private static final String PACKAGE_NAME = StringsKt.substringBeforeLast(LockBasedStorageManager.class.getCanonicalName(), ".", "");

    public interface ExceptionHandlingStrategy {
        ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
            @NotNull
            @Override
            public RuntimeException handleException(@NotNull Throwable throwable) {
                throw ExceptionUtilsKt.rethrow(throwable);
            }
        };

        /*
         * The signature of this method is a trick: it is used as
         *
         *     throw strategy.handleException(...)
         *
         * most implementations of this method throw exceptions themselves, so it does not matter what they return
         */
        @NotNull
        RuntimeException handleException(@NotNull Throwable throwable);
    }

    public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, EmptySimpleLock.INSTANCE) {
        @NotNull
        @Override
        protected  RecursionDetectedResult recursionDetectedDefault(@NotNull String source, K input) {
            return RecursionDetectedResult.fallThrough();
        }
    };

    @NotNull
    public static LockBasedStorageManager createWithExceptionHandling(
            @NotNull String debugText,
            @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy
    ) {
        return createWithExceptionHandling(debugText, exceptionHandlingStrategy, null, null);
    }

    @NotNull
    public static LockBasedStorageManager createWithExceptionHandling(
            @NotNull String debugText,
            @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
            @Nullable Runnable checkCancelled,
            @Nullable Function1 interruptedExceptionHandler
    ) {
        return new LockBasedStorageManager(debugText, exceptionHandlingStrategy,
                                           SimpleLock.Companion.simpleLock(checkCancelled, interruptedExceptionHandler));
    }

    protected final SimpleLock lock;
    private final ExceptionHandlingStrategy exceptionHandlingStrategy;
    private final String debugText;

    private LockBasedStorageManager(
            @NotNull String debugText,
            @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
            @NotNull SimpleLock lock
    ) {
        this.lock = lock;
        this.exceptionHandlingStrategy = exceptionHandlingStrategy;
        this.debugText = debugText;
    }

    public LockBasedStorageManager(String debugText) {
        this(debugText, (Runnable) null, null);
    }

    public LockBasedStorageManager(
            String debugText,
            @Nullable Runnable checkCancelled,
            @Nullable Function1 interruptedExceptionHandler
    ) {
        this(debugText, ExceptionHandlingStrategy.THROW, SimpleLock.Companion.simpleLock(checkCancelled, interruptedExceptionHandler));
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
    }

    public LockBasedStorageManager replaceExceptionHandling(
            @NotNull String debugText, @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy
    ) {
        return new LockBasedStorageManager(debugText, exceptionHandlingStrategy, lock);
    }

    @NotNull
    @Override
    public  MemoizedFunctionToNotNull createMemoizedFunction(@NotNull Function1 compute) {
        return createMemoizedFunction(compute, LockBasedStorageManager.createConcurrentHashMap());
    }

    @NotNull
    @Override
    public  MemoizedFunctionToNotNull createMemoizedFunction(
            @NotNull Function1 compute,
            @NotNull Function2 onRecursiveCall
    ) {
        return createMemoizedFunction(compute, onRecursiveCall, LockBasedStorageManager.createConcurrentHashMap());
    }

    @NotNull
    @Override
    public  MemoizedFunctionToNotNull createMemoizedFunction(
            @NotNull Function1 compute,
            @NotNull ConcurrentMap map
    ) {
        return new MapBasedMemoizedFunctionToNotNull(this, map, compute);
    }

    @NotNull
    @Override
    public  MemoizedFunctionToNotNull createMemoizedFunction(
            @NotNull Function1 compute,
            @NotNull final Function2 onRecursiveCall,
            @NotNull ConcurrentMap map
    ) {
        return new MapBasedMemoizedFunctionToNotNull(this, map, compute) {
            @NotNull
            @Override
            protected RecursionDetectedResult recursionDetected(K input, boolean firstTime) {
                return RecursionDetectedResult.value(onRecursiveCall.invoke(input, firstTime));
            }
        };
    }

    @NotNull
    @Override
    public  MemoizedFunctionToNullable createMemoizedFunctionWithNullableValues(@NotNull Function1 compute) {
        return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.createConcurrentHashMap());
    }

    @Override
    @NotNull
    public   MemoizedFunctionToNullable createMemoizedFunctionWithNullableValues(
            @NotNull Function1 compute,
            @NotNull ConcurrentMap map
    ) {
        return new MapBasedMemoizedFunction(this, map, compute);
    }

    @NotNull
    @Override
    public  NotNullLazyValue createLazyValue(@NotNull Function0 computable) {
        return new LockBasedNotNullLazyValue(this, computable);
    }

    @NotNull
    @Override
    public  NotNullLazyValue createLazyValue(
            @NotNull Function0 computable,
            @NotNull final Function1 onRecursiveCall
    ) {
        return new LockBasedNotNullLazyValue(this, computable) {
            @NotNull
            @Override
            protected RecursionDetectedResult recursionDetected(boolean firstTime) {
                return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
            }
        };
    }

    @NotNull
    @Override
    public  NotNullLazyValue createRecursionTolerantLazyValue(
            @NotNull Function0 computable, @NotNull final T onRecursiveCall
    ) {
        return new LockBasedNotNullLazyValue(this, computable) {
            @NotNull
            @Override
            protected RecursionDetectedResult recursionDetected(boolean firstTime) {
                return RecursionDetectedResult.value(onRecursiveCall);
            }

            @Override
            protected String presentableName() {
                return "RecursionTolerantLazyValue";
            }
        };
    }

    @NotNull
    @Override
    public  NotNullLazyValue createLazyValueWithPostCompute(
            @NotNull Function0 computable,
            final Function1 onRecursiveCall,
            @NotNull final Function1 postCompute
    ) {
        return new LockBasedNotNullLazyValueWithPostCompute(this, computable) {
            @NotNull
            @Override
            protected RecursionDetectedResult recursionDetected(boolean firstTime) {
                if (onRecursiveCall == null) {
                    return super.recursionDetected(firstTime);
                }
                return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
            }

            @Override
            protected void doPostCompute(@NotNull T value) {
                postCompute.invoke(value);
            }

            @Override
            protected String presentableName() {
                return "LockBasedNotNullLazyValueWithPostCompute";
            }
        };
    }

    @NotNull
    @Override
    public  NullableLazyValue createNullableLazyValue(@NotNull Function0 computable) {
        return new LockBasedLazyValue(this, computable);
    }

    @NotNull
    @Override
    public  NullableLazyValue createRecursionTolerantNullableLazyValue(@NotNull Function0 computable, final T onRecursiveCall) {
        return new LockBasedLazyValue(this, computable) {
            @NotNull
            @Override
            protected RecursionDetectedResult recursionDetected(boolean firstTime) {
                return RecursionDetectedResult.value(onRecursiveCall);
            }

            @Override
            protected String presentableName() {
                return "RecursionTolerantNullableLazyValue";
            }
        };
    }

    @NotNull
    @Override
    public  NullableLazyValue createNullableLazyValueWithPostCompute(
            @NotNull Function0 computable, @NotNull final Function1 postCompute
    ) {
        return new LockBasedLazyValueWithPostCompute(this, computable) {
            @Override
            protected void doPostCompute(T value) {
                postCompute.invoke(value);
            }

            @Override
            protected String presentableName() {
                return "NullableLazyValueWithPostCompute";
            }
        };
    }

    @Override
    public  T compute(@NotNull Function0 computable) {
        lock.lock();
        try {
            return computable.invoke();
        }
        catch (Throwable throwable) {
            throw exceptionHandlingStrategy.handleException(throwable);
        }
        finally {
            lock.unlock();
        }
    }

    @NotNull
    private static  ConcurrentMap createConcurrentHashMap() {
        // memory optimization: fewer segments and entries stored
        return new ConcurrentHashMap(3, 1, 2);
    }

    @NotNull
    protected  RecursionDetectedResult recursionDetectedDefault(@NotNull String source, K input) {
        throw sanitizeStackTrace(
                new AssertionError("Recursion detected " + source +
                        (input == null
                         ? ""
                         : "on input: " + input
                         ) + " under " + this
                )
        );
    }

    private static class RecursionDetectedResult {

        @NotNull
        public static  RecursionDetectedResult value(T value) {
            return new RecursionDetectedResult(value, false);
        }

        @NotNull
        public static  RecursionDetectedResult fallThrough() {
            return new RecursionDetectedResult(null, true);
        }

        private final T value;
        private final boolean fallThrough;

        private RecursionDetectedResult(T value, boolean fallThrough) {
            this.value = value;
            this.fallThrough = fallThrough;
        }

        public T getValue() {
            assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
            return value;
        }

        public boolean isFallThrough() {
            return fallThrough;
        }

        @Override
        public String toString() {
            return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
        }
    }

    private enum NotValue {
        NOT_COMPUTED,
        COMPUTING,
        RECURSION_WAS_DETECTED
    }

    private static class LockBasedLazyValue implements NullableLazyValue {
        private final LockBasedStorageManager storageManager;
        private final Function0 computable;

        @Nullable
        private volatile Object value = NotValue.NOT_COMPUTED;

        public LockBasedLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0 computable) {
            this.storageManager = storageManager;
            this.computable = computable;
        }

        @Override
        public boolean isComputed() {
            return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
        }

        @Override
        public boolean isComputing() {
            return value == NotValue.COMPUTING;
        }

        @Override
        public T invoke() {
            Object _value = value;
            if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);

            storageManager.lock.lock();
            try {
                _value = value;
                if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);

                if (_value == NotValue.COMPUTING) {
                    value = NotValue.RECURSION_WAS_DETECTED;
                    RecursionDetectedResult result = recursionDetected(/*firstTime = */ true);
                    if (!result.isFallThrough()) {
                        return result.getValue();
                    }
                }

                if (_value == NotValue.RECURSION_WAS_DETECTED) {
                    RecursionDetectedResult result = recursionDetected(/*firstTime = */ false);
                    if (!result.isFallThrough()) {
                        return result.getValue();
                    }
                }

                value = NotValue.COMPUTING;
                try {
                    T typedValue = computable.invoke();

                    // Don't publish computed value till post compute is finished as it may cause a race condition
                    // if post compute modifies value internals.
                    postCompute(typedValue);

                    value = typedValue;
                    return typedValue;
                }
                catch (Throwable throwable) {
                    if (ExceptionUtilsKt.isProcessCanceledException(throwable)) {
                        value = NotValue.NOT_COMPUTED;
                        //noinspection ConstantConditions
                        throw (RuntimeException)throwable;
                    }

                    if (value == NotValue.COMPUTING) {
                        // Store only if it's a genuine result, not something thrown through recursionDetected()
                        value = WrappedValues.escapeThrowable(throwable);
                    }
                    throw storageManager.exceptionHandlingStrategy.handleException(throwable);
                }
            }
            finally {
                storageManager.lock.unlock();
            }
        }

        /**
         * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
         * @return a value to be returned on a recursive call or subsequent calls
         */
        @NotNull
        protected RecursionDetectedResult recursionDetected(boolean firstTime) {
            return storageManager.recursionDetectedDefault("in a lazy value", null);
        }

        protected void postCompute(T value) {
            // Default post compute implementation doesn't publish the value till it is finished
        }

        @NotNull
        public String renderDebugInformation() {
            return presentableName() + ", storageManager=" + storageManager;
        }

        protected String presentableName() {
            return this.getClass().getName();
        }
    }

    /**
     * Computed value has an early publication and accessible from the same thread while executing a post-compute lambda.
     * For other threads value will be accessible only after post-compute lambda is finished (when a real lock is used).
     */
    private static abstract class LockBasedLazyValueWithPostCompute extends LockBasedLazyValue {
        @Nullable
        private volatile SingleThreadValue valuePostCompute = null;

        public LockBasedLazyValueWithPostCompute(
                @NotNull LockBasedStorageManager storageManager,
                @NotNull Function0 computable
        ) {
            super(storageManager, computable);
        }

        @Override
        public T invoke() {
            SingleThreadValue postComputeCache = valuePostCompute;
            if (postComputeCache != null && postComputeCache.hasValue()) {
                return postComputeCache.getValue();
            }

            return super.invoke();
        }

        // Doing something in post-compute helps prevent infinite recursion
        @Override
        protected final void postCompute(T value) {
            // Protected from rewrites in other threads because it is executed under lock in invoke().
            // May be overwritten when NO_LOCK is used.
            valuePostCompute = new SingleThreadValue(value);
            try {
                doPostCompute(value);
            } finally {
                valuePostCompute = null;
            }
        }

        protected abstract void doPostCompute(T value);
    }

    private static abstract class LockBasedNotNullLazyValueWithPostCompute extends LockBasedLazyValueWithPostCompute
            implements NotNullLazyValue {
        public LockBasedNotNullLazyValueWithPostCompute(
                @NotNull LockBasedStorageManager storageManager,
                @NotNull Function0 computable
        ) {
            super(storageManager, computable);
        }

        @Override
        @NotNull
        public T invoke() {
            T result = super.invoke();
            assert result != null : "compute() returned null";
            return result;
        }
    }


    private static class LockBasedNotNullLazyValue extends LockBasedLazyValue implements NotNullLazyValue {
        public LockBasedNotNullLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0 computable) {
            super(storageManager, computable);
        }

        @Override
        @NotNull
        public T invoke() {
            T result = super.invoke();
            assert result != null : "compute() returned null";
            return result;
        }
    }

    private static class MapBasedMemoizedFunction implements MemoizedFunctionToNullable {
        private final LockBasedStorageManager storageManager;
        private final ConcurrentMap cache;
        private final Function1 compute;

        public MapBasedMemoizedFunction(
                @NotNull LockBasedStorageManager storageManager,
                @NotNull ConcurrentMap map,
                @NotNull Function1 compute
        ) {
            this.storageManager = storageManager;
            this.cache = map;
            this.compute = compute;
        }

        @Override
        @Nullable
        public V invoke(K input) {
            Object value = cache.get(input);
            if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);

            storageManager.lock.lock();
            try {
                value = cache.get(input);

                if (value == NotValue.COMPUTING) {
                    value = NotValue.RECURSION_WAS_DETECTED;
                    RecursionDetectedResult result = recursionDetected(input, /*firstTime = */ true);
                    if (!result.isFallThrough()) {
                        return result.getValue();
                    }
                }

                if (value == NotValue.RECURSION_WAS_DETECTED) {
                    RecursionDetectedResult result = recursionDetected(input, /*firstTime = */ false);
                    if (!result.isFallThrough()) {
                        return result.getValue();
                    }
                }

                if (value != null) return WrappedValues.unescapeExceptionOrNull(value);

                AssertionError error = null;
                try {
                    cache.put(input, NotValue.COMPUTING);
                    V typedValue = compute.invoke(input);
                    Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));

                    // This code effectively asserts that oldValue is null
                    // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
                    // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
                    // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
                    if (oldValue != NotValue.COMPUTING) {
                        error = raceCondition(input, oldValue);
                        throw error;
                    }

                    return typedValue;
                }
                catch (Throwable throwable) {
                    if (ExceptionUtilsKt.isProcessCanceledException(throwable)) {
                        cache.remove(input);
                        //noinspection ConstantConditions
                        throw (RuntimeException)throwable;
                    }
                    if (throwable == error) {
                        throw storageManager.exceptionHandlingStrategy.handleException(throwable);
                    }

                    Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
                    if (oldValue != NotValue.COMPUTING) {
                        throw raceCondition(input, oldValue);
                    }

                    throw storageManager.exceptionHandlingStrategy.handleException(throwable);
                }
            }
            finally {
                storageManager.lock.unlock();
            }
        }

        @NotNull
        protected RecursionDetectedResult recursionDetected(K input, boolean firstTime) {
            return storageManager.recursionDetectedDefault("", input);
        }

        @NotNull
        private AssertionError raceCondition(K input, Object oldValue) {
            return sanitizeStackTrace(
                    new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
                                       " under " + storageManager)
            );
        }

        @Override
        public boolean isComputed(K key) {
            Object value = cache.get(key);
            return value != null && value != NotValue.COMPUTING;
        }

        protected LockBasedStorageManager getStorageManager() {
            return storageManager;
        }
    }

    private static class MapBasedMemoizedFunctionToNotNull extends MapBasedMemoizedFunction implements MemoizedFunctionToNotNull {

        public MapBasedMemoizedFunctionToNotNull(
                @NotNull LockBasedStorageManager storageManager, @NotNull ConcurrentMap map,
                @NotNull Function1 compute
        ) {
            super(storageManager, map, compute);
        }

        @NotNull
        @Override
        public V invoke(K input) {
            V result = super.invoke(input);
            assert result != null : "compute() returned null under " + getStorageManager();
            return result;
        }
    }

    @NotNull
    private static  T sanitizeStackTrace(@NotNull T throwable) {
        StackTraceElement[] stackTrace = throwable.getStackTrace();
        int size = stackTrace.length;

        int firstNonStorage = -1;
        for (int i = 0; i < size; i++) {
            // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage
            if (!stackTrace[i].getClassName().startsWith(PACKAGE_NAME)) {
                firstNonStorage = i;
                break;
            }
        }
        assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager";

        List list = Arrays.asList(stackTrace).subList(firstNonStorage, size);
        throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()]));
        return throwable;
    }

    @NotNull
    @Override
    public  CacheWithNullableValues createCacheWithNullableValues() {
        return new CacheWithNullableValuesBasedOnMemoizedFunction(
                this, LockBasedStorageManager.>createConcurrentHashMap());
    }

    private static class CacheWithNullableValuesBasedOnMemoizedFunction extends MapBasedMemoizedFunction, V> implements CacheWithNullableValues {

        private CacheWithNullableValuesBasedOnMemoizedFunction(
                @NotNull LockBasedStorageManager storageManager,
                @NotNull ConcurrentMap, Object> map
        ) {
            super(storageManager, map, new Function1, V>() {
                @Override
                public V invoke(KeyWithComputation computation) {
                    return computation.computation.invoke();
                }
            });
        }

        @Nullable
        @Override
        public V computeIfAbsent(K key, @NotNull Function0 computation) {
            return invoke(new KeyWithComputation(key, computation));
        }
    }

    @NotNull
    @Override
    public  CacheWithNotNullValues createCacheWithNotNullValues() {
        return new CacheWithNotNullValuesBasedOnMemoizedFunction(this, LockBasedStorageManager.>createConcurrentHashMap());
    }

    private static class CacheWithNotNullValuesBasedOnMemoizedFunction extends CacheWithNullableValuesBasedOnMemoizedFunction implements CacheWithNotNullValues {

        private CacheWithNotNullValuesBasedOnMemoizedFunction(
                @NotNull LockBasedStorageManager storageManager,
                @NotNull ConcurrentMap, Object> map
        ) {
            super(storageManager, map);
        }

        @NotNull
        @Override
        public V computeIfAbsent(K key, @NotNull Function0 computation) {
            V result = super.computeIfAbsent(key, computation);
            assert result != null : "computeIfAbsent() returned null under " + getStorageManager();
            return result;
        }
    }

    // equals and hashCode use only key
    private static class KeyWithComputation {
        private final K key;
        private final Function0 computation;

        public KeyWithComputation(K key, Function0 computation) {
            this.key = key;
            this.computation = computation;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            KeyWithComputation that = (KeyWithComputation) o;

            if (!key.equals(that.key)) return false;

            return true;
        }

        @Override
        public int hashCode() {
            return key.hashCode();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy