All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.ndarray.BaseNDManager Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.engine.StandardCapabilities;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import ai.djl.util.RandomUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** {@code BaseNDManager} is the default implementation of {@link NDManager}. */
public abstract class BaseNDManager implements NDManager {

    private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class);

    protected NDManager parent;
    protected NDManager alternativeManager;
    protected String uid;
    protected String name;
    protected Device device;
    protected ConcurrentHashMap resources;
    protected ConcurrentHashMap tempResources;
    protected AtomicBoolean closed = new AtomicBoolean(false);
    protected AtomicBoolean capped = new AtomicBoolean(false);

    @SuppressWarnings("this-escape")
    protected BaseNDManager(NDManager parent, Device device) {
        this.parent = parent;
        this.device = device == null ? defaultDevice() : device;
        resources = new ConcurrentHashMap<>();
        tempResources = new ConcurrentHashMap<>();
        uid = UUID.randomUUID().toString();
        Engine engine = getEngine().getAlternativeEngine();
        if (engine != null) {
            // Use the same device if possible for efficiency
            if (this.device.isGpu() && engine.hasCapability(StandardCapabilities.CUDA)) {
                alternativeManager = engine.newBaseManager(this.device);
            } else {
                // Use the default device
                alternativeManager = engine.newBaseManager();
            }
        }
    }

    /** {@inheritDoc} */
    @Override
    public final Device defaultDevice() {
        return getEngine().defaultDevice();
    }

    /** {@inheritDoc} */
    @Override
    public NDArray create(String[] data, Charset charset, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray create(Shape shape, DataType dataType) {
        return zeros(shape, dataType);
    }

    /** {@inheritDoc} */
    @Override
    public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray createCoo(Buffer data, long[][] indices, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDList load(Path path) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public void setName(String name) {
        this.name = name;
    }

    /** {@inheritDoc} */
    @Override
    public String getName() {
        return this.name == null ? uid : this.name;
    }

    /** {@inheritDoc} */
    @Override
    public NDArray full(Shape shape, float value, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray arange(float start, float stop, float step, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray eye(int rows, int cols, int k, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray linspace(float start, float stop, int num, boolean endpoint) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomPermutation(long n) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
        int sampleSize = (int) shape.size();
        double[] dist = new double[sampleSize];

        for (int i = 0; i < sampleSize; i++) {
            double sample = RandomUtils.nextGaussian();
            while (sample < -2 || sample > 2) {
                sample = RandomUtils.nextGaussian();
            }

            dist[i] = sample;
        }

        return create(dist).muli(scale).addi(loc).reshape(shape).toType(dataType, false);
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomMultinomial(int n, NDArray pValues) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sampleNormal(NDArray mu, NDArray sigma) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sampleNormal(NDArray mu, NDArray sigma, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray samplePoisson(NDArray lam) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray samplePoisson(NDArray lam, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sampleGamma(NDArray alpha, NDArray beta) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDArray sampleGamma(NDArray alpha, NDArray beta, Shape shape) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public boolean isOpen() {
        return !closed.get();
    }

    /** {@inheritDoc} */
    @Override
    public void cap() {
        this.capped.set(true);
    }

    /** {@inheritDoc} */
    @Override
    public NDManager getParentManager() {
        return parent;
    }

    /** {@inheritDoc} */
    @Override
    public NDManager newSubManager() {
        return newSubManager(device);
    }

    /** {@inheritDoc} */
    @Override
    public Device getDevice() {
        return device;
    }

    /** {@inheritDoc} */
    @Override
    public List getManagedArrays() {
        return Stream.concat(
                        // Main resources
                        resources.values().stream()
                                .flatMap(
                                        r -> {
                                            if (r instanceof NDResource) {
                                                return ((NDResource) r)
                                                        .getResourceNDArrays().stream();
                                            } else if (r instanceof NDManager) {
                                                return ((NDManager) r).getManagedArrays().stream();
                                            } else {
                                                return Stream.empty();
                                            }
                                        }),

                        // Temp resouces
                        tempResources.values().stream()
                                .flatMap(tr -> tr.resource.getResourceNDArrays().stream()))
                .collect(Collectors.toList());
    }

    /** {@inheritDoc} */
    @Override
    public String toString() {
        String parentName = parent == null ? "No Parent" : parent.getName();
        return "Name: "
                + getName()
                + " Parent Name: "
                + parentName
                + " isOpen: "
                + isOpen()
                + " Resource size: "
                + resources.size();
    }

    /** {@inheritDoc} */
    @Override
    public synchronized void attachInternal(String resourceId, AutoCloseable... resources) {
        if (capped.get()) {
            throw new IllegalStateException("NDManager is capped for addition of resources.");
        }
        for (int i = 0; i < resources.length; i++) {
            attachUncappedInternal(
                    resources.length == 1 ? resourceId : resourceId + "_" + i, resources[i]);
        }
    }

    /** {@inheritDoc} */
    @Override
    public synchronized void attachUncappedInternal(String resourceId, AutoCloseable resource) {
        if (closed.get()) {
            throw new IllegalStateException("NDManager has been closed already.");
        }
        tempResources.compute(
                resourceId,
                (key, tempResource) -> {
                    if (tempResource != null) {
                        // This state occurs when this manager (manA) tempAttaches a resource that
                        // is later
                        // tempAttached to another manager (manB)
                        // When manB is closed, it will use attach to return the resource to this
                        // (manA)
                        // In that case, it should stay as a tempResource in this (manA)
                        tempResource.detached = false;
                    } else {
                        resources.put(resourceId, resource);
                    }
                    return tempResource;
                });
    }

    /** {@inheritDoc} */
    @Override
    public void tempAttachInternal(
            NDManager originalManager, String resourceId, NDResource resource) {
        if (this instanceof SystemNDManager) {
            throw new IllegalStateException(
                    "System manager cannot be temp attached because it can't be closed..");
        }
        if (closed.get()) {
            throw new IllegalStateException("NDManager has been closed already.");
        }
        tempResources.put(resourceId, new TempResource(resource, originalManager));
    }

    /** {@inheritDoc} */
    @Override
    public synchronized void detachInternal(String resourceId) {
        if (closed.get()) {
            // This may happen in the middle of BaseNDManager.close()
            return;
        }
        tempResources.computeIfPresent(
                resourceId,
                (key, tempResource) -> {
                    tempResource.detached = true;
                    return tempResource;
                });
        resources.remove(resourceId);
    }

    /** {@inheritDoc} */
    @Override
    public void invoke(
            String operation, NDArray[] src, NDArray[] dest, PairList params) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public NDList invoke(String operation, NDList src, PairList params) {
        throw new UnsupportedOperationException("Not supported!");
    }

    /** {@inheritDoc} */
    @Override
    public void close() {
        if (this instanceof SystemNDManager) {
            throw new IllegalStateException(
                    "The SystemNDManager can not be closed. It is global and lives for the duration"
                            + " of the process");
        }
        if (!closed.getAndSet(true)) {
            for (AutoCloseable closeable : resources.values()) {
                try {
                    closeable.close();
                } catch (Exception e) {
                    logger.error("Resource close failed.", e);
                }
            }
            for (TempResource resource : tempResources.values()) {
                resource.returnResource();
            }
            parent.detachInternal(uid);
            resources.clear();
            tempResources.clear();
        }
    }

    /**
     * Prints information about this {@link NDManager} and all sub-managers to the console.
     *
     * @param level the level of this {@link NDManager} in the hierarchy
     */
    public void debugDump(int level) {
        StringBuilder sb = new StringBuilder(100);
        for (int i = 0; i < level; ++i) {
            sb.append("    ");
        }
        sb.append("\\--- NDManager(")
                .append(uid.substring(24))
                .append(") resource count: ")
                .append(resources.size());

        System.out.println(sb); // NOPMD
        for (AutoCloseable c : resources.values()) {
            if (c instanceof BaseNDManager) {
                ((BaseNDManager) c).debugDump(level + 1);
            }
        }
    }

    NDManager getAlternativeManager() {
        return alternativeManager;
    }

    /**
     * Checks if the input buffer size is match expected data type.
     *
     * @param buffer the input buffer
     * @param dataType the desired {@code DataType}
     * @param expected the expected size
     * @throws IllegalArgumentException if buffer size is invalid
     */
    public static void validateBuffer(Buffer buffer, DataType dataType, int expected) {
        boolean isByteBuffer = buffer instanceof ByteBuffer;
        DataType type = DataType.fromBuffer(buffer);
        if (!isCompatible(type, dataType) && !isByteBuffer) {
            // It's ok if type != datatype and buffer is ByteBuffer,
            // since buffer will be copied into ByteBuffer
            throw new IllegalArgumentException(
                    "The input data type: "
                            + type
                            + " does not match target array data type: "
                            + dataType);
        }

        int remaining = buffer.remaining();
        int expectedSize = isByteBuffer ? dataType.getNumOfBytes() * expected : expected;
        if (remaining < expectedSize) {
            throw new IllegalArgumentException(
                    "The NDArray size is: " + expected + ", but buffer size is: " + remaining);
        }
        if (remaining > expectedSize) {
            logger.warn(
                    "Input buffer size is greater than the NDArray size, please set limit"
                            + " explicitly.");
            buffer.limit(expectedSize);
        }
    }

    private static boolean isCompatible(DataType type1, DataType type2) {
        if (type1.getNumOfBytes() != type1.getNumOfBytes()) {
            return false;
        }
        if (type1.getNumOfBytes() == 2) {
            // fp16, bf16, int16, uint16 all uses ShortBuffer
            return true;
        }
        if (type1.getFormat() == type2.getFormat()) {
            return true;
        }
        return type1.isInteger() && type2.isInteger();
    }

    /**
     * Copies data from the source {@code Buffer} to the target {@code ByteBuffer}.
     *
     * @param src the source {@code Buffer}
     * @param target the target {@code ByteBuffer}
     */
    public static void copyBuffer(Buffer src, ByteBuffer target) {
        target.rewind();
        DataType inputType = DataType.fromBuffer(src);
        switch (inputType) {
            case FLOAT16:
                target.asShortBuffer().put((ShortBuffer) src);
                break;
            case FLOAT32:
                target.asFloatBuffer().put((FloatBuffer) src);
                break;
            case FLOAT64:
                target.asDoubleBuffer().put((DoubleBuffer) src);
                break;
            case UINT8:
            case INT8:
            case BOOLEAN:
                target.put((ByteBuffer) src);
                break;
            case INT32:
                target.asIntBuffer().put((IntBuffer) src);
                break;
            case INT64:
                target.asLongBuffer().put((LongBuffer) src);
                break;
            default:
                throw new AssertionError("Unsupported datatype: " + inputType);
        }
        target.rewind();
    }

    protected static final class TempResource {

        private NDResource resource;
        private NDManager manager;
        private boolean detached;

        public TempResource(NDResource resource, NDManager manager) {
            this.resource = resource;
            this.manager = manager;
            this.detached = false;
        }

        public void returnResource() {
            try {
                if (!detached) {
                    if (manager.isOpen()) {
                        resource.returnResource(manager);
                    } else {
                        resource.close();
                    }
                }
            } catch (Exception e) {
                logger.error("Temporary resource return failed.", e);
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy