All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.dynet.TreeLSTMBuilder.scala Maven / Gradle / Ivy

package edu.cmu.dynet


abstract class TreeLSTMBuilder[A<: internal.TreeLSTMBuilder] private[dynet](private[dynet] val builder: A)
  extends RnnBuilder(builder) {

  def addInput(id: Int, children: IntVector, x: Expression): Expression = {
    val expr = builder.add_input(id, children.vector, x.expr)
    new Expression(expr)
  }

  def setNumElements(num: Int) = {
    builder.set_num_elements(num)
  }
}


class BidirectionalTreeLSTMBuilder private[dynet](private[dynet] builder: internal.BidirectionalTreeLSTMBuilder)
  extends TreeLSTMBuilder[internal.BidirectionalTreeLSTMBuilder](builder) {

  def this(layers: Long, inputDim: Long, hiddenDim: Long, model: ParameterCollection) {
    this(new internal.BidirectionalTreeLSTMBuilder(layers, inputDim, hiddenDim, model.model))
  }
}

class UnidirectionalTreeLSTMBuilder private[dynet](private[dynet] builder: internal.UnidirectionalTreeLSTMBuilder)
  extends TreeLSTMBuilder[internal.UnidirectionalTreeLSTMBuilder](builder) {

  def this(layers: Long, inputDim: Long, hiddenDim: Long, model: ParameterCollection) {
    this(new internal.UnidirectionalTreeLSTMBuilder(layers, inputDim, hiddenDim, model.model))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy