Source code for ska_sdp_batchlet.utils.dask_cluster.slurm.cluster

import logging
import os

from distributed import Scheduler, SpecCluster

from ..resources import parse_dask_resource_spec, resolve_worker_configuration
from .worker import SlurmWorkers

logger = logging.getLogger(__name__)


# pylint: disable=W0223,R0913,R0914,W0613,R0917,R0912
[docs] class DaskSlurmCluster(SpecCluster): """ Starts a dask cluster inside existing slurm allocation. This class extends the :py:class:`distributed.SpecCluster` class. Thus this has a lot of functionalities that come with other standard clusters like :py:class:`distributed.LocalCluster`. Parameters ---------- nodes Number of slurm nodes on which SlurmWorkers will start. If ``None``, cluster will try to get slurm nodes using the ``SLURM_JOB_NUM_NODES`` environmet variable, else set to 1. workers_per_node Number of dask worker processes per node. If this is ``None`` and ``memory_per_worker`` is specified, then cluster will set ``workers_per_node`` such that each worker can use at least ``memory_per_worker`` amount of memory. If this is ``None`` and ``memory_per_worker`` is either "auto" or ``None``, then number of workers per node is decided using dask's default heuristic for setting number of dask workers and threads per worker, while respecting slurm's limit on CPUs (if any). threads_per_worker Number of threads per dask worker process. If this is ``None`` and ``workers_per_node`` is also ``None``, then number of threads per worker is decided using dask's default heuristic for setting number of dask workers and threads per worker, while respecting slurm's limit on CPUs (if any). memory_per_worker Memory limit per dask worker Notes regarding argument data type: * If "auto", cluster find out total usable memory per node, then divides the memory equally among all workers. * If ``None`` or 0, no limit is applied. The memory_per_worker is set to total usable memory per node. * If string, it should be a number followed by a unit. e.g. "128GB", "1 TiB" * If an int, it should indicate number of bytes per worker. resources_per_worker Resources or task constraints per dask worker. If string, it can be a ``,`` seperated key-value pairs, where key and value is seperated using ``=``. e.g. ``"GPU=2"``; ``"TPU=1,MEM=10e9"`` ; ``{"CPU":10}`` worker_scratch_directory The directory to store dask workers' temporary files. Equivalent to ``local_directory`` option in :py:class:`distributed.Worker` class. use_entry_node Whether to include the node running the dask scheduler to also launch dask workers. If true, dask workers are also launched on the node running dask scheduler. If false dask workers would be launched on nodes relative to the dask scheduler node. In this case, the number of allocated slurm nodes must be at least one more than the ``nodes`` parameter. silence_logs An integer corresponding to the standard logging levels in ``logging`` module. Determines the logging levels for scheduler and worker processes. name Name of the dask cluster. This also acts as the prefix for worker names. Example ------- >>> cluster = DaskSlurmCluster(1, 2, 3, "4GB") >>> # get scheduler address >>> address = cluster.scheduler_address >>> # get a client >>> client = cluster.get_client() """ def __init__( self, nodes: int = None, workers_per_node: int | None = None, threads_per_worker: int | None = None, memory_per_worker: str | int | None = "auto", resources_per_worker: dict | str | None = None, worker_scratch_directory: str | None = None, use_entry_node: bool = True, silence_logs: int = logging.INFO, name: str = "DaskSlurmCluster", **_, ): if not os.getenv("SLURM_JOB_ID"): raise RuntimeError( "Could not get slurm job id. " "The DaskSlurmCluster can only be setup inside " "an existing slurm job allocation" ) logger.setLevel(silence_logs) if nodes is None: nodes = int(os.getenv("SLURM_JOB_NUM_NODES", "1")) self.__slurm_nodes = nodes if not use_entry_node: self.__slurm_nodes -= 1 logger.info( "Number of slurm nodes available for dask workers: %s", self.__slurm_nodes, ) _memory_limit, _n_workers, _nthreads = resolve_worker_configuration( logger, workers_per_node, threads_per_worker, memory_per_worker ) _resources = parse_dask_resource_spec(resources_per_worker) worker_spec = { "cls": SlurmWorkers, "options": { "n_workers": _n_workers, "nthreads": _nthreads, "memory_limit": _memory_limit, "resources": _resources, "use_entry_node": use_entry_node, "local_directory": worker_scratch_directory, "silence_logs": silence_logs, }, "group": [f"-{i}" for i in range(_n_workers)], } logger.info("Dask worker specification: %s", worker_spec) workers = { f"{name}-node-{i}-worker": worker_spec for i in range(self.__slurm_nodes) } scheduler_spec = { "cls": Scheduler, "options": { "port": 0, "dashboard": True, "dashboard_address": ":0", }, } super().__init__( name=name, scheduler=scheduler_spec, worker=worker_spec, workers=workers, silence_logs=silence_logs, )
[docs] def wait_for_all_workers(self): """ Wait for all workers to start """ workers_per_node = self.new_spec["options"]["n_workers"] total_workers = self.__slurm_nodes * workers_per_node logger.info("Waiting for total %s workers to start", total_workers) self.wait_for_workers(total_workers, timeout=100)