All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.dl.dataset.handler.Cifar10Util.kt Maven / Gradle / Ivy

There is a newer version: 0.6.0-alpha-1
Show newest version
/*
 * Copyright 2020 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
 */

package org.jetbrains.kotlinx.dl.dataset.handler

import com.github.doyaaaaaken.kotlincsv.dsl.csvReader
import org.jetbrains.kotlinx.dl.dataset.OnHeapDataset
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import java.io.File
import java.io.IOException

private const val DATASET_SIZE = 50000

/** Loads images from [archiveName] to heap memory and applies basic normalization preprocessing. */
@Throws(IOException::class)
public fun extractCifar10Images(archiveName: String): Array {
    return loadImagesFromDirectory(
        DATASET_SIZE,
        archiveName
    )
}

private fun loadImagesFromDirectory(
    subDatasetSize: Int,
    archiveName: String
): Array {
    val images = Array(subDatasetSize) {
        ImageConverter.toNormalizedFloatArray(File(archiveName, "${it + 1}.png"), colorMode = ColorMode.BGR)
    }

    return images
}

/** Loads labels from [pathToLabels] csv file to heap memory and converts to Floats. */
@Throws(IOException::class)
public fun extractCifar10Labels(pathToLabels: String, numClasses: Int): FloatArray {
    val labelCount = DATASET_SIZE
    println(String.format("Extracting %d labels from %s", labelCount, pathToLabels))
    val labelBuffer = ByteArray(labelCount)

    val dictionary = mapOf(
        "airplane" to 0,
        "automobile" to 1,
        "bird" to 2,
        "cat" to 3,
        "deer" to 4,
        "dog" to 5,
        "frog" to 6,
        "horse" to 7,
        "ship" to 8,
        "truck" to 9
    )

    var cnt = 0
    csvReader().open(pathToLabels) {
        readAllAsSequence()
            .forEach { row ->
                labelBuffer[cnt] = dictionary.getOrElse(row[1]) { 1 }.toByte()
                cnt++
            }
    }

    val floats = FloatArray(labelCount)

    for (i in 0 until labelCount) {
        floats[i] = OnHeapDataset.convertByteToFloat(labelBuffer[i])
    }
    return floats
}

/**
 * Loads labels from [pathToLabels] csv file to heap memory and converts to Floats, after that it sorts it to have the same order as image files.
 *
 * NOTE: It's important if you are going to use it with [org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset].
 */
@Throws(IOException::class)
public fun extractCifar10LabelsAnsSort(pathToLabels: String, numClasses: Int): FloatArray {
    val labelCount = DATASET_SIZE
    println(String.format("Extracting %d labels from %s", labelCount, pathToLabels))
    val labelSorter = mutableMapOf()

    val dictionary = mapOf(
        "airplane" to 0,
        "automobile" to 1,
        "bird" to 2,
        "cat" to 3,
        "deer" to 4,
        "dog" to 5,
        "frog" to 6,
        "horse" to 7,
        "ship" to 8,
        "truck" to 9
    )

    csvReader().open(pathToLabels) {
        readAllAsSequence()
            .forEach { row ->
                labelSorter[row[0]] = dictionary.getOrElse(row[1]) { 1 }
            }
    }

    val sortedMap = labelSorter.toSortedMap()

    val labelBuffer = sortedMap.values.toIntArray()

    val floats = FloatArray(labelCount)

    for (i in 0 until labelCount) {
        floats[i] =
            OnHeapDataset.convertByteToFloat(
                labelBuffer[i].toByte()
            )
    }
    return floats
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy