All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.memory.BasicMemoryManager Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.api.memory;

import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

public abstract class BasicMemoryManager implements MemoryManager {
    protected AtomicInteger frequency = new AtomicInteger(0);
    protected AtomicLong freqCounter = new AtomicLong(0);

    protected AtomicLong lastGcTime = new AtomicLong(System.currentTimeMillis());

    protected AtomicBoolean periodicEnabled = new AtomicBoolean(false);

    protected AtomicInteger averageLoopTime = new AtomicInteger(0);

    protected AtomicInteger noGcWindow = new AtomicInteger(100);

    protected AtomicBoolean averagingEnabled = new AtomicBoolean(false);

    protected static final int intervalTail = 100;

    protected Queue intervals = new ConcurrentLinkedQueue<>();

    private ThreadLocal workspace = new ThreadLocal<>();

    private ThreadLocal tempWorkspace = new ThreadLocal<>();

    /**
     * This method returns
     * PLEASE NOTE: Cache options
     * depend on specific implementations
     *
     * @param bytes
     * @param kind
     * @param initialize
     */
    @Override
    public Pointer allocate(long bytes, MemoryKind kind, boolean initialize) {
        throw new UnsupportedOperationException("This method isn't available for this backend");
    }

    /**
     * This method detaches off-heap memory from passed INDArray instances, and optionally stores them in cache for future reuse
     * PLEASE NOTE: Cache options depend on specific implementations
     *
     * @param arrays
     */
    @Override
    public void collect(INDArray... arrays) {
        throw new UnsupportedOperationException("This method isn't implemented yet");
    }

    @Override
    public void toggleAveraging(boolean enabled) {
        averagingEnabled.set(enabled);
    }

    /**
     * This method purges all cached memory chunks
     *
     */
    @Override
    public void purgeCaches() {
        //No op for CPU (no cache)
    }

    @Override
    public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
        val perfD = PerformanceTracker.getInstance().helperStartTransaction();

        Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(),
                        srcBuffer.length() * srcBuffer.getElementSize());

        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, srcBuffer.length() * srcBuffer.getElementSize(), MemcpyDirection.HOST_TO_HOST);
    }

    @Override
    public void notifyScopeEntered() {
        // TODO: to be implemented
    }

    @Override
    public void notifyScopeLeft() {
        // TODO: to be implemented
    }

    @Override
    public void invokeGcOccasionally() {
        long currentTime = System.currentTimeMillis();

        if (averagingEnabled.get())
            intervals.add((int) (currentTime - lastGcTime.get()));

        // not sure if we want to conform autoGcWindow here...
        if (frequency.get() > 0)
            if (freqCounter.incrementAndGet() % frequency.get() == 0
                            && currentTime > getLastGcTime() + getAutoGcWindow()) {
                System.gc();
                lastGcTime.set(System.currentTimeMillis());
            }

        if (averagingEnabled.get())
            if (intervals.size() > intervalTail)
                intervals.remove();
    }

    @Override
    public void invokeGc() {
        System.gc();
        lastGcTime.set(System.currentTimeMillis());
    }

    @Override
    public boolean isPeriodicGcActive() {
        return periodicEnabled.get();
    }

    @Override
    public void setOccasionalGcFrequency(int frequency) {
        this.frequency.set(frequency);
    }

    @Override
    public void setAutoGcWindow(int windowMillis) {
        noGcWindow.set(windowMillis);
    }

    @Override
    public int getAutoGcWindow() {
        return noGcWindow.get();
    }

    @Override
    public int getOccasionalGcFrequency() {
        return frequency.get();
    }

    @Override
    public long getLastGcTime() {
        return lastGcTime.get();
    }

    @Override
    public void togglePeriodicGc(boolean enabled) {
        periodicEnabled.set(enabled);
    }

    @Override
    public int getAverageLoopTime() {
        if (averagingEnabled.get()) {
            int cnt = 0;
            for (Integer value : intervals) {
                cnt += value;
            }
            cnt /= intervals.size();
            return cnt;
        } else
            return 0;

    }

    @Override
    public MemoryWorkspace getCurrentWorkspace() {
        return workspace.get();
    }

    @Override
    public void setCurrentWorkspace(MemoryWorkspace workspace) {
        this.workspace.set(workspace);
    }


    @Override
    public MemoryWorkspace scopeOutOfWorkspaces() {
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (workspace == null)
            return new DummyWorkspace();
        else {
            //Nd4j.getMemoryManager().setCurrentWorkspace(null);
            return new DummyWorkspace().notifyScopeEntered();//workspace.tagOutOfScopeUse();
        }
    }

    @Override
    public void releaseCurrentContext() {
        // no-op
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy