import logging
import math
import os
import dask
from dask.system import CPU_COUNT
from dask.utils import parse_bytes
from distributed.deploy.utils import nprocesses_nthreads
from distributed.system import MEMORY_LIMIT
from tlz import valmap # pylint: disable=E0611
logger = logging.getLogger(__name__)
[docs]
def get_usable_cpus_per_node() -> int:
"""
Get number of slurm cpus available per node.
It is assumed that, in a slurm resource allocation,
numbers of cpus available across each allocated node
is identical.
Returns
-------
Number of usable CPUs on current node.
"""
return int(os.getenv("SLURM_CPUS_ON_NODE", CPU_COUNT))
[docs]
def get_usable_memory_per_node() -> int:
"""
Get number of bytes of memory available to use per slurm node.
It is assumed that, in a slurm resource allocation,
amount of memory to use across each allocated node
is identical.
Returns
-------
Amount of memory avaiable to use, in bytes.
"""
_mem_lim = None
if "SLURM_MEM_PER_NODE" in os.environ:
# SLURM_MEM_PER_NODE value is in unit 'MiB'
_mem_lim = int(os.environ["SLURM_MEM_PER_NODE"]) * 1024 * 1024
elif "SLURM_MEM_PER_CPU" in os.environ:
# SLURM_MEM_PER_CPU value is in unit 'MiB'
_mem_lim = (
get_usable_cpus_per_node()
* int(os.environ["SLURM_MEM_PER_CPU"])
* 1024
* 1024
)
else:
_mem_lim = MEMORY_LIMIT
mem_headroom_frac = float(
os.getenv("BATCHLET_DASK_CLUSTER__MEMORY_HEADROOM_FRAC", "0.0")
)
mem_headroom_frac = max(0, min(mem_headroom_frac, 0.99))
return int(_mem_lim * (1.0 - mem_headroom_frac))
[docs]
def parse_dask_resource_spec(
resources_per_worker: dict | str | None,
) -> dict[str, float]:
"""
Parse worker resources specification in backward compatible way.
This function also ensures that the values of the keys
in the dictionary are always floats.
Parameters
----------
resources_per_worker
If dictionary, its a mapping of resource names and their values.
If string, it should be in either of the following formats:
- "resource1=value1,resource2=value2"
- "resource1=value1 resource2=value2"
If resources_per_worker is None, then we attempt to get
``distributed.worker.resources`` value from dask config,
and proceed with the parsing as mentioned above.
In any cases, the dict values must be convertible to ``float``.
Returns
-------
A dictionary mapping resource names (str) to their values (float).
Can be an empty dictionary.
"""
if resources_per_worker is None:
# Replicate the default behavior of distributed.Worker
resources_per_worker = dask.config.get("distributed.worker.resources")
if isinstance(resources_per_worker, str):
# This logic is based on the implementation in
# ``dask/distributed/cli/dask_worker.py#L362-L365``
resources_per_worker = resources_per_worker.replace(",", " ").split()
resources_per_worker = dict(
pair.split("=") for pair in resources_per_worker
)
# At this point, resources_per_worker MUST be a dict
resources_per_worker = valmap(float, resources_per_worker)
return resources_per_worker
[docs]
def resolve_worker_configuration(
workers_per_node: int | None = None,
threads_per_worker: int | None = None,
memory_per_worker: str | int | None = "auto",
custom_logger: logging.Logger | None = None,
) -> tuple[int, int, int]:
"""
Calculate worker configuration based on hardware constraints.
Determines the optimal number of workers per node, threads per worker,
and memory limit per worker based on SLURM node resources and
user-specified constraints.
If user constraints conflict with available resources,
this function logs warnings and applies reasonable defaults.
Parameters
----------
logger
Logger instance for logging calculation steps and warnings.
workers_per_node
Number of workers per node.
If None, will be calculated based on other parameters
or system resources.
threads_per_worker
Number of threads per worker.
If None, will be calculated based on other parameters
or system resources.
memory_per_worker
Memory limit per worker. Can be:
- None: no explicit limit (uses all available node memory)
- "auto": automatically calculate based on number of workers
- str: parseable memory string (e.g., "4GB", "512MB")
- int: memory in bytes
Returns
-------
A tuple of (n_workers, n_threads, memory_limit) where:
- n_workers (int): Number of workers per node
- n_threads (int): Number of threads per worker
- memory_limit (int): Memory limit per worker in bytes
"""
custom_logger = custom_logger or logger
usable_cpus_per_node = get_usable_cpus_per_node()
usable_memory_per_node = get_usable_memory_per_node()
custom_logger.info(
"Number of available cpus per node: %s",
usable_cpus_per_node,
)
custom_logger.info(
"Available memory per node: %s bytes",
usable_memory_per_node,
)
if memory_per_worker is None:
custom_logger.info(
"User has set memory_per_worker as None. Thus, "
"cluster will not enforce any memory limit."
)
_memory_limit = usable_memory_per_node
elif memory_per_worker == "auto":
# This will be computed post number of workers are finalised
_memory_limit = None
else:
parsed_memory_bytes = parse_bytes(memory_per_worker)
if parsed_memory_bytes <= 0:
custom_logger.info(
"User has set memory_per_worker as %s. Thus, "
"cluster will not enforce any memory limit.",
parsed_memory_bytes,
)
_memory_limit = usable_memory_per_node
elif parsed_memory_bytes > usable_memory_per_node:
custom_logger.warning(
"User has set memory_per_worker as %s, which is greater "
"than usable memory per node. Cluster will "
"limit the memory to maximum usable memory.",
parsed_memory_bytes,
)
_memory_limit = usable_memory_per_node
else:
_memory_limit = parsed_memory_bytes
if (parsed_memory_bytes > 0) and (not workers_per_node):
custom_logger.info(
"User has set memory_per_worker without setting "
"workers_per_node. Cluster will ensure "
"that every worker has at least "
"%s bytes of memory",
_memory_limit,
)
workers_per_node = max(
1, int(usable_memory_per_node // _memory_limit)
)
custom_logger.info(
"Cluster has set workers_per_node to %s.",
workers_per_node,
)
_memory_limit = int(usable_memory_per_node // workers_per_node)
custom_logger.info(
"Cluster has set each worker's memory_limit to %s bytes",
_memory_limit,
)
_n_workers = workers_per_node
_nthreads = threads_per_worker
if (_n_workers) and (not _nthreads):
# Overcommit threads, similar to LocalCluster
_nthreads = max(1, int(math.ceil(usable_cpus_per_node / _n_workers)))
elif (not _n_workers) and (_nthreads):
_n_workers = max(1, usable_cpus_per_node // _nthreads)
elif (not _n_workers) and (not _nthreads):
_n_workers, _nthreads = nprocesses_nthreads(usable_cpus_per_node)
if _memory_limit is None:
# If _memory_limit is still unresolved till this point, then
# compute it based on available memory and number of workers
_memory_limit = int(usable_memory_per_node // _n_workers)
return _n_workers, _nthreads, _memory_limit