Source code for gatelfpytorchjson.concat

import torch.nn
import logging
import sys

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
streamhandler = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
                '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
streamhandler.setFormatter(formatter)
logger.addHandler(streamhandler)



[docs]class Concat(torch.nn.Module): def __init__(self, inputs, name="concat", dim=None): """Concatenates the outputs of the given layers to a single output. The default axis for concatenating is the last dimension of the tensor. This can be overriden by setting dim to the axis. """ super().__init__() self.inputs = inputs for i,input in enumerate(inputs): self.add_module(name+("-%s" % i), input) self.name = name self.dim = dim
[docs] def forward(self, inputslist): if len(inputslist) != len(self.inputs): raise Exception("Number of modules and number of inputs differ") outs = [] for i in range(len(inputslist)): out = self.inputs[i](inputslist[i]) outs.append(out) axis = self.dim if not axis: axis = len(inputslist[0].size()-1) return torch.cat(outs, axis)