All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.combust.mleap.tensorflow.converter.MleapConverter.scala Maven / Gradle / Ivy

The newest version!
package ml.combust.mleap.tensorflow.converter

import ml.combust.mleap.tensor.DenseTensor

import org.tensorflow.ndarray.{NdArray, NdArraySequence, NdArrays, Shape}
import org.tensorflow.types.TFloat32
import org.tensorflow.types.TFloat64
import org.tensorflow.types.TInt32
import org.tensorflow.types.TInt64
import org.tensorflow.types.TString
import org.tensorflow.types.TUint8
import org.tensorflow.ndarray.buffer.DataBuffers
import org.tensorflow

import scala.reflect.ClassTag
import ml.combust.mleap.tensor.{ByteString, Tensor}

import java.util.function.BiConsumer

/**
  * Created by hollinwilkins on 1/12/17.
  */

object MleapConverter {
  def convert[T: ClassTag](value: Tensor[T]): tensorflow.Tensor = {
    val dense = value.toDense
    val shape = Shape.of(dense.dimensions.map(_.toLong).toArray: _*)
    value.base.runtimeClass match {
      case Tensor.ByteClass =>
        TUint8.tensorOf(
          shape,
          DataBuffers.of(
              dense.values.asInstanceOf[Array[Byte]], true,false
          )
        )
      case Tensor.IntClass =>
        TInt32.tensorOf(
          shape,
          DataBuffers.of(
              dense.values.asInstanceOf[Array[Int]], true, false
          )
        )

      case Tensor.LongClass =>
        TInt64.tensorOf(
          shape,
          DataBuffers.of(
              dense.values.asInstanceOf[Array[Long]], true, false
          )
        )
      case Tensor.FloatClass =>
        TFloat32.tensorOf(
          shape,
          DataBuffers.of(
              dense.values.asInstanceOf[Array[Float]], true, false
          )
        )
      case Tensor.DoubleClass =>
        TFloat64.tensorOf(
          shape,
          DataBuffers.of(
              dense.values.asInstanceOf[Array[Double]], true, false
          )
        )
      case Tensor.StringClass =>
        val ndString = NdArrays.ofObjects(classOf[String],shape)
        val mlTensor = dense.asInstanceOf[DenseTensor[String]]
        if (shape.isScalar()) {
          ndString.setObject(mlTensor(0))
        } else {
          ndString.scalars.asInstanceOf[NdArraySequence[NdArray[String]]].forEachIndexed(
            new BiConsumer[Array[Long], NdArray[String]]
            {
              override def accept(i: Array[Long], e: NdArray[String]): Unit = {
                e.setObject(mlTensor(i.map(_.toInt):_*))
            }
          })
        }
        TString.tensorOf(ndString)

      case Tensor.ByteStringClass =>
        val ndString = NdArrays.ofObjects(classOf[Array[Byte]], shape)
        val mlTensor = dense.asInstanceOf[DenseTensor[ByteString]]
        if (shape.isScalar()) {
          ndString.setObject(mlTensor(0).bytes)
        } else {
          ndString.scalars.asInstanceOf[NdArraySequence[NdArray[Array[Byte]]]].forEachIndexed(
            new BiConsumer[Array[Long], NdArray[Array[Byte]]]
            {
              override def accept(i: Array[Long], e: NdArray[Array[Byte]]): Unit = {
                e.setObject(mlTensor(i.map(_.toInt):_*).bytes)
            }
          })
        }
        TString.tensorOfBytes(ndString)
      case _ =>
        throw new IllegalArgumentException(s"unsupported tensor type ${value.getClass}[${value.base.runtimeClass}]")
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy