import logging
import os
from distributed import Scheduler, SpecCluster
from ..resources import parse_dask_resource_spec, resolve_worker_configuration
from .worker import SlurmWorkers
logger = logging.getLogger(__name__)
# pylint: disable=W0223,R0913,R0914,W0613,R0917,R0912
[docs]
class DaskSlurmCluster(SpecCluster):
"""
Starts a dask cluster inside existing slurm allocation.
This class extends the :py:class:`distributed.SpecCluster` class.
Thus this has a lot of functionalities that come with
other standard clusters like :py:class:`distributed.LocalCluster`.
Parameters
----------
nodes
Number of slurm nodes on which SlurmWorkers will start.
If ``None``, cluster will try to get slurm nodes using the
``SLURM_JOB_NUM_NODES`` environmet variable, else set to 1.
workers_per_node
Number of dask worker processes per node.
If this is ``None`` and ``memory_per_worker`` is specified,
then cluster will set ``workers_per_node`` such that each
worker can use at least ``memory_per_worker``
amount of memory.
If this is ``None`` and ``memory_per_worker`` is either "auto"
or ``None``, then number of workers per node is decided
using dask's default heuristic for setting number of
dask workers and threads per worker,
while respecting slurm's limit on CPUs (if any).
threads_per_worker
Number of threads per dask worker process.
If this is ``None`` and ``workers_per_node`` is also ``None``,
then number of threads per worker is decided
using dask's default heuristic for setting number of
dask workers and threads per worker,
while respecting slurm's limit on CPUs (if any).
memory_per_worker
Memory limit per dask worker
Notes regarding argument data type:
* If "auto", cluster find out total usable memory per node,
then divides the memory equally among all workers.
* If ``None`` or 0, no limit is applied. The memory_per_worker
is set to total usable memory per node.
* If string, it should be a number followed by a unit.
e.g. "128GB", "1 TiB"
* If an int, it should indicate number of bytes per worker.
resources_per_worker
Resources or task constraints per dask worker.
If string, it can be a ``,`` seperated key-value pairs,
where key and value is seperated using ``=``.
e.g. ``"GPU=2"``; ``"TPU=1,MEM=10e9"`` ; ``{"CPU":10}``
worker_scratch_directory
The directory to store dask workers' temporary files.
Equivalent to ``local_directory`` option
in :py:class:`distributed.Worker` class.
use_entry_node
Whether to include the node running the dask scheduler
to also launch dask workers.
If true, dask workers are also launched on the node running
dask scheduler. If false dask workers would be launched on
nodes relative to the dask scheduler node.
In this case, the number of allocated slurm nodes must be
at least one more than the ``nodes`` parameter.
silence_logs
An integer corresponding to the standard logging levels in
``logging`` module. Determines the logging levels for
scheduler and worker processes.
name
Name of the dask cluster. This also acts as the prefix for
worker names.
Example
-------
>>> cluster = DaskSlurmCluster(1, 2, 3, "4GB")
>>> # get scheduler address
>>> address = cluster.scheduler_address
>>> # get a client
>>> client = cluster.get_client()
"""
def __init__(
self,
nodes: int = None,
workers_per_node: int | None = None,
threads_per_worker: int | None = None,
memory_per_worker: str | int | None = "auto",
resources_per_worker: dict | str | None = None,
worker_scratch_directory: str | None = None,
use_entry_node: bool = True,
silence_logs: int = logging.INFO,
name: str = "DaskSlurmCluster",
**_,
):
if not os.getenv("SLURM_JOB_ID"):
raise RuntimeError(
"Could not get slurm job id. "
"The DaskSlurmCluster can only be setup inside "
"an existing slurm job allocation"
)
logger.setLevel(silence_logs)
if nodes is None:
nodes = int(os.getenv("SLURM_JOB_NUM_NODES", "1"))
self.__slurm_nodes = nodes
if not use_entry_node:
self.__slurm_nodes -= 1
logger.info(
"Number of slurm nodes available for dask workers: %s",
self.__slurm_nodes,
)
_memory_limit, _n_workers, _nthreads = resolve_worker_configuration(
logger, workers_per_node, threads_per_worker, memory_per_worker
)
_resources = parse_dask_resource_spec(resources_per_worker)
worker_spec = {
"cls": SlurmWorkers,
"options": {
"n_workers": _n_workers,
"nthreads": _nthreads,
"memory_limit": _memory_limit,
"resources": _resources,
"use_entry_node": use_entry_node,
"local_directory": worker_scratch_directory,
"silence_logs": silence_logs,
},
"group": [f"-{i}" for i in range(_n_workers)],
}
logger.info("Dask worker specification: %s", worker_spec)
workers = {
f"{name}-node-{i}-worker": worker_spec
for i in range(self.__slurm_nodes)
}
scheduler_spec = {
"cls": Scheduler,
"options": {
"port": 0,
"dashboard": True,
"dashboard_address": ":0",
},
}
super().__init__(
name=name,
scheduler=scheduler_spec,
worker=worker_spec,
workers=workers,
silence_logs=silence_logs,
)
[docs]
def wait_for_all_workers(self):
"""
Wait for all workers to start
"""
workers_per_node = self.new_spec["options"]["n_workers"]
total_workers = self.__slurm_nodes * workers_per_node
logger.info("Waiting for total %s workers to start", total_workers)
self.wait_for_workers(total_workers, timeout=100)