All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.datasets.iterator.callbacks;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

/**
 * This callback migrates incoming datasets in round-robin manner, to ensure TDA for ParallelWrapper
 *
 * @author [email protected]
 */
@Slf4j
public class InterleavedDataSetCallback implements DataSetCallback {
    private List workspaces = new ArrayList<>();
    private int bufferSize;
    private int numWorkspaces;

    private boolean isInitialized = false;

    private AtomicLong counterInput = new AtomicLong(0);

    public InterleavedDataSetCallback(int bufferSize) {
        this.bufferSize = bufferSize;
    }

    protected void initializeWorkspaces(long size) {
        WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(size)
                        .overallocationLimit(bufferSize).policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
                        .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL)
                        .policyLearning(LearningPolicy.NONE).build();

        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int cDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        for (int i = 0; i < numDevices; i++) {
            Nd4j.getAffinityManager().unsafeSetDevice(i);
            workspaces.add(Nd4j.getWorkspaceManager().createNewWorkspace(configuration, "IDSC-" + i, i));
        }

        Nd4j.getAffinityManager().unsafeSetDevice(cDevice);
        numWorkspaces = numDevices;
        isInitialized = true;
    }

    @Override
    public void call(DataSet dataSet) {
        if (!isInitialized)
            initializeWorkspaces(dataSet.getMemoryFootprint());

        Nd4j.getExecutioner().commit();

        int currIdx = (int) (counterInput.getAndIncrement() % numWorkspaces);
        MemoryWorkspace currWs = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(workspaces.get(currIdx));

        dataSet.migrate();

        Nd4j.getMemoryManager().setCurrentWorkspace(currWs);
    }

    @Override
    public void call(MultiDataSet multiDataSet) {
        if (!isInitialized)
            initializeWorkspaces(multiDataSet.getMemoryFootprint());

        Nd4j.getExecutioner().commit();

        int currIdx = (int) (counterInput.getAndIncrement() % numWorkspaces);
        MemoryWorkspace currWs = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(workspaces.get(currIdx));

        multiDataSet.migrate();

        Nd4j.getMemoryManager().setCurrentWorkspace(currWs);
    }

    @Override
    public void reset() {
        counterInput.set(0);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy