All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.komputation.cuda.workflow.CudaTrainer.kt Maven / Gradle / Ivy

Go to download

Komputation is a neural network framework for the JVM written in the Kotlin programming language.

There is a newer version: 0.12.5
Show newest version
package com.komputation.cuda.workflow

import jcuda.Pointer
import com.komputation.cuda.network.CudaBackwardPropagator
import com.komputation.cuda.network.CudaForwardPropagator
import com.komputation.cuda.loss.CudaLossFunction
import com.komputation.cuda.memory.InputMemory
import com.komputation.cuda.memory.TargetMemory
import com.komputation.matrix.Matrix
import com.komputation.matrix.partitionIndices
import com.komputation.optimization.Optimizable

class CudaTrainer(
    private val forwardPropagator : CudaForwardPropagator,
    private val backwardPropagator : CudaBackwardPropagator,
    private val optimizables : Array,
    private val inputs : Array,
    private val targets: Array,
    private val numberIterations : Int,
    private val maximumBatchSize : Int,
    private val lossFunction : CudaLossFunction,
    private val afterEachIteration : ((index : Int, loss : Float) -> Unit)? = null) {

    private val numberExamples = this.inputs.size

    private val batches = partitionIndices(this.numberExamples, this.maximumBatchSize)

    private val inputMemory = InputMemory()
    private val targetMemory = TargetMemory(this.targets.first().size)

    init {

        this.lossFunction.acquire(this.maximumBatchSize)

    }

    fun free() {

        this.lossFunction.release()

        this.inputMemory.free()
        this.targetMemory.free()

    }

    fun run(): Long {

        val trackLoss = this.afterEachIteration != null

        val start = System.currentTimeMillis()

        repeat(this.numberIterations) { indexIteration ->

            var iterationLoss = if(trackLoss) 0f else Float.NaN

            for ((batchId, batch) in this.batches.withIndex()) {

                val currentBatchSize = batch.size

                val devicePredictions = this.forwardPropagator.forward(batchId, currentBatchSize, batch, this.inputs, this.inputMemory,true)
                val pointerToDevicePredictions = Pointer.to(devicePredictions)

                val pointerToTargets = this.targetMemory.get(batchId, currentBatchSize, batch, this.targets)

                if (trackLoss) {

                    this.lossFunction.accumulate(pointerToDevicePredictions, pointerToTargets, currentBatchSize)

                }

                val backwardLoss = this.lossFunction.backward(pointerToDevicePredictions, pointerToTargets, currentBatchSize)

                this.backwardPropagator.backward(backwardLoss, currentBatchSize)

                for (optimizable in this.optimizables) {

                    optimizable.optimize(currentBatchSize)

                }

                if (trackLoss) {

                    val batchLoss = this.lossFunction.accessAccumulation()

                    iterationLoss += batchLoss

                }

            }

            this.afterEachIteration?.invoke(indexIteration, iterationLoss)

        }

        val stop = System.currentTimeMillis()

        val time = stop - start

        return time

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy