resources.wrappers.FileJsonPyTorch.gate-lf-pytorch-json.docs.pythondoc.gatelfpytorchjson.modelwrapper.html Maven / Gradle / Ivy
Show all versions of learningframework Show documentation
gatelfpytorchjson.modelwrapper module — GATE LF Pytorch Wrapper (gatelfpytorch) documentation
gatelfpytorchjson.modelwrapper module¶
-
class
gatelfpytorchjson.modelwrapper.
ModelWrapper
(dataset, config={})[source]¶
Bases: object
Common base class for all wrappers. Defines instance methods which are the same
for all subclasses plus common static utility methods.
-
static
accuracy
(batch_predictions, batch_targets, pad_index=-1)[source]¶
Calculate the accuracy from a tensor with predictions, which contains scores for each
class in the last dimension (higher scores are better) and a tensor with target indices.
Tensor elements where the target has the padding index are ignored.
If the tensors represent sequences the shape of the predictions is batchsize, maxseqlen, nclasses
and of the targets is batchsize, maxseqlen, otherwise the predictions have shape
batchsize, nclasses, targets have shape batchsize
-
static
early_stopping_checker
(losses=None, accs=None, patience=2, mindelta=0.0)[source]¶
Takes two lists of numbers, representing the losses and/or accuracies of all validation
steps.
If accs is not None, it is used, otherwise losses is used if not None, otherwise always
returns False (do not stop).
If accuracies are used, at most patience number of the last validation accuracies can
NOT be at least mindelta larger than the best one so far.
If losses are used, at most patience number of last validation losses can NOT be
at least mindelta smaller then the best one so far.
In other words this stops if more that patience of the last metrics are not an improvement
of at least mindelta over the current best.