All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.combust.bundle.tensor.ArraySerializer.scala Maven / Gradle / Ivy

package ml.combust.bundle.tensor

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer

import ml.combust.mleap.tensor.ByteString
import resource._

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

/**
  * Created by hollinwilkins on 1/15/17.
  */
trait ArraySerializer[T] {
  def write(arr: Array[T]): Array[Byte]
  def read(arr: Array[Byte]): Array[T]
}

object BooleanArraySerializer extends ArraySerializer[Boolean] {
  override def write(arr: Array[Boolean]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length)
    arr.foreach {
      case true => b.put(1: Byte)
      case false => b.put(0: Byte)
    }

    b.array()
  }

  override def read(arr: Array[Byte]): Array[Boolean] = {
    val b = ByteBuffer.wrap(arr)
    val bArr = mutable.ArrayBuilder.make[Boolean]
    while(b.hasRemaining) {
      bArr += (if(b.get == 1) true else false)
    }
    bArr.result()
  }
}

object ByteArraySerializer extends ArraySerializer[Byte] {
  override def write(arr: Array[Byte]): Array[Byte] = arr
  override def read(arr: Array[Byte]): Array[Byte] = arr
}

object ShortArraySerializer extends ArraySerializer[Short] {
  override def write(arr: Array[Short]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length * 2)
    arr.foreach(s => b.putShort(s))
    b.array()
  }

  override def read(arr: Array[Byte]): Array[Short] = {
    val b = ByteBuffer.wrap(arr)
    val a = mutable.ArrayBuilder.make[Short]

    while(b.hasRemaining) {
      a += b.getShort()
    }

    a.result()
  }
}

object IntArraySerializer extends ArraySerializer[Int] {
  override def write(arr: Array[Int]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length * 4)
    arr.foreach(s => b.putInt(s))
    b.array()
  }

  override def read(arr: Array[Byte]): Array[Int] = {
    val b = ByteBuffer.wrap(arr)
    val a = mutable.ArrayBuilder.make[Int]

    while(b.hasRemaining) {
      a += b.getInt()
    }

    a.result()
  }
}

object LongArraySerializer extends ArraySerializer[Long] {
  override def write(arr: Array[Long]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length * 8)
    arr.foreach(s => b.putLong(s))
    b.array()
  }

  override def read(arr: Array[Byte]): Array[Long] = {
    val b = ByteBuffer.wrap(arr)
    val a = mutable.ArrayBuilder.make[Long]

    while(b.hasRemaining) {
      a += b.getLong()
    }

    a.result()
  }
}

object FloatArraySerializer extends ArraySerializer[Float] {
  override def write(arr: Array[Float]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length * 4)
    arr.foreach(s => b.putFloat(s))
    b.array()
  }

  override def read(arr: Array[Byte]): Array[Float] = {
    val b = ByteBuffer.wrap(arr)
    val a = mutable.ArrayBuilder.make[Float]

    while(b.hasRemaining) {
      a += b.getFloat()
    }

    a.result()
  }
}

object DoubleArraySerializer extends ArraySerializer[Double] {
  override def write(arr: Array[Double]): Array[Byte] = {
    val b = ByteBuffer.allocate(arr.length * 8)
    arr.foreach(s => b.putDouble(s))
    b.array()
  }

  override def read(arr: Array[Byte]): Array[Double] = {
    val b = ByteBuffer.wrap(arr)
    val a = mutable.ArrayBuilder.make[Double]

    while(b.hasRemaining) {
      a += b.getDouble()
    }

    a.result()
  }
}

object StringArraySerializer extends ArraySerializer[String] {
  override def write(arr: Array[String]): Array[Byte] = {
    (for(out <- managed(new ByteArrayOutputStream())) yield {
      val dout = new DataOutputStream(out)
      arr.foreach {
        str =>
          dout.writeInt(str.length)
          dout.writeBytes(str)
      }

      dout.flush()
      out.toByteArray
    }).opt.get
  }

  override def read(arr: Array[Byte]): Array[String] = {
    (for(in <- managed(new ByteArrayInputStream(arr))) yield {
      val din = new DataInputStream(in)
      val arr = mutable.ArrayBuilder.make[String]

      var done = false
      while(!done) {
        Try {
          val size = din.readInt()
          val bytes = new Array[Byte](size)
          din.readFully(bytes)
          new String(bytes)
        } match {
          case Success(str) => arr += str
          case Failure(_) => done = true
        }
      }

      arr.result()
    }).opt.get
  }
}

object ByteStringArraySerializer extends ArraySerializer[ByteString] {
  override def write(arr: Array[ByteString]): Array[Byte] = {
    (for(out <- managed(new ByteArrayOutputStream())) yield {
      val dout = new DataOutputStream(out)
      arr.foreach {
        bs =>
          dout.writeInt(bs.bytes.length)
          dout.write(bs.bytes)
      }

      dout.flush()
      out.toByteArray
    }).opt.get
  }

  override def read(arr: Array[Byte]): Array[ByteString] = {
    (for(in <- managed(new ByteArrayInputStream(arr))) yield {
      val din = new DataInputStream(in)
      val arr = mutable.ArrayBuilder.make[ByteString]

      var done = false
      while(!done) {
        Try {
          val size = din.readInt()
          val bytes = new Array[Byte](size)
          din.readFully(bytes)
          ByteString(bytes)
        } match {
          case Success(bs) => arr += bs
          case Failure(_) => done = true
        }
      }

      arr.result()
    }).opt.get
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy