
runtime.src.rosetta.runtime.utils.py Maven / Gradle / Ivy
'''Utility functions (runtime) for rosetta models.'''
from __future__ import annotations
import logging as log
import keyword
from enum import Enum
from typing import get_args, get_origin
from typing import TypeVar, Generic, Callable, Any
from functools import wraps
from collections import defaultdict
from pydantic import BaseModel, ValidationError, ConfigDict
__all__ = ['if_cond', 'if_cond_fn', 'Multiprop', 'rosetta_condition',
'BaseDataClass', 'ConditionViolationError', 'any_elements',
'get_only_element', 'rosetta_filter',
'all_elements', 'contains', 'disjoint', 'join',
'rosetta_local_condition',
'execute_local_conditions',
'flatten_list',
'rosetta_resolve_attr',
'rosetta_count',
'rosetta_attr_exists',
'_get_rosetta_object',
'set_rosetta_attr',
'add_rosetta_attr',
'check_cardinality',
'AttributeWithMeta',
'AttributeWithAddress',
'AttributeWithReference',
'AttributeWithMetaWithAddress',
'AttributeWithMetaWithReference',
'AttributeWithAddressWithReference',
'AttributeWithMetaWithAddressWithReference',
'rosetta_str']
def if_cond(ifexpr, thenexpr: str, elseexpr: str, obj: object):
'''A helper to return the value of the ternary operator.'''
expr = thenexpr if ifexpr else elseexpr
return eval(expr, globals(), {'self': obj}) # pylint: disable=eval-used
def if_cond_fn(ifexpr, thenexpr: Callable, elseexpr: Callable) -> Any:
''' A helper to return the value of the ternary operator
(functional version).
'''
expr = thenexpr if ifexpr else elseexpr
return expr()
def _to_list(obj) -> list | tuple:
if isinstance(obj, (list, tuple)):
return obj
return (obj,)
def _is_meta(obj: Any) -> bool:
'''Returns true if it is a meta data with embedded rosetta type.'''
return isinstance(
obj, (AttributeWithMeta, AttributeWithAddress,
AttributeWithMetaWithAddress, AttributeWithMetaWithReference,
AttributeWithMetaWithAddressWithReference))
def mangle_name(attrib: str) -> str:
''' Mangle any attrib that is a Python keyword, is a Python soft keyword
or begins with _
'''
if (keyword.iskeyword(attrib) or keyword.issoftkeyword(attrib)
or attrib.startswith('_')):
return 'rosetta_attr_' + attrib
return attrib
def rosetta_resolve_attr(obj: Any | None,
attrib: str) -> Any | list[Any] | None:
''' Rosetta semantics compliant attribute resolver.
Lists and mangled attributes are treated as defined by
the rosetta definition (list flattening).
'''
if obj is None:
return None
if isinstance(obj, (list, tuple)):
res = [
item for elem in obj
for item in _to_list(rosetta_resolve_attr(elem, attrib))
if item is not None
]
return res if res else None
if _is_meta(obj):
# NOTE: ignores (for now) all meta attributes in the expressions.
# In the future one might want to check if the attrib is contained
# in the metadata and return it instead of failing.
obj = obj.value
attrib = mangle_name(attrib)
return getattr(obj, attrib, None)
def rosetta_count(obj: Any | None) -> int:
'''Implements the lose count semantics of the rosetta DSL'''
if not obj:
return 0
try:
return len(obj)
except TypeError:
return 1
def rosetta_attr_exists(val: Any) -> bool:
'''Implements the Rosetta semantics of property existence'''
if val is None or val == []:
return False
return True
def rosetta_str(x: Any) -> str:
'''Returns a Rosetta conform string representation'''
if isinstance(x, Enum):
x = x.value
return str(x)
def _get_rosetta_object(base_model: str, attribute: str, value: Any) -> Any:
model_class = globals()[base_model]
instance_kwargs = {attribute: value}
instance = model_class(**instance_kwargs)
return instance
class Multiprop(list):
''' A class allowing for dot access to a attribute of all elements of a
list.
'''
def __getattr__(self, attr):
# return multiprop(getattr(x, attr) for x in self)
res = Multiprop()
for x in self:
if isinstance(x, Multiprop):
res.extend(x.__getattr__(attr))
else:
res.append(getattr(x, attr))
return res
_CONDITIONS_REGISTRY: defaultdict[str, dict[str, Any]] = defaultdict(dict)
def rosetta_condition(condition):
'''Wrapper to register all constraint functions in the global registry'''
path_components = condition.__qualname__.split('.')
path = '.'.join([condition.__module__ or ''] + path_components[:-1])
name = path_components[-1]
_CONDITIONS_REGISTRY[path][name] = condition
@wraps(condition)
def wrapper(*args, **kwargs):
return condition(*args, **kwargs)
return wrapper
def rosetta_local_condition(registry: dict):
'''Registers a condition function in a local registry.'''
def decorator(condition):
path_components = condition.__qualname__.split('.')
path = '.'.join([condition.__module__ or ''] + path_components)
registry[path] = condition
@wraps(condition)
def wrapper(*args, **kwargs):
return condition(*args, **kwargs)
return wrapper
return decorator
def execute_local_conditions(registry: dict, cond_type: str):
'''Executes all registered in a local registry.'''
for condition_path, condition_func in registry.items():
if not condition_func():
raise ConditionViolationError(
f"{cond_type} '{condition_path}' failed.")
class ConditionViolationError(ValueError):
'''Exception thrown on violation of a constraint'''
def _fqcn(cls) -> str:
return '.'.join([cls.__module__ or '', cls.__qualname__])
def _get_conditions(cls) -> list:
res = []
index = cls.__mro__.index(BaseDataClass)
for c in reversed(cls.__mro__[:index]):
fqcn = _fqcn(c)
res += [('.'.join([fqcn, k]), v)
for k, v in _CONDITIONS_REGISTRY.get(fqcn, {}).items()]
return res
class MetaAddress(BaseModel): # pylint: disable=missing-class-docstring
scope: str
value: str
class BaseDataClass(BaseModel):
''' A base class for all cdm generated classes. It is derived from
`pydantic.BaseModel` which provides type checking at object creation
for all cdm classes. It provides as well the `validate_model`,
`validate_conditions` and `validate_attribs` methods which perform the
conditions, cardinality and type checks as specified in the rosetta
type model. The method `validate_model` is not invoked automatically,
but is left to the user to determine when to check the validity of the
cdm model.
'''
model_config = ConfigDict(extra='forbid', revalidate_instances='always')
meta: dict | None = None
address: MetaAddress | None = None
def validate_model(self,
recursively: bool = True,
raise_exc: bool = True,
strict: bool = True) -> list:
''' This method performs full model validation. It will validate all
attributes and it will also invoke `validate_conditions` to check
all conditions and the cardinality of all attributes of this object.
The parameter `raise_exc` controls whether an exception should be
thrown if a validation or condition is violated or if a list with
all encountered violations should be returned instead.
'''
att_errors = self.validate_attribs(raise_exc=raise_exc, strict=strict)
return att_errors + self.validate_conditions(recursively=recursively,
raise_exc=raise_exc)
def validate_attribs(self, raise_exc: bool = True, strict: bool = True) -> list:
''' This method performs attribute type validation.
The parameter `raise_exc` controls whether an exception should be
thrown if a validation or condition is violated or if a list with
all encountered violations should be returned instead.
'''
try:
self.model_validate(self, strict=strict)
except ValidationError as validation_error:
if raise_exc and validation_error:
raise validation_error
return [validation_error]
return []
def validate_conditions(self,
recursively: bool = True,
raise_exc: bool = True) -> list:
''' This method will check all conditions and the cardinality of all
attributes of this object. This includes conditions and cardinality
of properties specified in the base classes. If the parameter
`recursively` is set to `True`, it will invoke the validation on the
rosetta defined attributes of this object too.
The parameter `raise_exc` controls whether an exception should be
thrown if a condition is not met or if a list with all encountered
condition violations should be returned instead.
'''
self_rep = object.__repr__(self)
log.info('Checking conditions for %s ...', self_rep)
exceptions = []
for name, condition in _get_conditions(self.__class__):
log.info('Checking condition %s for %s...', name, self_rep)
if not condition(self):
msg = f'Condition "{name}" for {repr(self)} failed!'
log.error(msg)
exc = ConditionViolationError(msg)
if raise_exc:
raise exc
exceptions.append(exc)
else:
log.info('Condition %s for %s satisfied.', name, self_rep)
if recursively:
for k, v in self.__dict__.items():
log.info('Validating conditions of property %s', k)
exceptions += _validate_conditions_recursively(
v, raise_exc=raise_exc)
err = f'with {len(exceptions)}' if exceptions else 'without'
log.info('Done conditions checking for %s %s errors.', self_rep, err)
return exceptions
def check_one_of_constraint(self, *attr_names, necessity=True) -> bool:
""" Checks that one and only one attribute is set. """
values = self.model_dump()
vals = [values.get(n) for n in attr_names]
n_attr = sum(1 for v in vals if v is not None and v != [])
if necessity and n_attr != 1:
log.error('One and only one of %s should be set!', attr_names)
return False
if not necessity and n_attr > 1:
log.error('Only one of %s can be set!', attr_names)
return False
return True
def add_to_list_attribute(self, attr_name: str, value) -> None:
"""
Adds a value to a list attribute, ensuring the value is of an allowed
type.
Parameters:
attr_name (str): Name of the list attribute.
value: Value to add to the list.
Raises:
AttributeError: If the attribute name is not found or not a list.
TypeError: If the value type is not one of the allowed types.
"""
if not hasattr(self, attr_name):
raise AttributeError(f"Attribute {attr_name} not found.")
attr = getattr(self, attr_name)
if not isinstance(attr, list):
raise AttributeError(f"Attribute {attr_name} is not a list.")
# Get allowed types for the list elements
allowed_types = get_allowed_types_for_list_field(
self.__class__, attr_name)
# Check if value is an instance of one of the allowed types
if not isinstance(value, allowed_types):
raise TypeError(
f"Value must be an instance of {allowed_types}, "
f"not {type(value)}"
)
attr.append(value)
def _validate_conditions_recursively(obj, raise_exc=True):
'''Helper to execute conditions recursively on a model.'''
if not obj:
return []
if isinstance(obj, BaseDataClass):
return obj.validate_conditions(recursively=True, # type:ignore
raise_exc=raise_exc)
if isinstance(obj, (list, tuple)):
exc = []
for item in obj:
exc += _validate_conditions_recursively(item, raise_exc=raise_exc)
return exc
if _is_meta(obj):
return _validate_conditions_recursively(obj.value, raise_exc=raise_exc)
return []
def get_allowed_types_for_list_field(model_class: type, field_name: str):
"""
Gets the allowed types for a list field in a Pydantic model, supporting
both Union and | operator.
Parameters:
model_class (type): The Pydantic model class.
field_name (str): The field name.
Returns:
tuple: A tuple of allowed types.
"""
field_type = model_class.__annotations__.get(field_name)
if field_type and get_origin(field_type) is list:
list_elem_type = get_args(field_type)[0]
if get_origin(list_elem_type):
return get_args(list_elem_type)
return (list_elem_type,) # Single type or | operator used
return ()
ValueT = TypeVar('ValueT')
class AttributeWithMeta(BaseModel, Generic[ValueT]):
'''Meta support'''
meta: dict | None = None
value: ValueT
class AttributeWithAddress(BaseModel, Generic[ValueT]):
'''Meta support'''
address: MetaAddress | None = None
value: ValueT | None = None
class AttributeWithReference(BaseDataClass):
'''Meta support'''
externalReference: str | None = None
globalReference: str | None = None
class AttributeWithMetaWithAddress(BaseModel, Generic[ValueT]):
'''Meta support'''
meta: dict | None = None
address: MetaAddress | None = None
value: ValueT
class AttributeWithMetaWithReference(BaseModel, Generic[ValueT]):
'''Meta support'''
meta: dict | None = None
externalReference: str | None = None
globalReference: str | None = None
value: ValueT
class AttributeWithAddressWithReference(BaseModel, Generic[ValueT]):
'''Meta support'''
address: MetaAddress | None = None
externalReference: str | None = None
globalReference: str | None = None
value: ValueT
class AttributeWithMetaWithAddressWithReference(BaseModel, Generic[ValueT]):
'''Meta support'''
meta: dict | None = None
address: MetaAddress | None = None
externalReference: str | None = None
globalReference: str | None = None
value: ValueT
def _ntoz(v):
'''Support the lose rosetta treatment of None in comparisons'''
if v is None:
return 0
return v
_cmp = {
'=': lambda x, y: _ntoz(x) == _ntoz(y),
'<>': lambda x, y: _ntoz(x) != _ntoz(y),
'>=': lambda x, y: _ntoz(x) >= _ntoz(y),
'<=': lambda x, y: _ntoz(x) <= _ntoz(y),
'>': lambda x, y: _ntoz(x) > _ntoz(y),
'<': lambda x, y: _ntoz(x) < _ntoz(y)
}
def all_elements(lhs, op, rhs) -> bool:
'''Checks that two lists have the same elements'''
cmp = _cmp[op]
op1 = _to_list(lhs)
op2 = _to_list(rhs)
return all(cmp(el1, el2) for el1 in op1 for el2 in op2)
def disjoint(op1, op2):
'''Checks if two lists have no common elements'''
op1 = set(_to_list(op1))
op2 = set(_to_list(op2))
return not op1 & op2
def contains(op1, op2):
''' Checks if op2 is contained in op1
(e.g. every element of op2 is in op1)
'''
op1 = set(_to_list(op1))
op2 = set(_to_list(op2))
return op2.issubset(op1)
def join(lst, sep=''):
''' Joins the string representation of the list elements, optionally
separated.
'''
return sep.join([str(el) for el in lst])
def any_elements(lhs, op, rhs) -> bool:
'''Checks if to lists have any common element(s)'''
cmp = _cmp[op]
op1 = _to_list(lhs)
op2 = _to_list(rhs)
return any(cmp(el1, el2) for el1 in op1 for el2 in op2)
def check_cardinality(prop, inf: int, sup: int | None = None) -> bool:
''' If the supremum is not supplied (e.g. is None), the property is
unbounded (e.g. it corresponds to (x..*) in rosetta).
'''
if not prop:
prop_card = 0
elif isinstance(prop, (list, tuple)):
prop_card = len(prop)
else:
prop_card = 1
if sup is None:
sup = prop_card
return inf <= prop_card <= sup
def get_only_element(collection):
''' Returns the single element of a collection, if the list contains more
more than one element or is empty, None is returned.
'''
if isinstance(collection, (list, tuple)) and len(collection) == 1:
return collection[0]
return None
def flatten_list(nested_list):
'''flattens the list of lists (no-recursively)'''
return [item for sublist in nested_list for item in sublist]
def rosetta_filter(items, filter_func, item_name='item'):
"""
Filters a list of items based on a specified filtering criteria provided as
a boolean lambda function.
:param items: List of items to be filtered.
:param filter_func: A lambda function representing the boolean expression
for filtering.
:param item_name: The name used to refer to each item in the boolean
expression.
:return: Filtered list.
"""
return [item for item in items if filter_func(locals()[item_name])]
def set_rosetta_attr(obj: Any, path: str, value: Any) -> None:
"""
Sets an attribute of a Rosetta model object to a specified value using a
path.
Parameters:
obj (Any): The object whose attribute is to be set.
path (str): The path to the attribute, with components separated by '->'.
value (Any): The value to set the attribute to.
Raises:
ValueError: If the object or attribute at any level in the path is None.
AttributeError: If an invalid attribute path is provided.
"""
if obj is None:
raise ValueError(
"Cannot set attribute on a None object in set_rosetta_attr.")
path_components = path.split('->') # Use '->' for splitting the path
parent_obj = obj
# Iterate through the path components, except the last one
for attrib in path_components[:-1]:
parent_obj = rosetta_resolve_attr(parent_obj, attrib)
if parent_obj is None:
raise ValueError(
f"Attribute '{attrib}' in the path is None, cannot "
"proceed to set value."
)
# Set the value to the last attribute in the path
final_attr = path_components[-1]
if hasattr(parent_obj, final_attr):
setattr(parent_obj, final_attr, value)
else:
raise AttributeError(
f"Invalid attribute '{final_attr}' for object of "
f"type {type(parent_obj).__name__}"
)
def add_rosetta_attr(obj: Any, attrib: str, value: Any) -> None:
"""
Adds a value to a list-like attribute of a Rosetta model object.
Parameters:
obj (Any): The object whose attribute is to be modified.
attrib (str): The list-like attribute to add the value to.
value (Any): The value to add to the attribute.
"""
if obj is not None:
if hasattr(obj, attrib):
current_attr = getattr(obj, attrib)
if isinstance(current_attr, list):
current_attr.append(value)
else:
raise TypeError(f"Attribute {attrib} is not list-like.")
else:
setattr(obj, attrib, [value])
else:
raise ValueError("Object for add_rosetta_attr cannot be None.")
# EOF
© 2015 - 2025 Weber Informatics LLC | Privacy Policy