All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.combust.bundle.tensor.TensorSerializer.scala Maven / Gradle / Ivy

package ml.combust.bundle.tensor

import com.google.protobuf
import ml.bundle.{BasicType, Tensor, TensorDimension, TensorShape}
import ml.combust.mleap.tensor
import ml.combust.mleap.tensor.ByteString

import scala.reflect.{ClassTag, classTag}

/**
  * Created by hollinwilkins on 1/15/17.
  */
object TensorSerializer {
  def toBundleType[T: ClassTag]: BasicType = classTag[T].runtimeClass match {
    case tensor.Tensor.BooleanClass => BasicType.BOOLEAN
    case tensor.Tensor.ByteClass => BasicType.BYTE
    case tensor.Tensor.ShortClass => BasicType.SHORT
    case tensor.Tensor.IntClass => BasicType.INT
    case tensor.Tensor.LongClass => BasicType.LONG
    case tensor.Tensor.FloatClass => BasicType.FLOAT
    case tensor.Tensor.DoubleClass => BasicType.DOUBLE
    case tensor.Tensor.StringClass => BasicType.STRING
    case tensor.Tensor.ByteStringClass => BasicType.BYTE_STRING
    case _ => throw new RuntimeException(s"unsupported base type ${classTag[T].runtimeClass.getName}")
  }

  def toProto[T](t: tensor.Tensor[T]): Tensor = {
    val (tpe, values) = t.base.runtimeClass match {
      case tensor.Tensor.BooleanClass =>
        (BasicType.BOOLEAN, BooleanArraySerializer.write(t.rawValues.asInstanceOf[Array[Boolean]]))
      case tensor.Tensor.ByteClass =>
        (BasicType.BYTE, ByteArraySerializer.write(t.rawValues.asInstanceOf[Array[Byte]]))
      case tensor.Tensor.ShortClass =>
        (BasicType.SHORT, ShortArraySerializer.write(t.rawValues.asInstanceOf[Array[Short]]))
      case tensor.Tensor.IntClass =>
        (BasicType.INT, IntArraySerializer.write(t.rawValues.asInstanceOf[Array[Int]]))
      case tensor.Tensor.LongClass =>
        (BasicType.LONG, LongArraySerializer.write(t.rawValues.asInstanceOf[Array[Long]]))
      case tensor.Tensor.FloatClass =>
        (BasicType.FLOAT, FloatArraySerializer.write(t.rawValues.asInstanceOf[Array[Float]]))
      case tensor.Tensor.DoubleClass =>
        (BasicType.DOUBLE, DoubleArraySerializer.write(t.rawValues.asInstanceOf[Array[Double]]))
      case tensor.Tensor.StringClass =>
        (BasicType.STRING, StringArraySerializer.write(t.rawValues.asInstanceOf[Array[String]]))
      case tensor.Tensor.ByteStringClass =>
        (BasicType.BYTE_STRING, ByteStringArraySerializer.write(t.rawValues.asInstanceOf[Array[ByteString]]))
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }

    val dimensions = t.dimensions.map(i => TensorDimension(size = i))
    Tensor(base = tpe,
      shape = Some(TensorShape(dimensions)),
      value = protobuf.ByteString.copyFrom(values))
  }

  def fromProto[T](t: Tensor): tensor.Tensor[T] = {
    val dimensions = t.shape.get.dimensions.map(_.size)
    val valueBytes = t.value.toByteArray

    val tn = t.base match {
      case BasicType.BOOLEAN =>
        tensor.Tensor.create(BooleanArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE =>
        tensor.Tensor.create(ByteArraySerializer.read(valueBytes), dimensions)
      case BasicType.SHORT =>
        tensor.Tensor.create(ShortArraySerializer.read(valueBytes), dimensions)
      case BasicType.INT =>
        tensor.Tensor.create(IntArraySerializer.read(valueBytes), dimensions)
      case BasicType.LONG =>
        tensor.Tensor.create(LongArraySerializer.read(valueBytes), dimensions)
      case BasicType.FLOAT =>
        tensor.Tensor.create(FloatArraySerializer.read(valueBytes), dimensions)
      case BasicType.DOUBLE =>
        tensor.Tensor.create(DoubleArraySerializer.read(valueBytes), dimensions)
      case BasicType.STRING =>
        tensor.Tensor.create(StringArraySerializer.read(valueBytes), dimensions)
      case BasicType.BYTE_STRING =>
        tensor.Tensor.create(ByteStringArraySerializer.read(valueBytes), dimensions)
      case _ => throw new IllegalArgumentException(s"unsupported tensor type ${t.base}")
    }

    tn.asInstanceOf[tensor.Tensor[T]]
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy