Source code for gatelfdata.targetnominal

"""Module for the TargetNominal class"""


from collections import Counter
from gatelfdata.vocab import Vocab
import sys
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
streamhandler = logging.StreamHandler(stream=sys.stderr)
formatter = logging.Formatter(
                '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
streamhandler.setFormatter(formatter)
logger.addHandler(streamhandler)


[docs]class TargetNominal(object): def __init__(self, meta, vocabs, targets_need_padding=False): self.meta = meta self.isSequence = meta["isSequence"] if self.isSequence: self.seq_max = meta["sequLengths.max"] self.seq_avg = meta["sequLengths.mean"] targetstats = meta["targetStats"] self.stringCounts = targetstats["stringCounts"] self.nrTargets = len(self.stringCounts) self.freqs = Counter(self.stringCounts) # so if we need to include a padding character for the targets, we set pad_index_only to True, if not, # we set no_special_indices True nspi = False pio = False if targets_need_padding: pio = True else: nspi = True self.vocab = Vocab(self.freqs, emb_id="<<TARGET>>", no_special_indices=nspi, pad_index_only=pio, emb_train="no") self.vocab.finish() vocabs.vocabs["<<TARGET>>"] = self.vocab # print("DEBUG!!!! Created vocab for target, itos is ", self.vocab.itos, "pad_index_only is", self.vocab.pad_index_only, file=sys.stderr) # influences if the conversion will return the index or # the onehot vector self.as_onehot = False
[docs] def set_as_onehot(self, flag=False): """Influence hot the original class label is converted. If the flag is False, then the string is converted to the corresponding string index, otherwise, to the corresponding onehot vector.""" self.as_onehot = flag
[docs] def zero_onehotvec(self): """Returns a zero vector with as many 0 as the one-hot representation would have.""" return self.vocab.zero_onehotvec()
def __call__(self, value, as_onehot=False): as_onehot = self.as_onehot or as_onehot if self.isSequence: if as_onehot: ret = [self.vocab.string2onehot(v) for v in value] else: ret = [self.vocab.string2idx(v) for v in value] else: if as_onehot: ret = self.vocab.string2onehot(value) else: ret = self.vocab.string2idx(value) # print("DEBUG looking up index for", value,"as_onehot=",as_onehot,"returning",ret,file=sys.stderr) return ret
[docs] def idx2label(self, idx): return self.vocab.idx2string(idx)
def __str__(self): return "TargetNominal()" def __repr__(self): return "TargetNominal()"