All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.lang.cudnn.CudaTensorList Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.lang.cudnn;

import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.UncheckedSupplier;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefIgnore;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.function.IntFunction;

public class CudaTensorList extends ReferenceCountingBase implements TensorList, CudaSystem.CudaDeviceResource {
  public static final Logger logger = LoggerFactory.getLogger(CudaTensorList.class);
  public final StackTraceElement[] createdBy = CudaSettings.INSTANCE().profileMemoryIO ? Util.getStackTrace()
      : new StackTraceElement[]{};
  @Nonnull
  private final int[] dimensions;
  private final int length;
  @Nullable
  volatile CudaTensor cudaTensor;
  @Nullable
  volatile TensorArray heapCopy = null;

  public CudaTensorList(@Nullable final CudaTensor cudaTensor, final int length, @Nonnull final int[] dimensions,
                        @Nonnull final Precision precision) {
    //assert 1 == ptr.currentRefCount() : ptr.referenceReport(false, false);
    if (null == cudaTensor) {
      throw new IllegalArgumentException("ptr");
    }
    if (null == cudaTensor.memory.getPtr()) {
      cudaTensor.freeRef();
      throw new IllegalArgumentException("ptr.getPtr()");
    }
    if (length <= 0) {
      cudaTensor.freeRef();
      throw new IllegalArgumentException();
    }
    if (Tensor.length(dimensions) <= 0) {
      cudaTensor.freeRef();
      throw new IllegalArgumentException();
    }
    if (cudaTensor.memory.size < (long) (length - 1) * Tensor.length(dimensions) * precision.size) {
      String message = String.format("%s < %s", cudaTensor.memory.size, (long) length * Tensor.length(dimensions) * precision.size);
      cudaTensor.freeRef();
      throw new AssertionError(message);
    }
    if (cudaTensor.descriptor.batchCount != length) {
      cudaTensor.freeRef();
      throw new AssertionError();
    }
    if (cudaTensor.descriptor.channels != (dimensions.length < 3 ? 1 : dimensions[2])) {
      String message = RefString.format(
          "%s != (%d,%d,%d,%d)", RefArrays.toString(dimensions), cudaTensor.descriptor.batchCount, cudaTensor.descriptor.channels,
          cudaTensor.descriptor.height, cudaTensor.descriptor.width);
      cudaTensor.freeRef();
      throw new AssertionError(message);
    }
    if (cudaTensor.descriptor.height != (dimensions.length < 2 ? 1 : dimensions[1])) {
      String message = RefString.format(
          "%s != (%d,%d,%d,%d)", RefArrays.toString(dimensions), cudaTensor.descriptor.batchCount, cudaTensor.descriptor.channels,
          cudaTensor.descriptor.height, cudaTensor.descriptor.width);
      cudaTensor.freeRef();
      throw new AssertionError(message);
    }
    if (cudaTensor.descriptor.width != (dimensions.length < 1 ? 1 : dimensions[0])) {
      String message = RefString.format("%s != (%d,%d,%d,%d)",
          RefArrays.toString(dimensions), cudaTensor.descriptor.batchCount, cudaTensor.descriptor.channels, cudaTensor.descriptor.height,
          cudaTensor.descriptor.width);
      cudaTensor.freeRef();
      throw new AssertionError(message);
    }
    if (cudaTensor.getPrecision() != precision) {
      cudaTensor.freeRef();
      throw new AssertionError();
    }
    if (cudaTensor.memory.getPtr() == null) {
      cudaTensor.freeRef();
      throw new AssertionError();
    }
    setCudaTensor(cudaTensor);
    this.dimensions = RefArrays.copyOf(dimensions, dimensions.length);
    this.length = length;
    ObjectRegistry.register(this.addRef());
    //assert this.stream().flatMapToDouble(x-> Arrays.stream(x.getData())).allMatch(v->Double.isFinite(v));
  }

  public int getDeviceId() {
    return null == cudaTensor ? -1 : cudaTensor.getDeviceId();
  }

  @Nonnull
  @Override
  public int[] getDimensions() {
    return RefArrays.copyOf(dimensions, dimensions.length);
  }

  public int getLength() {
    return length;
  }

  /**
   * The Precision.
   */
  @Nonnull
  public Precision getPrecision() {
    return null == cudaTensor ? null : cudaTensor.getPrecision();
  }

  void setCudaTensor(CudaTensor cudaTensor) {
    synchronized (this) {
      if (cudaTensor != this.cudaTensor) {
        RefUtil.freeRef(this.cudaTensor);
        this.cudaTensor = cudaTensor;
      } else {
        cudaTensor.freeRef();
      }
    }
  }

  private synchronized void setHeapCopy(TensorArray toHeap) {
    if (toHeap != heapCopy && !hasHeapCopy()) {
      if (hasHeapCopy()) {
        this.heapCopy.freeRef();
      }
      this.heapCopy = toHeap;
    } else if (null != toHeap) {
      toHeap.freeRef();
    }
  }

  public static long evictToHeap(int deviceId) {
    return CudaSystem.withDevice(deviceId, gpu -> {
      long size = ObjectRegistry.getLivingInstances(CudaTensorList.class).filter(cudaTensorList -> {
        int tensorDevice = cudaTensorList.getDeviceId();
        boolean inDevice = cudaTensorList.cudaTensor != null
            && (tensorDevice == deviceId || deviceId < 0 || tensorDevice < 0);
        cudaTensorList.freeRef();
        return inDevice;
      }).mapToLong(cudaTensorList -> {
        long toHeap = cudaTensorList.evictToHeap();
        cudaTensorList.freeRef();
        return toHeap;
      }).sum();
      gpu.freeRef();
      logger.debug(RefString.format("Cleared %s bytes from GpuTensorLists for device %s", size, deviceId));
      return size;
    });
  }

  @Nonnull
  public static CudaTensorList create(@Nullable final CudaMemory ptr, CudaDevice.CudaTensorDescriptor descriptor,
                                      final int length, @Nonnull final int[] dimensions, @Nonnull final Precision precision) {
    return new CudaTensorList(
        new CudaTensor(ptr, descriptor, precision),
        length, dimensions, precision);
  }

  @Override
  public TensorList addAndFree(@Nonnull final TensorList right) {
    assertAlive();
    right.assertAlive();
    try {
      if (right instanceof ReshapedTensorList) {
        return addAndFree(((ReshapedTensorList) right).getInner());
      }
      if (1 < currentRefCount()) {
        return add(right.addRef());
      }
      assert length() == right.length();
      if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
          @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right.addRef();
          if (nativeRight.getPrecision() == this.getPrecision()) {
            if (nativeRight.heapCopy == null) {
              assert nativeRight.cudaTensor != this.cudaTensor;
              int deviceId = this.cudaTensor.getDeviceId();
              return CudaSystem
                  .run(RefUtil.wrapInterface((RefFunction) gpu -> {
                        try {
                          if (gpu.getDeviceId() == deviceId) {
                            return gpu.addInPlace(this.addRef(), nativeRight.addRef());
                          } else {
                            assertAlive();
                            right.assertAlive();
                            return add(right.addRef());
                          }
                        } finally {
                          gpu.freeRef();
                        }
                      }, nativeRight, right.addRef()), this.addRef(),
                      right.addRef());
            }
          }
          nativeRight.freeRef();
        }
      }
      if (right.length() == 0) {
        return this.addRef();
      }
      if (length() == 0) {
        throw new IllegalArgumentException();
      }
      assert length() == right.length();
      return new TensorArray(
          RefIntStream.range(0, length()).mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
            return Tensor.add(get(i), right.get(i));
          }, right.addRef())).toArray(i -> new Tensor[i]));
    } finally {
      right.freeRef();
    }
  }

  @Override
  public TensorList add(@Nonnull final TensorList right) {
    try {
      assertAlive();
      right.assertAlive();
      assert length() == right.length();
      if (right instanceof ReshapedTensorList) {
        return add(((ReshapedTensorList) right).getInner());
      }
      if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
          @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right.addRef();
          if (nativeRight.getPrecision() == this.getPrecision()) {
            if (nativeRight.heapCopy == null) {
              return CudaSystem
                  .run(RefUtil.wrapInterface((RefFunction) gpu -> {
                    CudaTensorList add = gpu.add(this.addRef(), nativeRight.addRef());
                    gpu.freeRef();
                    return add;
                  }, nativeRight), this.addRef());
            }
          }
          nativeRight.freeRef();
        }
      }
      if (right.length() == 0) {
        return this.addRef();
      }
      if (length() == 0) {
        throw new IllegalArgumentException();
      }
      assert length() == right.length();
      return new TensorArray(
          RefIntStream.range(0, length()).mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
            return Tensor.add(get(i), right.get(i));
          }, right.addRef())).toArray(i -> new Tensor[i]));
    } finally {
      right.freeRef();
    }
  }

  @Override
  @Nonnull
  @RefAware
  public Tensor get(final int i) {
    assertAlive();
    if (heapCopy != null)
      return heapCopy.get(i);
    CudaTensor cudaTensor = this.cudaTensor.addRef();
    return CudaSystem.run(RefUtil.wrapInterface((RefFunction) gpu -> {
      TimedResult timedResult = TimedResult.time(RefUtil.wrapInterface((UncheckedSupplier) () -> {
        assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
        Tensor t = new Tensor(getDimensions());
        if (cudaTensor.isDense()) {
          CudaMemory memory = cudaTensor.getMemory(gpu.addRef());
          assert memory != null;
          memory.read(cudaTensor.getPrecision(), t.addRef(), i * Tensor.length(getDimensions()));
          memory.freeRef();
        } else {
          cudaTensor.read(gpu.addRef(), i, t.addRef(), false);
        }
        return t;
      }, cudaTensor.addRef(), gpu));
      Tensor result = timedResult.getResult();
      if (CudaTensorList.logger.isDebugEnabled()) CudaTensorList.logger.debug(RefString.format(
          "Read %s bytes in %.4f from Tensor %s, GPU at %s, created by %s",
          cudaTensor.size(),
          timedResult.seconds(),
          Integer.toHexString(RefSystem.identityHashCode(RefUtil.addRef(result))),
          Util.toString(Util.getStackTrace()).replaceAll("\n", "\n\t"),
          Util.toString(createdBy).replaceAll("\n", "\n\t")
      ));
      timedResult.freeRef();
      return result;
    }, cudaTensor), this.addRef());
  }

  @Override
  public int length() {
    return getLength();
  }

  @Nonnull
  @Override
  public RefStream stream() {
    TensorArray heapCopy = heapCopy();
    assert heapCopy != null;
    RefStream stream = heapCopy.stream();
    heapCopy.freeRef();
    return stream;
  }

  @Override
  public TensorList copy() {
    return CudaSystem.run(gpu -> {
      CudaTensor ptr = gpu.getTensor(this.addRef(), MemoryType.Device, false);
      assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
      CudaMemory cudaMemory = ptr.getMemory(gpu.addRef(), MemoryType.Device);
      assert cudaMemory != null;
      CudaMemory copyPtr = cudaMemory.copy(gpu, MemoryType.Managed.ifEnabled());
      cudaMemory.freeRef();
      Precision precision = getPrecision();
      CudaTensor cudaTensor = new CudaTensor(copyPtr, ptr.descriptor.addRef(), precision);
      ptr.freeRef();
      return new CudaTensorList(cudaTensor, getLength(), getDimensions(), precision);
    }, this.addRef());
  }

  @RefIgnore
  public long evictToHeap() {
    if (isFreed())
      return 0;
    TensorArray heapCopy = heapCopy(true);
    if (null == heapCopy) {
      throw new IllegalStateException();
    }
    heapCopy.freeRef();
    final CudaTensor ptr;
    synchronized (this) {
      ptr = this.cudaTensor;
      this.cudaTensor = null;
    }
    if (null != ptr && !ptr.isFreed()) {
      long elements = getElements();
      assert 0 < length;
      assert 0 < elements : RefArrays.toString(dimensions);
      long size = elements * ptr.getPrecision().size;
      ptr.freeRef();
      return size;
    } else {
      synchronized (this) {
        if (null != this.cudaTensor)
          this.cudaTensor.freeRef();
        this.cudaTensor = ptr;
      }
      return 0;
    }
  }

  public void _free() {
    super._free();
    synchronized (this) {
      if (null != cudaTensor) {
        cudaTensor.freeRef();
        cudaTensor = null;
      }
      if (null != heapCopy) {
        heapCopy.freeRef();
        heapCopy = null;
      }
    }
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  CudaTensorList addRef() {
    return (CudaTensorList) super.addRef();
  }

  @Nullable
  private TensorArray heapCopy() {
    return heapCopy(false);
  }

  @Nullable
  private TensorArray heapCopy(final boolean avoidAllocations) {
    if (!hasHeapCopy()) {
      setHeapCopy(toHeap(avoidAllocations));
    }
    return this.heapCopy.addRef();
  }

  private boolean hasHeapCopy() {
    return null != heapCopy && !heapCopy.isFreed();
  }

  @Nullable
  @RefIgnore
  private TensorArray toHeap(final boolean avoidAllocations) {
    assertAlive();
    CudaTensor cudaTensor = this.cudaTensor;
    if (cudaTensor.tryAddRef()) {
      return toHeap(avoidAllocations, cudaTensor);
    } else {
      if (null == heapCopy) {
        throw new IllegalStateException("No data");
      } else if (heapCopy.isFreed()) {
        throw new IllegalStateException("Local data has been freed");
      } else {
        return heapCopy.addRef();
      }
    }
  }

  private TensorArray toHeap(boolean avoidAllocations, CudaTensor gpuCopy) {
    try {
      if (null == gpuCopy) return null;
      int length = getLength();
      if (0 >= length) {
        throw new IllegalStateException();
      }
      final Tensor[] output = RefIntStream.range(0, length)
          .mapToObj(dataIndex -> new Tensor(getDimensions()))
          .toArray(i -> new Tensor[i]);
      TimedResult timedResult = TimedResult
          .time(RefUtil.wrapInterface((UncheckedSupplier) () -> CudaDevice
              .run(gpu -> {
                    assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
                    //        assert getPrecision() == gpuCopy.getPrecision();
                    //        assert getPrecision() == gpuCopy.descriptor.dataType;
                    for (int i = 0; i < length; i++) {
                      gpuCopy.read(gpu.addRef(), i, output[i].addRef(), avoidAllocations);
                    }
                    gpu.freeRef();
                    return new TensorArray(RefUtil.addRef(output));
                  }, this.addRef()
              ), output));
      TensorArray result = timedResult.getResult();
      if (CudaTensorList.logger.isDebugEnabled()) {
        CudaTensorList.logger.debug(RefString.format(
            "Read %s bytes in %.4f from Tensor %s on GPU at %s, created by %s",
            gpuCopy.size(),
            timedResult.seconds(),
            Integer.toHexString(RefSystem.identityHashCode(result.addRef())),
            Util.toString(Util.getStackTrace()).replaceAll("\n", "\n\t"),
            Util.toString(createdBy).replaceAll("\n", "\n\t"))
        );
      }
      timedResult.freeRef();
      return result;
    } finally {
      if (null != gpuCopy) gpuCopy.freeRef();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy