from inspect import cleandoc
from typing import Any, Callable, Optional, Tuple
import schema
[docs]class TMSchema(schema.Schema):
"""Wrapper on top of schema.Schema for incremental schema build-up."""
def __init__(
self,
schema: Any = None,
error=None,
ignore_extra_keys: bool = False,
name: str = None,
description: str = None,
as_reference: bool = False,
version: str = None,
strict: bool = False,
):
"""
:param schema: Schema data (can be dictionary, list, value, see
`schema`)
:param error: Error message to show (see `schema`)
:param ignore_extra_keys: Allows extra keys in non-strict modes
:param name: Name to use in error messages
:param description: Description to show in documentation
:param as_reference: Generate separate sub-schema in JSON +
documentation?
:param version: Version of the schema
:param strict: Strict mode?
"""
self._version = version
self._strict = strict
self._raw_name = name
if schema is None:
schema = {}
if not strict:
ignore_extra_keys = True
if version is not None:
version_num = version.rsplit("/", 1)[1]
name = name.replace("/", "_") + f" {version_num}"
super(TMSchema, self).__init__(
schema=schema,
error=error,
ignore_extra_keys=ignore_extra_keys,
name=name,
description=description,
as_reference=as_reference,
)
@classmethod
def new(cls, name: str, version: str, strict: bool, **kwargs):
return TMSchema(name=name, version=version, strict=strict, **kwargs)
@property
def raw_name(self):
return self._raw_name
@property
def version(self):
return self._version
@property
def strict(self):
return self._strict
def add_field(
self,
name: str,
check: Any,
check_strict: Any = None,
description: str = None,
optional: bool = False,
default: Any = None,
):
# Description + optional get indicated on the name
if description is not None:
name = schema.Literal(name, description=cleandoc(description))
if optional:
name = schema.Optional(name, default=default)
# Stricter check given?
if check_strict is not None and self._strict:
check = schema.And(check, check_strict)
# Add to schema
self._schema[name] = check
def add_opt_field(
self,
name: str,
check: Any,
check_strict: Any = None,
description: str = None,
default: Any = None,
):
return self.add_field(
name, check, check_strict, description, True, default
)
def update(self, dct):
self._schema.update(
dct._schema if isinstance(dct, TMSchema) else dict(dct)
)
def __getitem__(self, name: str):
if name in self._schema:
return self._schema[name]
for key, item in self._schema.items():
if isinstance(key, schema.Literal):
key = key.schema
if isinstance(key, schema.Optional):
key = key.key
if key == name:
return item
return None
[docs] def is_field_optional(self, name: str) -> Optional[bool]:
"""
Checks whether the field with the given name is optional.
Returns `None` if the field does not exist
:param name: Name of the field
:returns: bool
"""
for key in self._schema.keys():
optional = False
if isinstance(key, schema.Optional):
key = key.key
optional = True
if isinstance(key, schema.Literal):
key = key.schema
if key == name:
return optional
return None
[docs] def find_field_recursive(self, name: str) -> Optional["TMSchema"]:
"""
Recursively finds a field of the given name in the schema
If the key exists multiple times, an arbitrary item will get
returned. Note that to be returned by this function, the field must be
in a `TMSchema` - if the schema is specified as a dictionary, it won't
be found.
:param name: Name of the field to look for
:returns: A schema containing the given key
"""
# Does exist?
if self[name] is not None:
return self
# Recurse
for _, item in self._schema.items():
# Nested schema? Check values
if isinstance(item, TMSchema):
parent = item.find_field_recursive(name)
if parent is not None:
return parent
# List? Check values
elif isinstance(item, list):
for obj in item:
if isinstance(obj, TMSchema):
parent = obj.find_field_recursive(name)
if parent is not None:
return parent
# Dictionary? Check values
elif isinstance(item, dict):
for obj in item.values():
if isinstance(obj, TMSchema):
parent = obj.find_field_recursive(name)
if parent is not None:
return parent
return None
[docs]def mk_if(cond: bool) -> Callable[[Any], Any]:
"""Generate schema combinator to conditionally activate a part."""
return (lambda x: x) if cond else (lambda x: schema.And())
def get_channel_map_schema(
elem_type: Any, version: int, strict: bool
) -> schema.Schema:
elem_schema = schema.Schema(elem_type)
def valid_channel_map_entry(entry):
if strict and any([not elem_schema.is_valid(e) for e in entry[1:]]):
return False
return isinstance(entry[0], int)
return [valid_channel_map_entry]
[docs]def get_unique_id_schema(
strict: bool, type_re: str = r"[a-z0-9]+"
) -> schema.Schema:
"""Return schema for unique identifier.
:param type_re: Restricts ID type(s) to accept.
"""
if strict:
return schema.Regex(
r"^" + type_re + r"\-[a-z0-9]+\-[0-9]{8}\-[a-z0-9]+$"
)
else:
return str
[docs]def interface_uri(prefix: str, *versions: int) -> str:
"""Make an URI from the given prefix and versions
:param prefix: Schema URI prefix. Must end in '/'
:param versions: Components of the version
"""
assert prefix[-1] == "/"
return f"{prefix}{'.'.join(str(v) for v in versions)}"
[docs]def split_interface_version(version: str) -> Tuple[int, int]:
"""Extracts version number from interface URI
:param version: Version string
:returns: (major version, minor version) tuple
"""
# get the string with the interface semantic version (X.Y)
version_num = version.rsplit("/", 1)[1]
(major_version, minor_version) = version_num.split(".")
return int(major_version), int(minor_version)