All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.dl.dataset.embeddedDatasets.kt Maven / Gradle / Ivy

There is a newer version: 0.6.0-alpha-1
Show newest version
/*
 * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
 */

package org.jetbrains.kotlinx.dl.dataset

import io.jhdf.HdfFile
import io.jhdf.api.Dataset
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.AWS_S3_URL
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.LoadingMode
import org.jetbrains.kotlinx.dl.dataset.audio.wav.WavFile
import org.jetbrains.kotlinx.dl.dataset.handler.*
import java.io.*
import java.net.URL
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.nio.file.StandardCopyOption
import java.util.zip.ZipEntry
import java.util.zip.ZipFile


/**
 * Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
 * This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
 * along with a test set of 10,000 images.
 * More info can be found at the [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
 *
 * NOTE: Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
 * which is a derivative work from original NIST datasets.
 * MNIST dataset is made available under the terms of the
 * [Creative Commons Attribution-Share Alike 3.0 license.](https://creativecommons.org/licenses/by-sa/3.0/)
 *
 * @param [cacheDirectory] Cache directory to cached models and datasets.
 *
 * @return Train and test datasets. Each dataset includes X and Y data. X data are uint8 arrays of grayscale image data with shapes
 * (num_samples, 28, 28). Y data uint8 arrays of digit labels (integers in range 0-9) with shapes (num_samples,).
 */
public fun mnist(cacheDirectory: File = File("cache")): Pair {
    cacheDirectory.existsOrMkdirs()

    val trainXpath = loadFile(cacheDirectory, TRAIN_IMAGES_ARCHIVE).absolutePath
    val trainYpath = loadFile(cacheDirectory, TRAIN_LABELS_ARCHIVE).absolutePath
    val testXpath = loadFile(cacheDirectory, TEST_IMAGES_ARCHIVE).absolutePath
    val testYpath = loadFile(cacheDirectory, TEST_LABELS_ARCHIVE).absolutePath

    return OnHeapDataset.createTrainAndTestDatasets(
        trainXpath,
        trainYpath,
        testXpath,
        testYpath,
        NUMBER_OF_CLASSES,
        ::extractImages,
        ::extractLabels
    )
}

/**
 * Loads the Fashion-MNIST dataset.
 *
 * This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
 * along with a test set of 10,000 images. This dataset can be used as
 * a drop-in replacement for MNIST. The class labels are:
 *
 * | Label | Description |
 * |:-----:|-------------|
 * |   0   | T-shirt/top |
 * |   1   | Trousers     |
 * |   2   | Pullover    |
 * |   3   | Dress       |
 * |   4   | Coat        |
 * |   5   | Sandals      |
 * |   6   | Shirt       |
 * |   7   | Sneakers     |
 * |   8   | Bag         |
 * |   9   | Ankle boots  |
 *
 * NOTE: The copyright for Fashion-MNIST is held by Zalando SE.
 * Fashion-MNIST is licensed under the [MIT license](https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
 *
 * @param [cacheDirectory] Cache directory to cached models and datasets.
 *
 * @return Train and test datasets. Each dataset includes X and Y data. X data are uint8 arrays of grayscale image data with shapes
 * (num_samples, 28, 28). Y data uint8 arrays of digit labels (integers in range 0-9) with shapes (num_samples,).
 */
public fun fashionMnist(cacheDirectory: File = File("cache")): Pair {
    cacheDirectory.existsOrMkdirs()

    val trainXpath = loadFile(cacheDirectory, FASHION_TRAIN_IMAGES_ARCHIVE).absolutePath
    val trainYpath = loadFile(cacheDirectory, FASHION_TRAIN_LABELS_ARCHIVE).absolutePath
    val testXpath = loadFile(cacheDirectory, FASHION_TEST_IMAGES_ARCHIVE).absolutePath
    val testYpath = loadFile(cacheDirectory, FASHION_TEST_LABELS_ARCHIVE).absolutePath

    return OnHeapDataset.createTrainAndTestDatasets(
        trainXpath,
        trainYpath,
        testXpath,
        testYpath,
        NUMBER_OF_CLASSES,
        ::extractImages,
        ::extractLabels
    )
}

/** Path to H5 file of Mnist 3D Dataset. */
public const val MNIST_3D_DATASET: String = "datasets/mnist-3d/dataset.h5"

/**
 * Loads the [MNIST 3D dataset](https://www.kaggle.com/daavoo/3d-mnist).
 * This is a dataset of 10,000 16x16x16 grayscale 3D images of the 10 digits,
 * along with a test set of 2,000 3D images.
 *
 * NOTE: Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
 * which is a derivative work from original NIST datasets.
 * MNIST dataset is made available under the terms of the
 * [Creative Commons Attribution-Share Alike 3.0 license.](https://creativecommons.org/licenses/by-sa/3.0/)
 * MNIST 3D dataset was created by [daavoo](https://github.com/daavoo) as a transformation of
 * original MNIST dataset to 3D images to provide simple example of working with 3D images.
 *
 * @param [cacheDirectory] Cache directory to cached models and datasets.
 *
 * @return Train and test datasets. Each dataset includes X and Y data.
 * X data are float arrays of grayscale image data with shapes (num_samples, 16, 16, 16).
 * Y data float arrays of digit labels (integers in range 0-9) with shapes (num_samples,).
 */
public fun mnist3D(cacheDirectory: File = File("cache")): Pair {
    cacheDirectory.existsOrMkdirs()

    return HdfFile(loadFile(cacheDirectory, MNIST_3D_DATASET)).use {

        val (trainData, trainLabels) = it.extractMnist3DDataset("train")
        val (testData, testLabels) = it.extractMnist3DDataset("test")

        Pair(
            OnHeapDataset.create(trainData, trainLabels),
            OnHeapDataset.create(testData, testLabels)
        )
    }
}

/** Extract mnist3d X data from HD5 file [dataset] */
private fun extractMnist3DData(dataset: Dataset) =
    (dataset.data as Array<*>)
        .map { (it as DoubleArray).map(Double::toFloat).toFloatArray() }
        .toTypedArray()

/** Extract mnist3d Y labels from HD5 file [dataset] */
private fun extractMnist3DLabels(dataset: Dataset) =
    (dataset.data as LongArray).map(Long::toFloat).toFloatArray()

/** Extract mnist3d data and labels from HD5 file under specified [label] */
private fun HdfFile.extractMnist3DDataset(label: String): Pair, FloatArray> =
    Pair(
        extractMnist3DData(getDatasetByPath("X_$label")),
        extractMnist3DLabels(getDatasetByPath("y_$label"))
    )

/** Data size of [Free Spoken Digits Dataset](https://github.com/Jakobovski/free-spoken-digit-dataset). */
public const val FSDD_SOUND_DATA_SIZE: Long = 20480

/**
 * Loads the [Free Spoken Digits Dataset](https://github.com/Jakobovski/free-spoken-digit-dataset).
 * This is a dataset of wav sound files of the 10 digits spoken by different people many times each.
 * The test set officially consists of the first 10% of the recordings. Recordings numbered 0-4 (inclusive)
 * are in the test and 5-49 are in the training set.
 *
 * As the input data files have different number of channels of data we split every input file into separate samples
 * that are threatened as separate samples with the same label.
 *
 * Free Spoken Digits Dataset is made available under the terms of the
 * [Creative Commons Attribution-ShareAlike 4.0 International.](https://creativecommons.org/licenses/by-sa/4.0/)
 *
 * @param [cacheDirectory] Cache directory to cached models and datasets.
 * @param [maxTestIndex] Index of max sample to be selected to test part of data.
 *
 * @return Train and test datasets. Each dataset includes X and Y data. X data are float arrays of sound data with shapes
 * (num_samples, FSDD_SOUND_DATA_SIZE) where FSDD_SOUND_DATA_SIZE is at least as long as the longest input sequence and all
 * sequences are padded with zeros to have equal length. Y data float arrays of digit labels (integers in range 0-9)
 * with shapes (num_samples,).
 */
public fun freeSpokenDigits(
    cacheDirectory: File = File("cache"),
    maxTestIndex: Int = 5
): Pair {
    cacheDirectory.existsOrMkdirs()

    val path = freeSpokenDigitDatasetPath(cacheDirectory)
    val dataset = File(path)
        .listFiles()?.flatMap(::extractWavFileSamples)
        ?: throw IllegalStateException("Cannot find Free Spoken Digits Dataset files in $path")
    val maxDataSize = dataset.map { it.first.size }.maxOrNull()
        ?: throw IllegalStateException("Empty Free Spoken Digits Dataset")
    check(maxDataSize <= FSDD_SOUND_DATA_SIZE) {
        "Sound data should be limited to $FSDD_SOUND_DATA_SIZE values but has $maxDataSize"
    }
    val data = dataset.map(::extractPaddedDataWithIndex)
    val labels = dataset.map(::extractLabelWithIndex)

    val (trainData, testData) = data.splitToTrainAndTestByIndex(maxTestIndex)
    val (trainLabels, testLabels) = labels.splitToTrainAndTestByIndex(maxTestIndex)

    return Pair(
        OnHeapDataset.create(trainData, trainLabels.toFloatArray()),
        OnHeapDataset.create(testData, testLabels.toFloatArray())
    )
}

/**
 * Extract wav file samples from given file and return a list of data from all its
 * channels as a triple of (channel_data, label, sample_index)
 *
 * @param file to read from the sound data
 * @return list of triples (channel_data, label, sample_index) from all channels from file
 */
private fun extractWavFileSamples(file: File): List> =
    WavFile(file).use {
        val data = it.readRemainingFrames()
        val parts = file.name.split("_")
        val label = parts[0].toFloat()
        val index = parts[2].split(".")[0].toInt()
        data.map { channel -> Triple(channel, label, index) }
    }

private fun extractPaddedDataWithIndex(dataLabelIndex: Triple): Pair =
    Pair(dataLabelIndex.first.copyInto(FloatArray(FSDD_SOUND_DATA_SIZE.toInt())), dataLabelIndex.third)

private fun extractLabelWithIndex(dataLabelIndex: Triple): Pair =
    Pair(dataLabelIndex.second, dataLabelIndex.third)

private inline fun  List>.splitToTrainAndTestByIndex(maxTestIndex: Int): Pair, Array> {
    val test = filter { it.second < maxTestIndex }.map { it.first }.toTypedArray()
    val train = filter { it.second >= maxTestIndex }.map { it.first }.toTypedArray()
    return Pair(train, test)
}

/** Path to train images archive of Mnist Dataset. */
private const val CIFAR_10_IMAGES_ARCHIVE: String = "datasets/cifar10/images.zip"

/** Path to train labels archive of Mnist Dataset. */
private const val CIFAR_10_LABELS_ARCHIVE: String = "datasets/cifar10/trainLabels.csv"

/** Returns paths to images and its labels for the Cifar'10 dataset. */
public fun cifar10Paths(cacheDirectory: File = File("cache")): Pair {
    cacheDirectory.existsOrMkdirs()

    val pathToLabel = loadFile(cacheDirectory, CIFAR_10_LABELS_ARCHIVE).absolutePath

    val datasetDirectory = File(cacheDirectory.absolutePath + "/datasets/cifar10")
    val toFolder = datasetDirectory.toPath()

    val imageDataDirectory = File(cacheDirectory.absolutePath + "/datasets/cifar10/images")
    if (!imageDataDirectory.exists()) {
        Files.createDirectories(imageDataDirectory.toPath())

        val pathToImageArchive = loadFile(cacheDirectory, CIFAR_10_IMAGES_ARCHIVE)
        extractFromZipArchiveToFolder(pathToImageArchive.toPath(), toFolder)
        val deleted = pathToImageArchive.delete()
        if (!deleted) throw Exception("Archive ${pathToImageArchive.absolutePath} could not be deleted! Create this archive manually.")
    }

    return Pair(imageDataDirectory.toPath().toAbsolutePath().toString(), pathToLabel)
}

/** Path to the Dogs-vs-Cats dataset. */
private const val DOGS_CATS_IMAGES_ARCHIVE: String = "datasets/dogs-vs-cats/data.zip"

/** Returns path to images of the Dogs-vs-Cats dataset. */
public fun dogsCatsDatasetPath(cacheDirectory: File = File("cache")): String =
    unzipDatasetPath(
        cacheDirectory,
        loadFile(cacheDirectory, DOGS_CATS_IMAGES_ARCHIVE),
        "/datasets/dogs-vs-cats"
    )

/** Path to the subset of Dogs-vs-Cats dataset. */
private const val DOGS_CATS_SMALL_IMAGES_ARCHIVE: String = "datasets/small-dogs-vs-cats/data.zip"

/** Returns path to images of the subset of the Dogs-vs-Cats dataset. */
public fun dogsCatsSmallDatasetPath(cacheDirectory: File = File("cache")): String =
    unzipDatasetPath(
        cacheDirectory,
        loadFile(cacheDirectory, DOGS_CATS_SMALL_IMAGES_ARCHIVE),
        "/datasets/small-dogs-vs-cats"
    )

/** Path to the Free Spoken Digits Dataset. */
private const val FSDD_SOUNDS_ARCHIVE: String = "datasets/fsdd.zip"

/** Path to download the Free Spoken Digits Dataset. */
private const val FSS_SOUNDS_SOURCE: String =
    "https://codeload.github.com/Jakobovski/free-spoken-digit-dataset/zip/refs/heads/master"

/** Returns path to sound data files from Free Spoken Digits Dataset. */
public fun freeSpokenDigitDatasetPath(cacheDirectory: File = File("cache")): String =
    unzipDatasetPath(
        cacheDirectory,
        loadFile(cacheDirectory, FSDD_SOUNDS_ARCHIVE, downloadURLFromRelativePath = { FSS_SOUNDS_SOURCE }),
        "/datasets/free-spoken-digit"
    ).run {
        "$this/free-spoken-digit-dataset-master/recordings"
    }

/**
 * Download the compressed dataset from external source, decompress the file and remove the downloaded file
 * but leave the decompressed data from dataset.
 *
 * @param [cacheDirectory] The directory where the downloaded files are stored.
 * @param [archive] Archive file.
 * @param [dirRelativePath] The relative path where to store the downloaded archive temporarily and decompress its data.
 * @return The absolute path string to directory where dataset is decompressed.
 */
private fun unzipDatasetPath(cacheDirectory: File, archive: File, dirRelativePath: String): String {
    cacheDirectory.existsOrMkdirs()

    val dataDirectory = File(cacheDirectory.absolutePath + dirRelativePath)
    val toFolder = dataDirectory.toPath()

    if (!dataDirectory.exists()) Files.createDirectories(dataDirectory.toPath())

    if (archive.exists()) {
        extractFromZipArchiveToFolder(archive.toPath(), toFolder)
        val deleted = archive.delete()
        if (!deleted) {
            throw Exception("Archive ${archive.absolutePath} could not be deleted! Delete this archive manually.")
        }
    } else {
        throw Exception("No archive file ${archive.absolutePath} in the cache folder!")
    }

    return toFolder.toAbsolutePath().toString()
}

/**
 * Downloads a file from a URL if it not already in the cache.
 *
 * By default, the download location
 * is defined as the concatenation of [AWS_S3_URL] and [relativePathToFile] but can be defined
 * as an arbitrary file location to download file from *
 *
 * @param [cacheDirectory] where the downloaded file is stored
 * @param [relativePathToFile] where the downloaded file is stored in [cacheDirectory] and which can
 * define the location of file to be downloaded
 * @param [downloadURLFromRelativePath] can produce the download URL of the file using. Defaults to [AWS_S3_URL]/[relativePathToFile].
 * @param [loadingMode] of the file to be loaded. Defaults to [LoadingMode.SKIP_LOADING_IF_EXISTS]
 * @return downloaded [File] on local file system.
 */
private fun loadFile(
    cacheDirectory: File,
    relativePathToFile: String,
    downloadURLFromRelativePath: (String) -> String = { "$AWS_S3_URL/$it" },
    loadingMode: LoadingMode = LoadingMode.SKIP_LOADING_IF_EXISTS
): File {
    val fileName = cacheDirectory.absolutePath + "/" + relativePathToFile
    val file = File(fileName)
    file.parentFile.mkdirs() // Will create parent directories if not exists

    if (!file.exists() || loadingMode == LoadingMode.OVERRIDE_IF_EXISTS) {
        val urlString = downloadURLFromRelativePath(relativePathToFile)
        val inputStream = URL(urlString).openStream()
        Files.copy(inputStream, Paths.get(fileName), StandardCopyOption.REPLACE_EXISTING)
    }

    return file
}

/** Creates file structure archived in zip file with all directories and sub-directories. */
@Throws(IOException::class)
internal fun extractFromZipArchiveToFolder(zipArchivePath: Path, toFolder: Path, bufferSize: Int = 4096) {
    val zipFile = ZipFile(zipArchivePath.toFile())
    val entries = zipFile.entries()

    while (entries.hasMoreElements()) {
        val entry = entries.nextElement() as ZipEntry
        var currentEntry = entry.name
        currentEntry = currentEntry.replace('\\', '/')

        val destFile = File(toFolder.toFile(), currentEntry)

        val destinationParent = destFile.parentFile
        destinationParent.mkdirs()

        if (!entry.isDirectory && !destFile.exists()) {
            val inputStream = BufferedInputStream(
                zipFile.getInputStream(entry)
            )
            var currentByte: Int
            // establish buffer for writing file
            val data = ByteArray(bufferSize)

            // write the current file to disk
            val fos = FileOutputStream(destFile)
            val dest = BufferedOutputStream(
                fos,
                bufferSize
            )

            // read and write until last byte is encountered
            while (inputStream.read(data, 0, bufferSize).also { currentByte = it } != -1) {
                dest.write(data, 0, currentByte)
            }
            dest.flush()
            dest.close()
            inputStream.close()
        }
    }
    zipFile.close()
}

internal fun File.existsOrMkdirs() {
    if (!exists()) {
        val created = mkdirs()
        if (!created) {
            throw Exception("Directory $absolutePath could not be created! Create this directory manually.")
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy