All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.memory.impl.CudaCachingZeroProvider Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.jita.memory.impl;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.MemoryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;


/**
 * This is MemoryProvider implementation, that adds cache for memory reuse purposes. Only host memory is cached for future reuse.
 *
 * If some memory chunk gets released via allocator, it'll be probably saved for future reused within same JVM process.
 *
 * @author [email protected]
 */
public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider {
    private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class);

    protected volatile ConcurrentHashMap zeroCache = new ConcurrentHashMap<>();

    protected final AtomicLong cacheZeroHit = new AtomicLong(0);
    protected final AtomicLong cacheZeroMiss = new AtomicLong(0);

    protected final AtomicLong cacheDeviceHit = new AtomicLong(0);
    protected final AtomicLong cacheDeviceMiss = new AtomicLong(0);



    private final AtomicLong allocRequests = new AtomicLong(0);

    protected final AtomicLong zeroCachedAmount = new AtomicLong(0);
    protected List deviceCachedAmount = new ArrayList<>();


    protected final Semaphore singleLock = new Semaphore(1);

    // we don't cache allocations greater then this value
    //protected final long MAX_SINGLE_ALLOCATION = configuration.getMaximumHostCacheableLength();

    // maximum cached size of memory
    //protected final long MAX_CACHED_MEMORY = configuration.getMaximumHostCache();

    // memory chunks below this threshold will be guaranteed regardless of number of cache entries
    // that especially covers all possible variations of shapeInfoDataBuffers in all possible cases
    protected final long FORCED_CACHE_THRESHOLD = 96;

    //  number of preallocation entries for each yet-unknown shape
    //protected final int PREALLOCATION_LIMIT = configuration.getPreallocationCalls();

    public CudaCachingZeroProvider() {

    }

    /**
     * This method provides PointersPair to memory chunk specified by AllocationShape
     *
     * PLEASE NOTE: This method can actually ignore malloc request, and give out previously cached free memory chunk with equal shape.
     *
     * @param shape shape of desired memory chunk
     * @param point target AllocationPoint structure
     * @param location either HOST or DEVICE
     * @return
     */
    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
        long reqMemory = AllocationUtils.getRequiredMemory(shape);

        if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {

            CacheHolder cache = zeroCache.get(shape);
            if (cache != null) {
                Pointer pointer = cache.poll();
                if (pointer != null) {
                    cacheZeroHit.incrementAndGet();

                    // since this memory chunk is going to be used now, remove it's amount from
                    zeroCachedAmount.addAndGet(-1 * reqMemory);

                    PointersPair pair = new PointersPair();
                    pair.setDevicePointer(new CudaPointer(pointer.address()));
                    pair.setHostPointer(new CudaPointer(pointer.address()));

                    point.setAllocationStatus(AllocationStatus.HOST);
                    return pair;
                }
            }
            cacheZeroMiss.incrementAndGet();

            if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10
                            && reqMemory < 16 * 1024 * 1024L) {
                CachePreallocator preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls());
                preallocator.start();
            }

            cacheZeroMiss.incrementAndGet();
            return super.malloc(shape, point, location);
        }

        return super.malloc(shape, point, location);
    }



    protected void ensureCacheHolder(AllocationShape shape) {
        if (!zeroCache.containsKey(shape)) {
            try {
                singleLock.acquire();
                if (!zeroCache.containsKey(shape)) {
                    zeroCache.put(shape, new CacheHolder(shape, zeroCachedAmount));
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                singleLock.release();
            }
        }

    }

    /**
     * This method frees specific chunk of memory, described by AllocationPoint passed in.
     *
     * PLEASE NOTE: This method can actually ignore free, and keep released memory chunk for future reuse.
     *
     * @param point
     */
    @Override
    public void free(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            super.free(point);
        } else {
            AllocationShape shape = point.getShape();
            long reqMemory = AllocationUtils.getRequiredMemory(shape);

            // we don't cache too big objects
            if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
                //log.info("HOST memory purging: {} bytes; MS: {}; MT: {}", reqMemory, MAX_SINGLE_ALLOCATION, MAX_CACHED_MEMORY);
                super.free(point);
                return;
            }

            ensureCacheHolder(shape);

            //log.info("Saving DEVICE memory into cache...");

            /*
                Now we should decide if this object can be cached or not
             */
            CacheHolder cache = zeroCache.get(shape);

            // memory chunks < threshold will be cached no matter what
            if (reqMemory <= FORCED_CACHE_THRESHOLD) {
                Pointer.memset(point.getHostPointer(), 0, reqMemory);
                cache.put(new CudaPointer(point.getHostPointer().address()));
            } else {
                long cacheEntries = cache.size();
                long cacheHeight = zeroCache.size();

                // total memory allocated within this bucket
                long cacheDepth = cacheEntries * reqMemory;

                //   if (cacheDepth < MAX_CACHED_MEMORY / cacheHeight) {
                Pointer.memset(point.getHostPointer(), 0, reqMemory);
                cache.put(new CudaPointer(point.getHostPointer().address()));
                //    } else {
                //       super.free(point);
                //    }
            }
        }
    }

    private float getZeroCacheHitRatio() {
        long totalHits = cacheZeroHit.get() + cacheZeroMiss.get();
        float cacheRatio = cacheZeroHit.get() * 100 / (float) totalHits;
        return cacheRatio;
    }

    private float getDeviceCacheHitRatio() {
        long totalHits = cacheDeviceHit.get() + cacheDeviceMiss.get();
        float cacheRatio = cacheDeviceHit.get() * 100 / (float) totalHits;
        return cacheRatio;
    }

    @Deprecated
    public void printCacheStats() {
        log.debug("Cached host amount: " + zeroCachedAmount.get());
        log.debug("Cached device amount: " + deviceCachedAmount.get(0).get());
        log.debug("Total shapes in cache: " + zeroCache.size());
        log.debug("Current host hit ratio: " + getZeroCacheHitRatio());
        log.debug("Current device hit ratio: " + getDeviceCacheHitRatio());
    }

    protected class CacheHolder {
        private Queue queue = new ConcurrentLinkedQueue<>();
        private AtomicInteger counter = new AtomicInteger(0);
        private long reqMem = 0;
        private final AtomicLong allocCounter;

        public CacheHolder(AllocationShape shape, AtomicLong counter) {
            this.reqMem = AllocationUtils.getRequiredMemory(shape);
            this.allocCounter = counter;
        }

        public int size() {
            return counter.get();
        }

        public Pointer poll() {
            Pointer pointer = queue.poll();
            if (pointer != null)
                counter.decrementAndGet();

            return pointer;
        }

        public void put(Pointer pointer) {
            allocCounter.addAndGet(reqMem);
            counter.incrementAndGet();
            queue.add(pointer);
        }
    }

    protected class CachePreallocator extends Thread implements Runnable {

        private AllocationShape shape;
        private AllocationStatus location;
        private int target;

        public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) {
            this.shape = shape;
            this.target = numberOfEntries;
            this.location = location;
        }

        @Override
        public void run() {
            //            log.info("Precaching ["+target+"] chunks for shape: " + shape);

            ensureCacheHolder(shape);

            for (int i = 0; i < target; i++) {
                AllocationPoint point = new AllocationPoint();

                PointersPair pair = CudaCachingZeroProvider.super.malloc(shape, point, this.location);
                if (this.location == AllocationStatus.HOST) {
                    Pointer pointer = new CudaPointer(pair.getHostPointer().address());
                    CudaCachingZeroProvider.this.zeroCache.get(shape).put(pointer);
                }
            }
        }
    }

    @Override
    public void purgeCache() {
        for (AllocationShape shape : zeroCache.keySet()) {
            Pointer ptr = null;
            while ((ptr = zeroCache.get(shape).poll()) != null) {
                freeHost(ptr);
            }
        }

        zeroCachedAmount.set(0);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy