Source code for ska_sdp_batchlet.utils.dask_cluster.resources

import logging
import math
import os

from dask.system import CPU_COUNT
from dask.utils import parse_bytes
from distributed.deploy.utils import nprocesses_nthreads
from distributed.system import MEMORY_LIMIT
from tlz import valmap  # pylint: disable=E0611


[docs] def get_usable_cpus_per_node() -> int: """ Get number of slurm cpus available per node. It is assumed that, in a slurm resource allocation, numbers of cpus available across each allocated node is identical. Returns ------- Number of usable CPUs on current node. """ return int(os.getenv("SLURM_CPUS_ON_NODE", CPU_COUNT))
[docs] def get_usable_memory_per_node() -> int: """ Get number of bytes of memory available to use per slurm node. It is assumed that, in a slurm resource allocation, amount of memory to use across each allocated node is identical. Returns ------- Amount of memory avaiable to use, in bytes. """ _mem_lim = None if "SLURM_MEM_PER_NODE" in os.environ: # SLURM_MEM_PER_NODE value is in unit 'MiB' _mem_lim = int(os.environ["SLURM_MEM_PER_NODE"]) * 1024 * 1024 elif "SLURM_MEM_PER_CPU" in os.environ: # SLURM_MEM_PER_CPU value is in unit 'MiB' _mem_lim = ( get_usable_cpus_per_node() * int(os.environ["SLURM_MEM_PER_CPU"]) * 1024 * 1024 ) else: _mem_lim = MEMORY_LIMIT mem_headroom_frac = float( os.getenv("BATCHLET_DASK_CLUSTER__MEMORY_HEADROOM_FRAC", "0.0") ) mem_headroom_frac = max(0, min(mem_headroom_frac, 0.99)) return int(_mem_lim * (1.0 - mem_headroom_frac))
[docs] def parse_dask_resource_spec( resources_per_worker: dict | str | None, ) -> dict | None: """ Handle worker resources in backward compatible way. If input is dictionary or None, return unchanged. If string, convert a comma or space-separated resource string to a dictionary, and return the dictionary. This function also ensures that the values of the keys in the dictionary are always floats. Parameters ---------- resources_per_worker Either a string specifying resources in either of the following formats: - "resource1=value1,resource2=value2" - "resource1=value1 resource2=value2" or a dictionary: - {"resource1": value1, "resource2": value2} In both cases, the "value" must be convertible to ``float``. Returns ------- A dictionary mapping resource names (str) to their values (float), or None if input is None. Notes ----- This function is based on the implementation in ``dask/distributed/cli/dask_worker.py#L362-L365`` """ if isinstance(resources_per_worker, str): resources_per_worker = resources_per_worker.replace(",", " ").split() resources_per_worker = dict( pair.split("=") for pair in resources_per_worker ) if isinstance(resources_per_worker, dict): resources_per_worker = valmap(float, resources_per_worker) return resources_per_worker
[docs] def resolve_worker_configuration( logger: logging.Logger, workers_per_node: int | None = None, threads_per_worker: int | None = None, memory_per_worker: str | int | None = "auto", ) -> tuple[int, int, int]: """ Calculate worker resources for a SLURM cluster based on node constraints. Determines the optimal number of workers per node, threads per worker, and memory limit per worker based on SLURM node resources and user-specified constraints. If user constraints conflict with available resources, this function logs warnings and applies reasonable defaults. Parameters ---------- logger Logger instance for logging calculation steps and warnings. workers_per_node Number of workers per node. If None, will be calculated based on other parameters or system resources. threads_per_worker Number of threads per worker. If None, will be calculated based on other parameters or system resources. memory_per_worker Memory limit per worker. Can be: - None: no explicit limit (uses all available node memory) - "auto": automatically calculate based on number of workers - str: parseable memory string (e.g., "4GB", "512MB") - int: memory in bytes Returns ------- A tuple of (memory_limit, n_workers, n_threads) where: - memory_limit (int): Memory limit per worker in bytes - n_workers (int): Number of workers per node - n_threads (int): Number of threads per worker """ usable_cpus_per_node = get_usable_cpus_per_node() usable_memory_per_node = get_usable_memory_per_node() logger.info( "Number of available cpus per node: %s", usable_cpus_per_node, ) logger.info( "Available memory per node: %s bytes", usable_memory_per_node, ) _memory_limit = None if memory_per_worker is None: logger.info( "User has set memory_per_worker as None. Thus, " "cluster will not enforce any memory limit." ) _memory_limit = usable_memory_per_node elif memory_per_worker != "auto": parsed_memory_bytes = parse_bytes(memory_per_worker) if parsed_memory_bytes <= 0: logger.info( "User has set memory_per_worker as %s. Thus, " "cluster will not enforce any memory limit.", parsed_memory_bytes, ) _memory_limit = usable_memory_per_node elif parsed_memory_bytes > usable_memory_per_node: logger.warning( "User has set memory_per_worker as %s, which is greater " "than usable memory per node. Cluster will " "limit the memory to maximum usable memory.", parsed_memory_bytes, ) _memory_limit = usable_memory_per_node else: _memory_limit = parsed_memory_bytes if (parsed_memory_bytes > 0) and (not workers_per_node): logger.info( "User has set memory_per_worker without setting " "workers_per_node. Cluster will ensure " "that every worker has at least " "%s bytes of memory", _memory_limit, ) workers_per_node = max( 1, int(usable_memory_per_node // _memory_limit) ) logger.info( "Cluster has set workers_per_node to %s.", workers_per_node, ) _memory_limit = int(usable_memory_per_node // workers_per_node) logger.info( "Cluster has set each worker's memory_limit to %s bytes", _memory_limit, ) _n_workers = workers_per_node _nthreads = threads_per_worker if (_n_workers) and (not _nthreads): # Overcommit threads, similar to LocalCluster _nthreads = max(1, int(math.ceil(usable_cpus_per_node / _n_workers))) elif (not _n_workers) and (_nthreads): _n_workers = max(1, usable_cpus_per_node // _nthreads) elif (not _n_workers) and (not _nthreads): _n_workers, _nthreads = nprocesses_nthreads(usable_cpus_per_node) if memory_per_worker == "auto": _memory_limit = int(usable_memory_per_node // _n_workers) return _memory_limit, _n_workers, _nthreads