com.komputation.cpu.demos.mnist.MnistBatchDropout.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of komputation Show documentation
Show all versions of komputation Show documentation
Komputation is a neural network framework for the JVM written in the Kotlin programming language.
package com.komputation.cpu.demos.mnist
import com.komputation.cpu.network.Network
import com.komputation.demos.mnist.MnistData
import com.komputation.initialization.heInitialization
import com.komputation.layers.entry.inputLayer
import com.komputation.layers.forward.activation.ActivationFunction
import com.komputation.layers.forward.dense.denseLayer
import com.komputation.layers.forward.dropout.dropoutLayer
import com.komputation.loss.crossEntropyLoss
import com.komputation.optimization.historical.momentum
import java.io.File
import java.util.*
// The data set for this demo can be found here: https://pjreddie.com/projects/mnist-in-csv/
fun main(args: Array) {
if (args.size != 2) {
throw Exception("Please specify the paths to the MNIST training data and the test data (in the CSV format).")
}
val random = Random(1)
val numberIterations = 50
val batchSize = 64
val (trainingInputs, trainingTargets) = MnistData.loadMnistTraining(File(args.first()))
val (testInputs, testTargets) = MnistData.loadMnistTest(File(args.last()))
val inputDimension = 784
val hiddenDimension = 100
val numberCategories = MnistData.numberCategories
val initialization = heInitialization(random)
val optimizer = momentum(0.01f, 0.9f)
val keepProbability = 0.85f
val hiddenLayer = denseLayer(
inputDimension,
hiddenDimension,
initialization,
initialization,
ActivationFunction.ReLU,
optimizer
)
val outputLayer = denseLayer(
hiddenDimension,
numberCategories,
initialization,
initialization,
ActivationFunction.Softmax,
optimizer
)
val network = Network(
batchSize,
inputLayer(inputDimension),
hiddenLayer,
dropoutLayer(random, keepProbability, hiddenDimension),
outputLayer
)
val test = network
.test(
testInputs,
testTargets,
batchSize,
numberCategories)
network.training(
trainingInputs,
trainingTargets,
numberIterations,
crossEntropyLoss(numberCategories)) { _ : Int, _ : Float ->
println(test.run())
}
.run()
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy