import logging
import dask
from distributed import as_completed, futures_of, get_client
from ska_sdp_piper.piper.runners import DaskRunner
from .tagger import Tags
logger = logging.getLogger()
[docs]
class UpstreamOutput:
"""
Container for managing outputs and metadata between pipeline stages.
This class acts as a shared context object, allowing downstream stages to
access the results of upstream stages via dictionary-style or attribute-
style access. It also tracks computational tasks, checkpoint keys, and
execution counts for each stage.
Attributes
----------
stage_compute_tasks : list
A list of delayed compute tasks (e.g., Dask graphs) accumulated
during the pipeline execution.
checkpoint_keys : list
A list of keys identifying data that should be checkpointed or
persisted.
compute_outputs : list
A list to store results of computations.
"""
def __init__(self):
"""
Initialize the UpstreamOutput container.
"""
self.__stage_outputs = {}
self.stage_compute_tasks = []
self.checkpoint_keys = []
self.compute_outputs = []
self.__call_count = {}
def __setitem__(self, key, value):
"""
Store an output value for a specific stage key.
Parameters
----------
key : str
The identifier for the stage output.
value : any
The data or object to store.
"""
self.__stage_outputs[key] = value
def __getitem__(self, key):
"""
Retrieve a stage output by key.
Parameters
----------
key : str
The identifier of the output to retrieve.
Returns
-------
any
The value associated with the key.
Raises
------
AttributeError
If the key is not present in the outputs.
"""
if key not in self.__stage_outputs:
raise AttributeError(f"{key} not present in upstream-output.")
return self.__stage_outputs[key]
def __getattr__(self, key):
"""
Retrieve a stage output using attribute access syntax.
Parameters
----------
key : str
The identifier of the output to retrieve.
Returns
-------
any
The value associated with the key.
Raises
------
AttributeError
If the key is not present in the outputs.
"""
if key not in self.__stage_outputs:
raise AttributeError(f"{key} not present in upstream-output.")
return self.__stage_outputs[key]
def __contains__(self, key):
"""
Check if a key exists in the stage outputs.
Parameters
----------
key : str
The identifier to check.
Returns
-------
bool
True if the key exists, False otherwise.
"""
return key in self.__stage_outputs
[docs]
def get_call_count(self, stage_name):
"""
Get the number of times a specific stage has been executed.
Parameters
----------
stage_name : str
The name of the stage.
Returns
-------
int
The execution count (default is 0).
"""
return self.__call_count.get(stage_name, 0)
[docs]
def increment_call_count(self, stage_name):
"""
Increment the execution counter for a specific stage.
Parameters
----------
stage_name : str
The name of the stage to increment.
"""
self.__call_count[stage_name] = (
self.__call_count.get(stage_name, 0) + 1
)
@property
def compute_tasks(self):
"""
list: Get the list of accumulated compute tasks.
"""
return self.stage_compute_tasks
[docs]
def add_compute_tasks(self, *args):
"""
Register new compute tasks to the pipeline.
Parameters
----------
*args
One or more task objects (e.g., Dask delayed objects) to add
to the execution queue.
"""
self.stage_compute_tasks.extend(args)
[docs]
def add_checkpoint_key(self, *args):
"""
Register keys that should be checkpointed.
Parameters
----------
*args
One or more string keys identifying outputs that require
checkpointing or persistence.
"""
self.checkpoint_keys.extend(args)
[docs]
class InstrumentalDaskRunner(DaskRunner):
"""DaskRunner implementation for Instrumental Calibration"""
@classmethod
def _execute_stage(cls, stage, output):
if stage.stage_definition in Tags.BROADCASTER:
return stage(output)
outputs = output if isinstance(output, list) else [output]
if stage.stage_definition in Tags.AGGREGATOR:
return stage(outputs)
return [stage(output) for output in outputs]
@classmethod
def _process_upstream_output(cls, output, is_client_present):
outputs = output if isinstance(output, list) else [output]
checkpoints = [
output[key] for output in outputs for key in output.checkpoint_keys
]
compute_tasks = [
task for output in outputs for task in output.compute_tasks
]
persisted_values = dask.persist(
*(checkpoints + compute_tasks), optimize_graph=True
)
idx = 0
for output in outputs:
for key in output.checkpoint_keys:
output[key] = persisted_values[idx]
idx += 1
computed_tasks = persisted_values[idx:]
slider = 0
for output in outputs:
output.compute_outputs += computed_tasks[
slider : slider + len(output.compute_tasks) # noqa E203
]
slider += len(output.compute_tasks)
output.checkpoint_keys = []
output.stage_compute_tasks = []
if is_client_present:
for task in as_completed(futures_of(persisted_values)):
if task.status == "error":
raise task.result()
return outputs
[docs]
def execute(self):
"""
Execute the provided list of pipeline stages.
Iterates through the stages, executing each one sequentially. For
each stage, it:
1. Logs the start of the stage.
2. Invokes the stage callable with the current upstream outputs.
3. Persists any data flagged for checkpointing and any accumulated
compute tasks using `dask.persist`.
4. Updates the output container with the persisted results.
5. Waits for completion if a Dask client is present.
6. Logs the completion of the stage.
Parameters
----------
stages : list of Stage
A list of stage objects to be executed.
"""
is_client_present = False
try:
get_client()
is_client_present = True
except Exception:
pass
output = UpstreamOutput()
for stage in self.pipeline.executable_stages:
logger.info(
f"Starting {stage.name}",
extra={"tags": f"sdpPhase:{stage.name.upper()},state:START"},
)
output = self._execute_stage(stage, output)
output = self._process_upstream_output(output, is_client_present)
logger.info(
f"Finished {stage.name}",
extra={
"tags": f"sdpPhase:{stage.name.upper()},state:FINISHED"
},
)