Source code for ska_sdp_batchlet.utils.dask_cluster.resources

import logging
import math
import os

import dask
from dask.system import CPU_COUNT
from dask.utils import parse_bytes
from distributed.deploy.utils import nprocesses_nthreads
from distributed.system import MEMORY_LIMIT
from tlz import valmap  # pylint: disable=E0611

logger = logging.getLogger(__name__)


[docs] def get_usable_cpus_per_node() -> int: """ Get number of slurm cpus available per node. It is assumed that, in a slurm resource allocation, numbers of cpus available across each allocated node is identical. Returns ------- Number of usable CPUs on current node. """ return int(os.getenv("SLURM_CPUS_ON_NODE", CPU_COUNT))
[docs] def get_usable_memory_per_node() -> int: """ Get number of bytes of memory available to use per slurm node. It is assumed that, in a slurm resource allocation, amount of memory to use across each allocated node is identical. Returns ------- Amount of memory avaiable to use, in bytes. """ _mem_lim = None if "SLURM_MEM_PER_NODE" in os.environ: # SLURM_MEM_PER_NODE value is in unit 'MiB' _mem_lim = int(os.environ["SLURM_MEM_PER_NODE"]) * 1024 * 1024 elif "SLURM_MEM_PER_CPU" in os.environ: # SLURM_MEM_PER_CPU value is in unit 'MiB' _mem_lim = ( get_usable_cpus_per_node() * int(os.environ["SLURM_MEM_PER_CPU"]) * 1024 * 1024 ) else: _mem_lim = MEMORY_LIMIT mem_headroom_frac = float( os.getenv("BATCHLET_DASK_CLUSTER__MEMORY_HEADROOM_FRAC", "0.0") ) mem_headroom_frac = max(0, min(mem_headroom_frac, 0.99)) return int(_mem_lim * (1.0 - mem_headroom_frac))
[docs] def parse_dask_resource_spec( resources_per_worker: dict | str | None, ) -> dict[str, float]: """ Parse worker resources specification in backward compatible way. This function also ensures that the values of the keys in the dictionary are always floats. Parameters ---------- resources_per_worker If dictionary, its a mapping of resource names and their values. If string, it should be in either of the following formats: - "resource1=value1,resource2=value2" - "resource1=value1 resource2=value2" If resources_per_worker is None, then we attempt to get ``distributed.worker.resources`` value from dask config, and proceed with the parsing as mentioned above. In any cases, the dict values must be convertible to ``float``. Returns ------- A dictionary mapping resource names (str) to their values (float). Can be an empty dictionary. """ if resources_per_worker is None: # Replicate the default behavior of distributed.Worker resources_per_worker = dask.config.get("distributed.worker.resources") if isinstance(resources_per_worker, str): # This logic is based on the implementation in # ``dask/distributed/cli/dask_worker.py#L362-L365`` resources_per_worker = resources_per_worker.replace(",", " ").split() resources_per_worker = dict( pair.split("=") for pair in resources_per_worker ) # At this point, resources_per_worker MUST be a dict resources_per_worker = valmap(float, resources_per_worker) return resources_per_worker
[docs] def resolve_worker_configuration( workers_per_node: int | None = None, threads_per_worker: int | None = None, memory_per_worker: str | int | None = "auto", custom_logger: logging.Logger | None = None, ) -> tuple[int, int, int]: """ Calculate worker configuration based on hardware constraints. Determines the optimal number of workers per node, threads per worker, and memory limit per worker based on SLURM node resources and user-specified constraints. If user constraints conflict with available resources, this function logs warnings and applies reasonable defaults. Parameters ---------- logger Logger instance for logging calculation steps and warnings. workers_per_node Number of workers per node. If None, will be calculated based on other parameters or system resources. threads_per_worker Number of threads per worker. If None, will be calculated based on other parameters or system resources. memory_per_worker Memory limit per worker. Can be: - None: no explicit limit (uses all available node memory) - "auto": automatically calculate based on number of workers - str: parseable memory string (e.g., "4GB", "512MB") - int: memory in bytes Returns ------- A tuple of (n_workers, n_threads, memory_limit) where: - n_workers (int): Number of workers per node - n_threads (int): Number of threads per worker - memory_limit (int): Memory limit per worker in bytes """ custom_logger = custom_logger or logger usable_cpus_per_node = get_usable_cpus_per_node() usable_memory_per_node = get_usable_memory_per_node() custom_logger.info( "Number of available cpus per node: %s", usable_cpus_per_node, ) custom_logger.info( "Available memory per node: %s bytes", usable_memory_per_node, ) if memory_per_worker is None: custom_logger.info( "User has set memory_per_worker as None. Thus, " "cluster will not enforce any memory limit." ) _memory_limit = usable_memory_per_node elif memory_per_worker == "auto": # This will be computed post number of workers are finalised _memory_limit = None else: parsed_memory_bytes = parse_bytes(memory_per_worker) if parsed_memory_bytes <= 0: custom_logger.info( "User has set memory_per_worker as %s. Thus, " "cluster will not enforce any memory limit.", parsed_memory_bytes, ) _memory_limit = usable_memory_per_node elif parsed_memory_bytes > usable_memory_per_node: custom_logger.warning( "User has set memory_per_worker as %s, which is greater " "than usable memory per node. Cluster will " "limit the memory to maximum usable memory.", parsed_memory_bytes, ) _memory_limit = usable_memory_per_node else: _memory_limit = parsed_memory_bytes if (parsed_memory_bytes > 0) and (not workers_per_node): custom_logger.info( "User has set memory_per_worker without setting " "workers_per_node. Cluster will ensure " "that every worker has at least " "%s bytes of memory", _memory_limit, ) workers_per_node = max( 1, int(usable_memory_per_node // _memory_limit) ) custom_logger.info( "Cluster has set workers_per_node to %s.", workers_per_node, ) _memory_limit = int(usable_memory_per_node // workers_per_node) custom_logger.info( "Cluster has set each worker's memory_limit to %s bytes", _memory_limit, ) _n_workers = workers_per_node _nthreads = threads_per_worker if (_n_workers) and (not _nthreads): # Overcommit threads, similar to LocalCluster _nthreads = max(1, int(math.ceil(usable_cpus_per_node / _n_workers))) elif (not _n_workers) and (_nthreads): _n_workers = max(1, usable_cpus_per_node // _nthreads) elif (not _n_workers) and (not _nthreads): _n_workers, _nthreads = nprocesses_nthreads(usable_cpus_per_node) if _memory_limit is None: # If _memory_limit is still unresolved till this point, then # compute it based on available memory and number of workers _memory_limit = int(usable_memory_per_node // _n_workers) return _n_workers, _nthreads, _memory_limit