"""Common utility routines"""
import dataclasses
import functools
import inspect
import json
import os
import sys
import tarfile
import typing
import urllib
[docs]
def load_json_resource(json_resource_url: str) -> dict:
"""
Load the given JSON resource.
:param str json_resource_url: The json resource location, which can either
be a filename or a URL.
"""
parse_result = urllib.parse.urlparse(json_resource_url)
if not parse_result.scheme:
with open(json_resource_url, "rb") as f:
data = f.read()
else:
with urllib.request.urlopen(parse_result.geturl()) as resource:
data = resource.read()
return json.loads(data)
[docs]
def from_dict(cls, data: dict):
"""
Constructs an object of the given type from a dictionary using the
dictionary keys/value pairs that match parameter names of the type's
constructor.
:param cls: The type of the object to build
:param data: A dictionary where parameters for construction will be
extracted from
"""
return cls(**{k: v for k, v in data.items() if k in inspect.signature(cls).parameters})
[docs]
def untar(path, target_dir=None):
"""
Extracts the given file from a similarly-named tarfile into the given
directory.
:param path: The file to extract. The corresponding tarfile should have the
same name, suffixed with ``.tar.gz``.
:param target_dir: The directory where the tarfile will be extracted. If not
given the dirname of ``path`` is used.
"""
target_dir = target_dir or os.path.dirname(path)
target_file = os.path.join(target_dir, os.path.basename(path))
if not os.path.exists(target_file):
with tarfile.open(f"{path}.tar.gz", "r:gz") as tar:
tar.extractall(target_dir)
[docs]
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
Adapted from distutils.util.strtobool, which will disappear in Python 3.12.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError(f"invalid truth value {val}")
def _cast(value, typ):
if typ == bool:
return strtobool(value)
return typ(value)
def _underlying_type(cls, field: dataclasses.Field):
annotation = cls.__annotations__[field.name]
if sys.version_info[:2] == (3, 7):
# undocumented, but it works
typing_origin = annotation.__dict__.get("__origin__")
typing_args = annotation.__dict__.get("__args__")
else:
typing_origin = typing.get_origin(annotation)
typing_args = typing.get_args(annotation)
# Optional[X] == Union[X|None]
if (
typing_origin is typing.Union
and len(typing_args) == 2
and isinstance(None, typing_args[1])
):
return typing_args[0]
return field.type
def _ensure_field_types(cls, self):
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if value is None:
continue
elif value == "None":
setattr(self, field.name, None)
else:
typ = _underlying_type(cls, field)
if not isinstance(value, typ):
setattr(self, field.name, _cast(value, typ))
[docs]
def autocast_fields(cls):
"""
An annotation for dataclasses to automatically transform given field values
into their declared types. In order to be successful the ctor for an
annotated type T must accept values with the given type (e.g., int() accepts
str objects, so fields of type int can be initialised with str values).
``None`` values are left unset. ``"None"`` strings are casted to ``None``.
Boolean fields are casted using ``strtobool``.
"""
# pylint: disable=protected-access
cls._ensure_field_types = lambda self: _ensure_field_types(cls, self)
# Try to wrap special methods instead of adding them, works better with
# dataclasses
if cls.__dict__.get("__post_init__"):
to_wrap = "__post_init__"
elif cls.__dict__.get("__init__"):
to_wrap = "__init__"
else:
cls.__post_init__ = cls._ensure_field_types
return cls
original = getattr(cls, to_wrap)
if to_wrap == "__init__":
@functools.wraps(original)
def wrapped(self, *args, **kwargs):
original(self, *args, **kwargs)
self._ensure_field_types()
else:
@functools.wraps(original)
def wrapped(self, *args, **kwargs):
self._ensure_field_types()
original(self, *args, **kwargs)
setattr(cls, to_wrap, wrapped)
return cls