import logging
from functools import reduce
from typing import Any
import yaml
from ..constants import ConfigRoot
from ..exceptions import StageNotFoundException
from ..utils.io_utils import read_yml, write_yml
logger = logging.getLogger(__name__)
[docs]
class RuntimeConfig:
"""
Runtime configuration manager for the pipeline
Parameters
----------
pipeline
Pipeline stages state configuration.
parameters
Pipeline stages parameters configuration.
global_parameters
Pipeline global parameters configuration.
"""
def __init__(
self,
pipeline: dict,
parameters: dict,
global_parameters: dict,
version: str = None,
):
"""
Initialise the config manager object.
"""
self.pipeline = pipeline
self.parameters = parameters
self.global_parameters = global_parameters
self.version = version
[docs]
def set(self, config_path: str, value: Any):
"""
Updates a given key in the config
Parameters
----------
config_path
Path to the key. This can represent a path to a nested key,
where each key in the hirerarchy is seperated by
a ``.`` character.
value
Value to be update with.
Examples
--------
>>> stage_params = {
"stage1": {
"param1": 10
}
}
>>> runtimeconfig = RuntimeConfig({"stage1": True}, stage_params, {})
>>> runtimeconfig.set("parameters.stage1.param1", 20)
>>> runtimeconfig.set("pipeline.stage1", False)
"""
root_path, *path = config_path.split(".")
if root_path == ConfigRoot.GLOBAL_PARAMETERS:
self.global_parameters = value
elif root_path == ConfigRoot.PARAMETERS:
if path[0] not in self.parameters:
raise StageNotFoundException(path[0])
config_param = reduce(
lambda acc, val: acc.get(val), path[:-1], self.parameters
)
if config_param is None:
raise ValueError(
f"Path {'.'.join(path)} not found in configuration yaml"
)
config_param[path[-1]] = value
elif root_path == ConfigRoot.PIPELINE:
if type(value) is not bool:
raise ValueError(
f"Stage flags need to be boolean, {type(value)} provided"
)
if path[0] not in self.pipeline:
raise StageNotFoundException(path[0])
if len(path) != 1:
raise ValueError("Illegal stage name parameter provided")
self.pipeline[path[0]] = value
else:
raise ValueError(f"Invalid configuration path {config_path}")
[docs]
def update_from_yaml(self, yaml_path: str):
"""
Updates current runtime configuration from a yaml file
Parameters
----------
yaml_path: str
Path to the yaml configuration file
Returns
-------
RuntimeConfig
Same instance but with updated paramters.
"""
if not yaml_path:
return self
config_dict = read_yml(yaml_path)
_pipelines = config_dict.get(ConfigRoot.PIPELINE, dict())
_parameters = config_dict.get(ConfigRoot.PARAMETERS, dict())
if diff := _pipelines.keys() - self.pipeline.keys():
logger.warning("Found unknown stages %s in pipeline." % diff)
if diff := _parameters.keys() - self.parameters.keys():
logger.warning("Found unknown stages %s in parameters." % diff)
self.pipeline = {
**self.pipeline,
**_pipelines,
}
self.parameters = {
**self.parameters,
**_parameters,
}
self.global_parameters = {
**self.global_parameters,
**config_dict.get(ConfigRoot.GLOBAL_PARAMETERS, dict()),
}
return self
[docs]
def update_from_cli_overrides(self, cli_overrides: list[tuple[str, str]]):
"""
Update the runtime configuration from CLI overrides option
Parameters
----------
cli_overrides
A list of tuple/list, where each element contains 2 sub-elements.
First element is a string key to the parameter to override.
This can represent a path to a nested key, where each key
in the hirerarchy is seperated by a ``.`` character.
Second element is also a string, which is then parsed using
YAML rules, and converted to a rich object.
Returns
-------
RuntimeConfig
Same instance but with updated paramters
Examples
--------
>>> runtimeconfig.update_from_cli_overrides(
["parameters.stage1.param1", "20"],
["pipeline.stage1", "False"],
)
"""
if not cli_overrides:
return self
params_to_update = yaml.safe_load(
"\n".join(f"{a} : {b}" for a, b in cli_overrides)
)
for path, value in params_to_update.items():
self.set(path, value)
return self
[docs]
def update_from_cli_stages(self, cli_stages: list[str]):
"""
Update the pipeline stage states
Parameters
----------
cli_stages
Names of the stages to be enabled
Returns
-------
RuntimeConfig
Same instance but with updated paramters
"""
if not cli_stages:
return self
self.pipeline = {stage: stage in cli_stages for stage in self.pipeline}
return self
[docs]
def write_yml(self, path):
"""
Writes config to provided path in yaml format.
Parameters
----------
path: str
Location of config file to write to.
"""
config = {
ConfigRoot.VERSION: self.version,
ConfigRoot.GLOBAL_PARAMETERS: self.global_parameters,
ConfigRoot.PARAMETERS: self.parameters,
ConfigRoot.PIPELINE: self.pipeline,
}
write_yml(path, config)
@property
def stages_to_run(self) -> list[str]:
"""
Returns the list of names of stages which are enabled for this run.
Returns
-------
List of names of the stages to run.
"""
return [stage for stage in self.pipeline if self.pipeline.get(stage)]