All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.memory.deallocation.DeallocatorService Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.memory.deallocation;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.factory.Nd4j;


import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * This class provides unified management for Deallocatable resources
 *
 * @author [email protected]
 */
@Slf4j
public class DeallocatorService {
    private Thread[] deallocatorThreads;
    private ReferenceQueue[] queues;
    private Map referenceMap = new ConcurrentHashMap<>();
    private List>> deviceMap = new ArrayList<>();

    public DeallocatorService() {
        // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int numThreads = Math.max(2, numDevices * 2);

        for (int e = 0; e < numDevices; e++)
            deviceMap.add(new ArrayList>());

        deallocatorThreads = new Thread[numThreads];
        queues = new ReferenceQueue[numThreads];
        for (int e = 0; e < numThreads; e++) {
            log.trace("Starting deallocator thread {}", e + 1);
            queues[e] = new ReferenceQueue<>();

            int deviceId = e % numDevices;
            // attaching queue to its own thread
            deallocatorThreads[e] = new DeallocatorServiceThread(queues[e], e, deviceId);
            deallocatorThreads[e].setName("DeallocatorServiceThread_" + e);
            deallocatorThreads[e].setDaemon(true);

            deviceMap.get(deviceId).add(queues[e]);
            
            deallocatorThreads[e].start();
        }
    }

    /**
     * This method adds Deallocatable object instance to tracking system
     *
     * @param deallocatable object to track
     */
    public void pickObject(@NonNull Deallocatable deallocatable) {
        val desiredDevice = deallocatable.targetDevice();
        val map = deviceMap.get(desiredDevice);
        val reference = new DeallocatableReference(deallocatable, map.get(RandomUtils.nextInt(0, map.size())));
        referenceMap.put(deallocatable.getUniqueId(), reference);
    }


    private class DeallocatorServiceThread extends Thread implements Runnable {
        private final ReferenceQueue queue;
        private final int threadIdx;
        public static final String DeallocatorThreadNamePrefix = "DeallocatorServiceThread thread ";
        private final int deviceId;

        private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx, int deviceId) {
            this.queue = queue;
            this.threadIdx = threadIdx;
            this.setName(DeallocatorThreadNamePrefix + threadIdx);
            this.deviceId = deviceId;
            setContextClassLoader(null);
        }

        @Override
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
            boolean canRun = true;
            long cnt = 0;
            while (canRun) {
                // if periodicGc is enabled, only first thread will call for it
                if (Nd4j.getMemoryManager().isPeriodicGcActive() && threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) {
                    val reference = (DeallocatableReference) queue.poll();
                    if (reference == null) {
                        val timeout = Nd4j.getMemoryManager().getAutoGcWindow();
                        try {
                            Thread.sleep(Nd4j.getMemoryManager().getAutoGcWindow());
                            Nd4j.getMemoryManager().invokeGc();
                        } catch (InterruptedException e) {
                            canRun = false;
                        }
                    } else {
                        // invoking deallocator
                        reference.getDeallocator().deallocate();
                        referenceMap.remove(reference.getId());
                    }
                } else {
                    try {
                        val reference = (DeallocatableReference) queue.remove();
                        if (reference == null)
                            continue;

                        // invoking deallocator
                        reference.getDeallocator().deallocate();
                        referenceMap.remove(reference.getId());
                    } catch (InterruptedException e) {
                        canRun = false;
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy