com.intel.analytics.zoo.models.image.objectdetection.common.dataset.CustomizedDataSet.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of analytics-zoo-bigdl_0.13.0-spark_2.4.6 Show documentation
Show all versions of analytics-zoo-bigdl_0.13.0-spark_2.4.6 Show documentation
Big Data AI platform for distributed TensorFlow and PyTorch on Apache Spark.
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.models.image.objectdetection.common.dataset
import java.io.File
import java.nio.file.Paths
import com.intel.analytics.bigdl.transform.vision.image.ImageFeature
import com.intel.analytics.zoo.feature.image.{ImageSet, LocalImageSet}
import scala.io.Source
/**
* Parse the customized dataset, load images and annotations
* @param imageSet train, val, test, etc
* @param devkitPath dataset folder
*/
class CustomizedDataSet(val imageSet: String, devkitPath: String) extends Imdb {
def getRoidb(readImage: Boolean = true): LocalImageSet = {
val classFile = new File(devkitPath + "/" + "classname.txt")
require(classFile.exists(), s"if labelMap is null," +
s" there should be a classname.txt in $devkitPath")
val labelMap = getLabelMap(classFile.getAbsolutePath)
val imageSetFile = Paths.get(devkitPath, "ImageSets", s"$imageSet.txt").toFile
assert(imageSetFile.exists(), "Path does not exist " + imageSetFile.getAbsolutePath)
val array = Source.fromFile(imageSetFile).getLines()
.map(line => line.trim.split("\\s")).map(x => {
val imagePath = devkitPath + "/" + x(0)
val image = if (readImage) loadImage(imagePath) else null
ImageFeature(image,
PascalVoc.loadAnnotation(devkitPath + "/" + x(1), labelMap),
imagePath)
}).toArray
ImageSet.array(array)
}
def getLabelMap(labelFile: String): Map[String, Float] = {
val classes = Source.fromFile(labelFile).getLines().toArray
(classes zip (Stream from 1).map(_.toFloat)).toMap
}
}