All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.stripe.rainier.ir.DataFunction.scala Maven / Gradle / Ivy

The newest version!
package com.stripe.rainier.ir

/*
Input layout:
- numParamInputs param inputs
- for 0 <= i < data.size:
   - data[i].size inputs

Output layout:
- for 0 <= i < data.size:
   - numOutputs outputs
 */
case class DataFunction(cf: CompiledFunction,
                        numParamInputs: Int,
                        numOutputs: Int,
                        data: Array[Array[Array[Double]]]) {
  val numInputs: Int = cf.numInputs
  val numGlobals: Int = cf.numGlobals

  private val inputStartIndices =
    data.map(_.size).scanLeft(numParamInputs) {
      case (x, sz) => x + sz
    }
  require(inputStartIndices(data.size) == cf.numInputs)

  private val outputStartIndices =
    data.scanLeft(0) {
      case (x, _) => x + numOutputs
    }
  require(outputStartIndices(data.size) == cf.numOutputs)

  def apply(inputs: Array[Double],
            globals: Array[Double],
            outputs: Array[Double]): Unit = {
    var k = 0
    while (k < outputs.size) {
      outputs(k) = 0.0
      k += 1
    }

    var i = 0
    while (i < data.size) {
      compute(inputs, globals, outputs, i)
      i += 1
    }
  }

  private def compute(inputs: Array[Double],
                      globals: Array[Double],
                      outputs: Array[Double],
                      i: Int): Unit = {
    val inputStartIndex = inputStartIndices(i)
    val outputStartIndex = outputStartIndices(i)
    val d = data(i)
    if (d.size > 0) {
      var k = 0
      val n = d(0).size
      while (k < n) {
        var j = 0
        while (j < d.size) {
          inputs(inputStartIndex + j) = d(j)(k)
          j += 1
        }
        var o = 0
        while (o < numOutputs) {
          outputs(o) += CompiledFunction.output(cf,
                                                inputs,
                                                globals,
                                                outputStartIndex + o)
          o += 1
        }
        k += 1
      }
    } else {
      var o = 0
      while (o < numOutputs) {
        outputs(o) += CompiledFunction.output(cf,
                                              inputs,
                                              globals,
                                              outputStartIndex + o)
        o += 1
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy