import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
import tango # type: ignore
import yaml # type: ignore
from ska_ser_logging import configure_logging
from tango.utils import CaselessDict # type: ignore
from . import ARCHIVING_PARAMS, AttributeConfig
from .schema import check_config, check_defaults
configure_logging()
ARCHIVING_DEFAULT_PARAMS = {"archive_strategy": "ALWAYS"}
logger = logging.getLogger(__name__)
[docs]
def get_tango_database(tango_host):
"""Get connection to Tango DB, exit with error if it fails"""
try:
host, port = tango_host.split(":")
return tango.Database(host, port)
except tango.DevFailed:
logger.fatal(
(
"Can't connect to control system %r!"
+ " Is is spelled correctly? Are you on the correct network?"
),
tango_host,
)
raise RuntimeError("No control system found")
[docs]
def log_validation_errors(filename, errors):
"""Helper to display validation errors"""
# TODO we could probably make these errors more human readable
for error in errors:
logger.error(
"Validation error in %r: %s (%r)",
filename,
error.message,
list(error.path),
)
[docs]
def load_yaml_file(path: Path):
logger.debug("Loading file: %s.", path)
with path.open() as stream:
try:
content = yaml.load(stream, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
logger.error("YAML error in %r: %s", path.name, e)
raise RuntimeError("Failed to parse YAML file") from e
return content
[docs]
def merge_defaults(config, defaults):
"""Fill in a device class configuration from defaults"""
logger.debug("Merging defaults for config: %r", config)
defaults = CaselessDict(defaults) # Case must be ignored for Tango stuff!
clss = config.get("class")
if clss not in defaults:
return config
new_config = {"class": clss, "attributes": {}}
class_defaults = defaults[clss]
attribute_defaults = class_defaults["attributes"]
attributes = CaselessDict(config.get("attributes", {}))
all_attrs = set(attribute_defaults) | set(attributes)
for attr_name in all_attrs:
if attr_name in attributes and attributes[attr_name] is None:
# Attribute is explicitly skipped
continue
params = attributes.get(attr_name, {})
param_defaults = attribute_defaults.get(attr_name, {})
new_params = new_config["attributes"][attr_name] = {}
for param in ARCHIVING_PARAMS:
if param in params and params[param] is None:
# Parameter is explicitly skipped
continue
value = params.get(param, param_defaults.get(param))
if value is not None:
new_params[param] = value
try:
new_config["filtering"] = format_filters(config.get("filtering", {}))
except re.error as e:
raise RuntimeError(
"Filter pattern %r is not valid: %s", e.pattern, e.msg
) from e
return new_config
[docs]
def check_filters(config):
"""Check that device filters are sane, else report errors"""
for c in config["configuration"]:
if "filtering" in c:
device = c["filtering"].get("device", [])
if isinstance(device, str):
device = [device]
for d in device:
try:
re.compile(d)
except re.error as e:
yield (
f"Bad regex '{e.pattern}'"
f"for class {c['class']}: {e.msg}"
)
# TODO also sanity check server patterns somehow?
[docs]
def load_configuration(filename: str):
"""
Reads a YAML configuration file and combines it with any defaults.
Returns a dict containing the complete configuration.
Raises RuntimeError on any fatal errors.
"""
yaml_file = Path(filename)
if not yaml_file.exists():
raise RuntimeError(f"{filename} does not exist.")
config = load_yaml_file(yaml_file)
errors = list(check_config(config))
if errors:
log_validation_errors(filename, errors)
raise RuntimeError("Configuration is not valid, see errors above")
filter_errors = list(check_filters(config))
for error in filter_errors:
logger.error("Broken filter in %r: %s", filename, error)
raise RuntimeError("Configuration is not valid, see errors above")
# Load, validate and combine defaults
# Note: If defaults overlap, later defaults in the list overwrite earlier.
default_files = config.get("defaults", [])
defaults: Dict[str, AttributeConfig] = {"classes": {}}
for d in default_files:
defs = load_yaml_file(yaml_file.parent / d)
errors = list(check_defaults(defs))
if errors:
log_validation_errors(d, errors)
raise RuntimeError(
"Default configuration is not valid, see errors above"
)
defaults["classes"].update(defs.get("classes", {}))
if "db" in defs:
defaults["db"].update(defs["db"])
if "manager" in defs:
defaults["manager"].update(defs["manager"])
# Fill in the configuration with defaults
if "db" not in config:
config["db"] = defaults["db"]
if "manager" not in config:
config["manager"] = defaults["manager"]
if defaults:
config["configuration"] = [
merge_defaults(c, defaults["classes"])
for c in config["configuration"]
]
# OK hopefully the config is still valid. Just checking again to be safe
errors = list(check_config(config))
if errors:
# Somehow valid config and default combined into something broken.
# If this happens, it should be a bug.
log_validation_errors("...", errors)
raise RuntimeError(
"Final configuration is not valid. This should not happen!"
)
# Note: from here we know that the config is valid. We could probably type
# it too, but that seems tedious...
logger.debug("Getting DB: %s ", config["db"])
logger.debug("Getting Manager: %s ", config["manager"])
return config
[docs]
def get_class_devices(
db: tango.Database,
clss: str,
device: Optional[Union[List[str], str]] = None,
server: Union[List[str], str] = "*",
) -> List[str]:
"""
Returns a list of devices for a Class. Also
can apply different filters i.e device or server.
"""
device_filter = device
server_filter = server
# Get devices from Class/Servers
if isinstance(server_filter, str):
server_filter = [server_filter]
# List of server patterns; collect all matching devices
results: List[str] = []
for s in server_filter:
results.extend(db.get_device_name(s, clss))
db_devices = sorted(results)
if not db_devices:
logger.warning(
"No devices found for class %r! Perhaps it was mis-spelled?", clss
)
return db_devices # no sense in filtering an empty list
if device_filter:
devices: Set[str] = set()
if isinstance(device_filter, str):
# device pattern can be a single pattern or a list of patterns
device_filter = [device_filter]
for pattern in device_filter:
matches = [
dev
for dev in db_devices
if re.fullmatch(pattern, dev, flags=re.IGNORECASE)
]
if not matches:
logger.warning(
"Pattern %r for class %r matches no devices!",
pattern,
clss,
)
devices.update(matches)
return sorted(devices)
return db_devices
[docs]
def get_desired_attributes(
tango_host: str, configs, db: tango.Database = None
) -> Dict[str, Any]:
"""Based on a given config, get the list of attributes affected."""
desired_attributes = {}
for config in configs:
clss: str = config.get("class")
attributes: Dict = config.get("attributes", {})
if not attributes:
logger.warning("No attributes given for class %r!", clss)
continue
if clss:
filtering: Dict = config.get("filtering", {})
devices = get_class_devices(db, clss, **filtering)
else:
devices = []
if not devices:
continue
for attr, params in attributes.items():
for device in devices:
full_attr_name = (
f"tango://{tango_host}/{device}/{attr}".lower()
)
desired_attributes[full_attr_name] = {
**ARCHIVING_DEFAULT_PARAMS,
**params,
}
return desired_attributes