com.intel.analytics.zoo.common.PythonZoo.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.common
import java.util
import com.intel.analytics.bigdl.python.api.{JTensor, PythonBigDLKeras, Sample}
import com.intel.analytics.bigdl.tensor.{DenseType, SparseType, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.zoo.pipeline.api.Predictable
import org.apache.spark.api.java.JavaRDD
import java.util.{List => JList, Map => JMap}
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.optim.LocalPredictor
import com.intel.analytics.bigdl.utils.Table
import com.intel.analytics.zoo.feature.image.ImageSet
import com.intel.analytics.zoo.feature.text.TextSet
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
object PythonZoo {
def ofFloat(): PythonZoo[Float] = new PythonZoo[Float]()
def ofDouble(): PythonZoo[Double] = new PythonZoo[Double]()
}
class PythonZoo[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonBigDLKeras[T] {
private val typeName = {
val cls = implicitly[ClassTag[T]].runtimeClass
cls.getSimpleName
}
override def toTensor(jTensor: JTensor): Tensor[T] = {
if (jTensor == null) return null
this.typeName match {
case "float" =>
if (null == jTensor.indices) {
if (jTensor.shape == null || jTensor.shape.product == 0) {
Tensor()
} else {
Tensor(jTensor.storage.map(x => ev.fromType(x)), jTensor.shape)
}
} else {
Tensor.sparse(jTensor.indices, jTensor.storage.map(x => ev.fromType(x)), jTensor.shape)
}
case "double" =>
if (null == jTensor.indices) {
if (jTensor.shape == null || jTensor.shape.product == 0) {
Tensor()
} else {
Tensor(jTensor.storage.map(x => ev.fromType(x.toDouble)), jTensor.shape)
}
} else {
Tensor.sparse(jTensor.indices,
jTensor.storage.map(x => ev.fromType(x.toDouble)), jTensor.shape)
}
case t: String =>
throw new IllegalArgumentException(s"Not supported type: ${t}")
}
}
override def toJTensor(tensor: Tensor[T]): JTensor = {
// clone here in case the the size of storage larger then the size of tensor.
require(tensor != null, "tensor cannot be null")
tensor.getTensorType match {
case SparseType =>
// Note: as SparseTensor's indices is inaccessible here,
// so we will transfer it to DenseTensor. Just for testing.
if (tensor.nElement() == 0) {
JTensor(Array(), Array(0), bigdlType = typeName)
} else {
val cloneTensor = Tensor.dense(tensor)
val result = JTensor(cloneTensor.storage().array().map(i => ev.toType[Float](i)),
cloneTensor.size(), bigdlType = typeName)
result
}
case DenseType =>
if (tensor.nElement() == 0) {
if (tensor.dim() == 0) {
JTensor(null, null, bigdlType = typeName)
} else {
JTensor(Array(), tensor.size(), bigdlType = typeName)
}
} else {
val cloneTensor = tensor.clone()
val result = JTensor(cloneTensor.storage().array().map(i => ev.toType[Float](i)),
cloneTensor.size(), bigdlType = typeName)
result
}
case _ =>
throw new IllegalArgumentException(s"toJTensor: Unsupported tensor type" +
s" ${tensor.getTensorType}")
}
}
def activityToList(outputActivity: Activity): JList[Object] = {
if (outputActivity.isInstanceOf[Tensor[T]]) {
val list = new util.ArrayList[Object]()
list.add(toJTensor(outputActivity.toTensor))
list
} else {
table2JList(outputActivity.toTable)
}
}
private def table2JList(t: Table): JList[Object] = {
var i = 1
val list = new util.ArrayList[Object]()
while (i <= t.length()) {
val item = t[Object](i)
if (item.isInstanceOf[Tensor[T]]) {
list.add(toJTensor(item.asInstanceOf[Tensor[T]]))
} else if (item.isInstanceOf[Table]) {
list.add(table2JList(item.asInstanceOf[Table]))
} else {
throw new IllegalArgumentException(s"Table contains unrecognizable objects $item")
}
i += 1
}
list
}
def zooPredict(
module: Predictable[T],
x: JavaRDD[Sample],
batchPerThread: Int): JavaRDD[JList[Object]] = {
module.predict(x.rdd.map(toJSample), batchPerThread).map(activityToList).toJavaRDD()
}
def zooForward(model: AbstractModule[Activity, Activity, T],
input: JList[JTensor],
inputIsTable: Boolean): JList[Object] = {
val inputActivity = jTensorsToActivity(input, inputIsTable)
val outputActivity = model.forward(inputActivity)
activityToList(outputActivity)
}
def zooPredict(
module: Module[T],
x: JList[JTensor],
batchPerThread: Int): JList[JList[Object]] = {
val sampleArray = toSampleArray(x.asScala.toList.map{f => toTensor(f)})
val localPredictor = LocalPredictor(module,
batchPerCore = batchPerThread)
val result = localPredictor.predict(sampleArray)
result.map(activityToList).toList.asJava
}
def zooPredict(
module: Predictable[T],
x: ImageSet,
batchPerThread: Int): ImageSet = {
module.predict(x, batchPerThread)
}
def zooPredict(
module: Predictable[T],
x: TextSet,
batchPerThread: Int): TextSet = {
module.predict(x, batchPerThread)
}
def zooPredictClasses(
module: Predictable[T],
x: JavaRDD[Sample],
batchPerThread: Int,
zeroBasedLabel: Boolean = true): JavaRDD[Int] = {
module.predictClasses(toJSample(x), batchPerThread, zeroBasedLabel).toJavaRDD()
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy