All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.iterator.SparkAMDSI Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.iterator;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContextHelper;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
import org.nd4j.linalg.factory.Nd4j;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;

@Slf4j
public class SparkAMDSI extends AsyncMultiDataSetIterator {
    protected TaskContext context;

    protected SparkAMDSI() {
        super();
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue queue) {
        this(iterator, queueSize, queue, true);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue(queueSize));
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue(queueSize), useWorkspace);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue(queueSize), useWorkspace,
                        new DefaultCallback(), deviceId);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace,
                    DataSetCallback callback) {
        this(baseIterator, queueSize, new LinkedBlockingQueue(queueSize), useWorkspace, callback);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue queue,
                    boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, null);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue queue,
                    boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue queue,
                    boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        this();

        if (queueSize < 2)
            queueSize = 2;

        this.callback = callback;
        this.buffer = queue;
        this.backedIterator = iterator;
        this.useWorkspaces = useWorkspace;
        this.prefetchSize = queueSize;
        this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString();
        this.deviceId = deviceId;

        if (iterator.resetSupported())
            this.backedIterator.reset();

        this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread());

        context = TaskContext.get();

        thread.setDaemon(true);
        thread.start();
    }

    @Override
    protected void externalCall() {
        TaskContextHelper.setTaskContext(context);
    }

    protected class SparkPrefetchThread extends AsyncPrefetchThread {

        protected SparkPrefetchThread(@NonNull BlockingQueue queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) {
            super(queue, iterator, terminator, deviceId);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy