All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.jita.flow.impl.SynchronousFlowController Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.jita.flow.impl;


import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.concurrency.EventsProvider;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author [email protected]
 */
public class SynchronousFlowController implements FlowController {
    private static Logger log = LoggerFactory.getLogger(SynchronousFlowController.class);
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    @Getter
    protected EventsProvider eventsProvider = new EventsProvider();

    @Override
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    /**
     * This method makes sure HOST memory contains latest data from GPU
     *
     * @param point
     */
    @Override
    public void synchronizeToHost(AllocationPoint point) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer());
    }

    @Override
    public void synchronizeToDevice(@NonNull AllocationPoint point) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer());
    }

    @Override
    public void waitTillFinished(AllocationPoint point) {
        // this should be always null, since synchronization happens in C++ now
        if (point.getLastWriteEvent() != null) {
            point.getLastWriteEvent().synchronize();
        }
    }


    @Override
    public CudaContext prepareActionAllWrite(INDArray... operands) {
        val context = allocator.getDeviceContext();
        val cId = allocator.getDeviceId();

        for (INDArray operand : operands) {
            if (operand == null || operand.isEmpty())
                continue;

            Nd4j.getCompressor().autoDecompress(operand);

            val pointData = allocator.getAllocationPoint(operand);
            val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());


            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
                                : operand.data().originalDataBuffer();
                allocator.getMemoryHandler().relocateObject(buffer);
            }

            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray) operand).setShapeInfoDataBuffer(
                                Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
            }

            prepareDelayedMemory(operand);
            allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public CudaContext prepareAction(INDArray result, INDArray... operands) {
        val context = allocator.getDeviceContext();
        val cId = allocator.getDeviceId();


        if (result != null && !result.isEmpty()) {
            Nd4j.getCompressor().autoDecompress(result);
            prepareDelayedMemory(result);
            val pointData = allocator.getAllocationPoint(result);
            val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());

            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data()
                                : result.data().originalDataBuffer();
                allocator.getMemoryHandler().relocateObject(buffer);
            }

            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray) result).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(result.shape(), result.stride(), result.elementWiseStride(), result.ordering(), result.dataType(), result.isEmpty()));
            }

            allocator.getAllocationPoint(result).setCurrentContext(context);
        }

        if (operands == null)
            return context;

        for (INDArray operand : operands) {
            // empty or String arrays can be skipped
            if (operand == null || operand.isEmpty() || operand.isS())
                continue;

            Nd4j.getCompressor().autoDecompress(operand);

            val pointData = allocator.getAllocationPoint(operand);
            val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE);

            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data()
                                : operand.data().originalDataBuffer();
                allocator.getMemoryHandler().relocateObject(buffer);
            }

            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(operand.shape(), operand.stride(), operand.elementWiseStride(), operand.ordering(), operand.dataType(), operand.isEmpty()));
            }

            prepareDelayedMemory(operand);
            allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void waitTillReleased(AllocationPoint point) {
        waitTillFinished(point);

        if (point.getLastReadEvent() != null)
            point.getLastReadEvent().synchronize();
    }

    @Override
    public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
        // this method is irrelevant now, everything happens in C++ now
        /*
        eventsProvider.storeEvent(result.getLastWriteEvent());
        result.setLastWriteEvent(eventsProvider.getEvent());
        result.getLastWriteEvent().register(context.getOldStream());


        for (AllocationPoint operand : operands) {
            eventsProvider.storeEvent(operand.getLastReadEvent());
            operand.setLastReadEvent(eventsProvider.getEvent());
            operand.getLastReadEvent().register(context.getOldStream());
        }
        //   context.syncOldStream();
        */
    }

    @Override
    public void registerActionAllWrite(CudaContext context, INDArray... operands) {
        for (INDArray operand : operands) {
            if (operand == null)
                continue;

            val pointOperand = allocator.getAllocationPoint(operand);
            pointOperand.tickDeviceWrite();
        }
    }

    public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
        if (result == null || result.isEmpty())
            return;

        val point = allocator.getAllocationPoint(result);
        point.tickDeviceWrite();

        for (INDArray operand : operands) {
            if (operand == null || operand.isEmpty())
                continue;

            val pointOperand = allocator.getAllocationPoint(operand);
            pointOperand.tickDeviceRead();
        }
    }

    @Override
    public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) {
        val context = allocator.getDeviceContext();

        if (result != null) {
            result.setCurrentContext(context);
        }

        for (AllocationPoint operand : operands) {
            if (operand == null)
                continue;

            operand.setCurrentContext(context);
        }

        return context;
    }

    @Override
    public void commitTransfer(cudaStream_t streamUsed) {
        streamUsed.synchronize();
    }

    protected void prepareDelayedMemory(INDArray array) {
        if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            val pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            val pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());

            if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
                prepareDelayedMemory(array.data());

            if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
                val oShape = array.shapeInfoDataBuffer();
                val nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);

                if (nShape == oShape)
                    Nd4j.getConstantHandler().moveToConstantSpace(nShape);
                ((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
            }
        }
    }

    protected void prepareDelayedMemory(DataBuffer buffer) {

        allocator.getMemoryHandler().promoteObject(buffer);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy