All Downloads are FREE. Search and download functionalities are using the official Maven repository.

lamp.io.npy.package.scala Maven / Gradle / Ivy

The newest version!
package lamp.io

import java.nio.channels.ReadableByteChannel
import lamp.Device
import lamp.Scope
import aten.ATen
import lamp.STenOptions
import lamp.CPU
import lamp.STen
import aten.Tensor
import java.io.File

/** This package provides methods to read NPY formatted data into STen tensors
  *
  * The data is first read into to a regular JVM array, then transferred to
  * off-heap memory. The total tensor size may be larger than what a single JVM
  * array can hold.
  */
package object npy {

  def readDoubleFromChannel(
      channel: ReadableByteChannel,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromChannel(7, channel, device)
  def readFloatFromChannel(
      channel: ReadableByteChannel,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromChannel(6, channel, device)
  def readLongFromChannel(
      channel: ReadableByteChannel,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromChannel(4, channel, device)

  def readDoubleFromFile(
      file: File,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromFile(7, file, device)
  def readFloatFromFile(
      file: File,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromFile(6, file, device)
  def readLongFromFile(
      file: File,
      device: Device
  )(implicit
      scope: Scope
  ) = readFromFile(4, file, device)

  def readFromFile(
      scalarType: Byte,
      file: File,
      device: Device
  )(implicit
      scope: Scope
  ) = {
    val fis = new java.io.FileInputStream(file)
    val channel = fis.getChannel
    try {
      readFromChannel(scalarType, channel, device)
    } finally {
      fis.close
    }
  }

  def readFromChannel(
      scalarType: Byte,
      channel: ReadableByteChannel,
      device: Device
  )(implicit
      scope: Scope
  ) = {
    val (descriptor, iterator, topt, copy) = scalarType match {
      case 4 =>
        val dtype = org.saddle.io.npy.LongType
        val topt = STenOptions.l.value
        val copy = (ar: Array[_], offset: Long, t: Tensor) =>
          assert(
            t.copyFromLongArrayAtOffset(ar.asInstanceOf[Array[Long]], offset)
          )
        val (descriptor, iterator) =
          org.saddle.io.npy
            .readFromChannel(dtype, channel)
            .toOption
            .get
        (descriptor, iterator, topt, copy)
      case 6 =>
        val dtype = org.saddle.io.npy.FloatType
        val topt = STenOptions.f.value
        val copy = (ar: Array[_], offset: Long, t: Tensor) =>
          assert(
            t.copyFromFloatArrayAtOffset(ar.asInstanceOf[Array[Float]], offset)
          )
        val (descriptor, iterator) =
          org.saddle.io.npy
            .readFromChannel(dtype, channel)
            .toOption
            .get
        (descriptor, iterator, topt, copy)
      case 7 =>
        val dtype = org.saddle.io.npy.DoubleType
        val topt = STenOptions.d.value
        val copy = (ar: Array[_], offset: Long, t: Tensor) =>
          assert(
            t.copyFromDoubleArrayAtOffset(
              ar.asInstanceOf[Array[Double]],
              offset
            )
          )
        val (descriptor, iterator) =
          org.saddle.io.npy
            .readFromChannel(dtype, channel)
            .toOption
            .get
        (descriptor, iterator, topt, copy)
    }

    assert(!descriptor.fortran, "Fortran (column-wise) layout not supported")
    val dim = descriptor.shape
    val t = ATen.zeros(
      dim.toArray,
      topt
    )
    var offset = 0L
    iterator.foreach { arr =>
      val ar = arr.toOption.get.asInstanceOf[Array[Double]]
      copy(ar, offset, t)
      offset += ar.length
    }
    assert(offset == dim.foldLeft(1L)(_ * _), "Premature end")
    if (device != CPU) {
      val t2 = device.to(t)
      t.release
      STen.owned(t2)
    } else STen.owned(t)

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy