com.microsoft.ml.spark.core.env.StreamUtilities.scala Maven / Gradle / Ivy
The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.ml.spark.core.env
import java.io.{ByteArrayOutputStream, InputStream}
import java.util.zip.ZipInputStream
import org.apache.commons.io.IOUtils
import scala.util.Random
object StreamUtilities {
import scala.util.{Failure, Success, Try}
def using[T <: AutoCloseable, U](disposable: Seq[T])(task: Seq[T] => U): Try[U] = {
try {
Success(task(disposable))
} catch {
case e: Exception => Failure(e)
} finally {
disposable.foreach(d => d.close())
}
}
def using[T <: AutoCloseable, U](disposable: T)(task: T => U): Try[U] = {
try {
Success(task(disposable))
} catch {
case e: Exception => Failure(e)
} finally {
disposable.close()
}
}
/** Iterate through the entries of a streamed .zip file, selecting only sampleRatio of them
*
* @param stream Stream of zip file
* @param zipfile File name is only used to construct the names of the entries
* @param sampleRatio What fraction of files is returned from zip
*/
class ZipIterator(stream: InputStream, zipfile: String, random: Random, sampleRatio: Double = 1)
extends Iterator[(String, Array[Byte])] {
private val zipStream = new ZipInputStream(stream)
private def getNext: Option[(String, Array[Byte])] = {
var entry = zipStream.getNextEntry
while (entry != null) {
if (!entry.isDirectory && random.nextDouble < sampleRatio) {
val filename = zipfile + java.io.File.separator + entry.getName()
//extracting all bytes of a given entry
val byteStream = new ByteArrayOutputStream
IOUtils.copy(zipStream, byteStream)
val bytes = byteStream.toByteArray
assert(bytes.length == entry.getSize,
"incorrect number of bytes is read from zipstream: " + bytes.length + " instead of " + entry.getSize)
return Some((filename, bytes))
}
entry = zipStream.getNextEntry
}
stream.close()
None
}
private var nextValue = getNext
def hasNext: Boolean = nextValue.isDefined
def next: (String, Array[Byte]) = {
val result = nextValue.get
nextValue = getNext
result
}
}
}