from __future__ import annotations
import argparse
from dask.distributed import Client, performance_report
from ..command import CLIArgument
from ..utils.log_util import LogPlugin, LogUtil
from .default_runner import DefaultRunner
[docs]
class DaskRunner(DefaultRunner):
"""
A runner that executes pipelines using Dask for distributed computation.
Parameters
----------
dask_scheduler
The address of the Dask scheduler.
with_report
Whether to generate a Dask performance report.
**kwargs
Additional keyword arguments for the DefaultRunner.
"""
__mandatory = [
CLIArgument(
"--dask-scheduler",
dest="dask_scheduler",
type=str,
default=None,
help="""Optional dask scheduler address to which to submit jobs.
If specified, any eligible pipeline step will be distributed
on the associated Dask cluster. Without this option, the
pipeline will perform dask computation locally using threads""",
),
CLIArgument(
"--with-report",
dest="with_report",
action=argparse.BooleanOptionalAction,
default=False,
help="""Capture performance report for dask distributed tasks.
Defaults to false. If the flag is true, the diagnostic
report would be saved as dask_report.html file in the
run output folder.""",
),
]
additional = []
[docs]
@classmethod
def cli_args(cls):
"""
Return the combined list of mandatory and additional CLI arguments.
Returns
-------
list
The list of CLIArgument objects for the runner.
"""
return cls.__mandatory + cls.additional
def __init__(
self,
*,
dask_scheduler: str = None,
with_report: bool = False,
**kwargs,
):
"""
Initialize the Dask runner with scheduler and reporting options.
"""
super().__init__(**kwargs)
self.dask_scheduler = dask_scheduler
self.with_report = with_report
def __enter__(self) -> DaskRunner:
"""
Set up the Dask client and performance reporting context.
Returns
-------
The initialized runner instance.
"""
super().__enter__()
if self.dask_scheduler:
client = self._stack.enter_context(Client(self.dask_scheduler))
log_configure_plugin = LogPlugin(verbose=LogUtil.verbose)
client.register_worker_plugin(log_configure_plugin)
client.forward_logging()
if self.with_report:
self._stack.enter_context(
performance_report(
filename=f"{self.pipeline.output_dir}/dask_report.html"
)
)
return self