Source code for ska_sdp_piper.piper.runners.dask_runner

from __future__ import annotations

import argparse

from dask.distributed import Client, performance_report

from ..command import CLIArgument
from ..utils.log_util import LogPlugin, LogUtil
from .default_runner import DefaultRunner


[docs] class DaskRunner(DefaultRunner): """ A runner that executes pipelines using Dask for distributed computation. Parameters ---------- dask_scheduler The address of the Dask scheduler. with_report Whether to generate a Dask performance report. **kwargs Additional keyword arguments for the DefaultRunner. """ __mandatory = [ CLIArgument( "--dask-scheduler", dest="dask_scheduler", type=str, default=None, help="""Optional dask scheduler address to which to submit jobs. If specified, any eligible pipeline step will be distributed on the associated Dask cluster. Without this option, the pipeline will perform dask computation locally using threads""", ), CLIArgument( "--with-report", dest="with_report", action=argparse.BooleanOptionalAction, default=False, help="""Capture performance report for dask distributed tasks. Defaults to false. If the flag is true, the diagnostic report would be saved as dask_report.html file in the run output folder.""", ), ] additional = []
[docs] @classmethod def cli_args(cls): """ Return the combined list of mandatory and additional CLI arguments. Returns ------- list The list of CLIArgument objects for the runner. """ return cls.__mandatory + cls.additional
def __init__( self, *, dask_scheduler: str = None, with_report: bool = False, **kwargs, ): """ Initialize the Dask runner with scheduler and reporting options. """ super().__init__(**kwargs) self.dask_scheduler = dask_scheduler self.with_report = with_report def __enter__(self) -> DaskRunner: """ Set up the Dask client and performance reporting context. Returns ------- The initialized runner instance. """ super().__enter__() if self.dask_scheduler: client = self._stack.enter_context(Client(self.dask_scheduler)) log_configure_plugin = LogPlugin(verbose=LogUtil.verbose) client.register_worker_plugin(log_configure_plugin) client.forward_logging() if self.with_report: self._stack.enter_context( performance_report( filename=f"{self.pipeline.output_dir}/dask_report.html" ) ) return self