All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.nn.quantized.QuantSerializer.scala Maven / Gradle / Ivy

There is a newer version: 0.11.1
Show newest version
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.nn.quantized

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table
import com.intel.analytics.bigdl.utils.serializer._
import com.intel.analytics.bigdl.utils.serializer.converters.TensorConverter
import com.intel.analytics.bigdl.serialization.Bigdl.{AttrValue, BigDLModule}

import scala.reflect.ClassTag

trait QuantSerializer extends ModuleSerializable {
  def serializeWeight[T: ClassTag](context: SerializeContext[T],
    modelBuilder: BigDLModule.Builder)(implicit ev: TensorNumeric[T]): Unit

  def serializeBias[T: ClassTag](context: SerializeContext[T],
    modelBuilder: BigDLModule.Builder)(implicit ev: TensorNumeric[T]): Unit = {
    val moduleData = context.moduleData
    val paramTable : Table = moduleData.module.getParametersTable()
    val moduleName = moduleData.module.getName()

    if (paramTable != null && paramTable.contains(moduleName)) {
      val modulePramTable: Table = paramTable(moduleName)
      val bias: Tensor[T] = if (modulePramTable.contains("bias")) {
        modulePramTable("bias")
      } else {
        null
      }

      if (bias != null) {
        val biasAttr = AttrValue.newBuilder
        TensorConverter.setAttributeValue(context, biasAttr, bias)
        modelBuilder.setBias(biasAttr.getTensorValue)
      }
    }
  }

  def serializeOthers[T: ClassTag](context: SerializeContext[T],
    modelBuilder: BigDLModule.Builder)(implicit ev: TensorNumeric[T]): Unit = {
  }

  def loadWeight[T: ClassTag](context: DeserializeContext,
    module: ModuleData[T])(implicit ev: TensorNumeric[T]): Unit

  def loadBias[T: ClassTag](context: DeserializeContext,
    moduleData: ModuleData[T])(implicit ev: TensorNumeric[T]): Unit = {
    val moduleName = moduleData.module.getName()
    val paramTable : Table = moduleData.module.getParametersTable
    if (paramTable != null && paramTable.contains(moduleName)) {
      val modulePramTable : Table = paramTable(moduleName)
      val bias : Tensor[T] = if (modulePramTable.contains("bias")) {
        modulePramTable("bias")
      } else {
        null
      }

      if (bias != null) {
        val attrValue = AttrValue.newBuilder
        attrValue.setTensorValue(context.bigdlModule.getBias)
        val bias = TensorConverter.getAttributeValue(context, attrValue.build)
        modulePramTable("bias").asInstanceOf[Tensor[T]].copy(bias.asInstanceOf[Tensor[T]])
      }
    }
  }

  def loadOthers[T: ClassTag](context: DeserializeContext,
    module: ModuleData[T])(implicit ev: TensorNumeric[T]): Unit = {
  }

  override protected def copyFromBigDL[T: ClassTag](context: SerializeContext[T],
    modelBuilder: BigDLModule.Builder)(implicit ev: TensorNumeric[T]): Unit = {
    val storageType = context.storageType
    if (storageType == ProtoStorageType) {
      serializeWeight(context, modelBuilder)
      serializeBias(context, modelBuilder)
      serializeOthers(context, modelBuilder)
    } else {
      throw new IllegalArgumentException(s"$storageType not supported!")
    }
  }

  override protected def copy2BigDL[T: ClassTag](context: DeserializeContext, module: ModuleData[T])
    (implicit ev: TensorNumeric[T]): Unit = {
    loadWeight(context, module)
    loadBias(context, module)
    loadOthers(context, module)
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy