All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.utils.serializer.converters.DataReaderWriter.scala Maven / Gradle / Ivy

/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.utils.serializer.converters

import java.io.{DataInputStream, DataOutputStream}

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Storage
import com.intel.analytics.bigdl.utils.serializer.BigDLDataType
import com.intel.analytics.bigdl.utils.serializer.BigDLDataType.BigDLDataType

/**
 * DataReaderWriter defines how to read/write weight data from bin file
 */
trait DataReaderWriter {
  def write(outputStream: DataOutputStream, data: Array[_]): Unit
  def read(inputStream: DataInputStream, size: Int): Any
  def dataType(): BigDLDataType
}

object FloatReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeFloat(d.asInstanceOf[Float]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Float](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readFloat
    }
    Storage[Float](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.FLOAT
}

object DoubleReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeDouble(d.asInstanceOf[Double]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Double](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readDouble
    }
    Storage[Double](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.DOUBLE
}

object CharReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeChar(d.asInstanceOf[Char]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Char](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readChar
    }
    Storage[Char](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.CHAR
}

object BoolReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeBoolean(d.asInstanceOf[Boolean]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Boolean](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readBoolean
    }
    Storage[Boolean](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.BOOL
}

object StringReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(str => {
      val value = str.asInstanceOf[String].getBytes("utf-8")
      outputStream.writeInt(value.size)
      outputStream.write(value)
    })
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[String](size)
    for (i <- 0 until size) {
      val ssize = inputStream.readInt
      val buffer = new Array[Byte](ssize)
      inputStream.read(buffer)
      data(i) = new String(buffer, "utf-8")
    }
    Storage[String](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.STRING
}

object IntReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeInt(d.asInstanceOf[Int]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Int](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readInt
    }
    Storage[Int](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.INT
}

object ShortReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeShort(d.asInstanceOf[Short]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Short](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readShort
    }
    Storage[Short](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.SHORT
}

object LongReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(d => outputStream.writeLong(d.asInstanceOf[Long]))
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Long](size)
    for (i <- 0 until size) {
      data(i) = inputStream.readLong
    }
    Storage[Long](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.LONG
}

object ByteStringReaderWriter extends DataReaderWriter {
  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
    data.foreach(str => {
      val value = str.asInstanceOf[ByteString].toByteArray
      outputStream.writeInt(value.size)
      outputStream.write(value)
    })
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[ByteString](size)
    for (i <- 0 until size) {
      val ssize = inputStream.readInt
      val buffer = new Array[Byte](ssize)
      inputStream.read(buffer)
      data(i) = ByteString.copyFrom(buffer)
    }
    Storage[ByteString](data)
  }

  def dataType(): BigDLDataType = BigDLDataType.BYTESTRING
}

object ByteReaderWriter extends DataReaderWriter {

  override def write(outputStream: DataOutputStream, data: Array[_]): Unit = {
   outputStream.write(data.asInstanceOf[Array[Byte]])
  }

  override def read(inputStream: DataInputStream, size: Int): Any = {
    val data = new Array[Byte](size)
    inputStream.read(data)
    Storage[Byte](data)
  }

  override def dataType(): BigDLDataType = BigDLDataType.BYTE
}

object DataReaderWriter {
  def apply(datas : Array[_]): DataReaderWriter = {
    datas match {
      case flats: Array[Float] => FloatReaderWriter
      case dbls: Array[Double] => DoubleReaderWriter
      case chs: Array[Char] => CharReaderWriter
      case bools: Array[Boolean] => BoolReaderWriter
      case strs : Array[String] => StringReaderWriter
      case ints : Array[Int] => IntReaderWriter
      case shorts : Array[Short] => ShortReaderWriter
      case longs : Array[Long] => LongReaderWriter
      case bytestrs : Array[ByteString] => ByteStringReaderWriter
      case bytes : Array[Byte] => ByteReaderWriter
      case _ => throw new RuntimeException("Unsupported Type")
    }
  }

  def apply(dataType : BigDLDataType): DataReaderWriter = {
    dataType match {
      case BigDLDataType.FLOAT => FloatReaderWriter
      case BigDLDataType.DOUBLE => DoubleReaderWriter
      case BigDLDataType.CHAR => CharReaderWriter
      case BigDLDataType.BOOL => BoolReaderWriter
      case BigDLDataType.STRING => StringReaderWriter
      case BigDLDataType.INT => IntReaderWriter
      case BigDLDataType.SHORT => ShortReaderWriter
      case BigDLDataType.LONG => LongReaderWriter
      case BigDLDataType.BYTESTRING => ByteStringReaderWriter
      case BigDLDataType.BYTE => ByteReaderWriter
      case _ => throw new RuntimeException("Unsupported Type")
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy