All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.jita.allocator.concurrency;

import lombok.NonNull;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 *
 *
 * @author [email protected]
 */
public class DeviceAllocationsTracker {
    private Configuration configuration;

    private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();

    private final Map deviceLocks = new ConcurrentHashMap<>();

    private final Map memoryTackled = new ConcurrentHashMap<>();

    private final Map reservedSpace = new ConcurrentHashMap<>();

    private static Logger log = LoggerFactory.getLogger(DeviceAllocationsTracker.class);

    public DeviceAllocationsTracker(@NonNull Configuration configuration) {
        this.configuration = configuration;

        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();

        for (int device = 0; device < numDevices; device++) {
            deviceLocks.put(device, new ReentrantReadWriteLock());
        }
    }

    protected void ensureThreadRegistered(Long threadId, Integer deviceId) {
        globalLock.readLock().lock();

        //  boolean contains = allocationTable.contains(deviceId, threadId);

        globalLock.readLock().unlock();

        if (!memoryTackled.containsKey(deviceId)) {
            globalLock.writeLock().lock();

            //contains = allocationTable.contains(deviceId, threadId);
            //if (!contains) {
            //allocationTable.put(deviceId, threadId, new AtomicLong(0));

            if (!memoryTackled.containsKey(deviceId)) {
                memoryTackled.put(deviceId, new AtomicLong(0));
            }

            if (!reservedSpace.containsKey(deviceId)) {
                reservedSpace.put(deviceId, new AtomicLong(0));
            }
            //}
            globalLock.writeLock().unlock();
        }
    }

    public long addToAllocation(@NonNull Long threadId, Integer deviceId, long memorySize) {
        ensureThreadRegistered(threadId, deviceId);
        try {
            deviceLocks.get(deviceId).readLock().lock();

            long res = memoryTackled.get(deviceId).addAndGet(memorySize);

            subFromReservedSpace(deviceId, memorySize);

            return res; //allocationTable.get(deviceId, threadId).addAndGet(memorySize);
        } finally {
            deviceLocks.get(deviceId).readLock().unlock();
        }
    }

    public long subFromAllocation(Long threadId, Integer deviceId, long memorySize) {
        ensureThreadRegistered(threadId, deviceId);

        try {
            deviceLocks.get(deviceId).writeLock().lock();

            AtomicLong val2 = memoryTackled.get(deviceId);
            //long before = val2.get();
            val2.addAndGet(memorySize * -1);

            //long after = memoryTackled.get(deviceId).get();

            //log.info("Memory reduction on device [{}], memory size: [{}], before: [{}], after [{}]", deviceId, memorySize, before, after);

            //            AtomicLong val = allocationTable.get(deviceId, threadId);

            //            val.addAndGet(memorySize * -1);

            return val2.get();
        } finally {
            deviceLocks.get(deviceId).writeLock().unlock();
        }
    }

    /**
     * This method "reserves" memory within allocator
     *
     * @param threadId
     * @param deviceId
     * @param memorySize
     * @return
     */
    public boolean reserveAllocationIfPossible(Long threadId, Integer deviceId, long memorySize) {
        ensureThreadRegistered(threadId, deviceId);
        try {
            deviceLocks.get(deviceId).writeLock().lock();
            /*
            if (getAllocatedSize(deviceId) + memorySize + getReservedSpace(deviceId)> environment.getDeviceInformation(deviceId).getTotalMemory() * configuration.getMaxDeviceMemoryUsed()) {
                return false;
            } else {
                addToReservedSpace(deviceId, memorySize);
                return true;
            }
            */
            addToReservedSpace(deviceId, memorySize);
            return true;
        } finally {
            deviceLocks.get(deviceId).writeLock().unlock();
        }
    }

    public long getAllocatedSize(Long threadId, Integer deviceId) {
        ensureThreadRegistered(threadId, deviceId);

        try {
            deviceLocks.get(deviceId).readLock().lock();

            return getAllocatedSize(deviceId); /// allocationTable.get(deviceId, threadId).get();
        } finally {
            deviceLocks.get(deviceId).readLock().unlock();
        }
    }


    public long getAllocatedSize(Integer deviceId) {
        if (!memoryTackled.containsKey(deviceId))
            return 0L;
        try {
            deviceLocks.get(deviceId).readLock().lock();
            return memoryTackled.get(deviceId).get();
        } finally {
            deviceLocks.get(deviceId).readLock().unlock();
        }
    }

    public long getReservedSpace(Integer deviceId) {
        return reservedSpace.get(deviceId).get();
    }

    protected void addToReservedSpace(Integer deviceId, long memorySize) {
        ensureThreadRegistered(Thread.currentThread().getId(), deviceId);

        reservedSpace.get(deviceId).addAndGet(memorySize);
    }

    protected void subFromReservedSpace(Integer deviceId, long memorySize) {
        ensureThreadRegistered(Thread.currentThread().getId(), deviceId);

        reservedSpace.get(deviceId).addAndGet(memorySize * -1);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy