import logging
import math
import os
from dask.system import CPU_COUNT
from dask.utils import parse_bytes
from distributed.deploy.utils import nprocesses_nthreads
from distributed.system import MEMORY_LIMIT
from tlz import valmap # pylint: disable=E0611
[docs]
def get_usable_cpus_per_node() -> int:
"""
Get number of slurm cpus available per node.
It is assumed that, in a slurm resource allocation,
numbers of cpus available across each allocated node
is identical.
Returns
-------
Number of usable CPUs on current node.
"""
return int(os.getenv("SLURM_CPUS_ON_NODE", CPU_COUNT))
[docs]
def get_usable_memory_per_node() -> int:
"""
Get number of bytes of memory available to use per slurm node.
It is assumed that, in a slurm resource allocation,
amount of memory to use across each allocated node
is identical.
Returns
-------
Amount of memory avaiable to use, in bytes.
"""
_mem_lim = None
if "SLURM_MEM_PER_NODE" in os.environ:
# SLURM_MEM_PER_NODE value is in unit 'MiB'
_mem_lim = int(os.environ["SLURM_MEM_PER_NODE"]) * 1024 * 1024
elif "SLURM_MEM_PER_CPU" in os.environ:
# SLURM_MEM_PER_CPU value is in unit 'MiB'
_mem_lim = (
get_usable_cpus_per_node()
* int(os.environ["SLURM_MEM_PER_CPU"])
* 1024
* 1024
)
else:
_mem_lim = MEMORY_LIMIT
mem_headroom_frac = float(
os.getenv("BATCHLET_DASK_CLUSTER__MEMORY_HEADROOM_FRAC", "0.0")
)
mem_headroom_frac = max(0, min(mem_headroom_frac, 0.99))
return int(_mem_lim * (1.0 - mem_headroom_frac))
[docs]
def parse_dask_resource_spec(
resources_per_worker: dict | str | None,
) -> dict | None:
"""
Handle worker resources in backward compatible way.
If input is dictionary or None, return unchanged.
If string, convert a comma or space-separated
resource string to a dictionary, and return the dictionary.
This function also ensures that the values of the keys
in the dictionary are always floats.
Parameters
----------
resources_per_worker
Either a string specifying resources in either of the following
formats:
- "resource1=value1,resource2=value2"
- "resource1=value1 resource2=value2"
or a dictionary:
- {"resource1": value1, "resource2": value2}
In both cases, the "value" must be convertible to ``float``.
Returns
-------
A dictionary mapping resource names (str) to their values (float),
or None if input is None.
Notes
-----
This function is based on the implementation in
``dask/distributed/cli/dask_worker.py#L362-L365``
"""
if isinstance(resources_per_worker, str):
resources_per_worker = resources_per_worker.replace(",", " ").split()
resources_per_worker = dict(
pair.split("=") for pair in resources_per_worker
)
if isinstance(resources_per_worker, dict):
resources_per_worker = valmap(float, resources_per_worker)
return resources_per_worker
[docs]
def resolve_worker_configuration(
logger: logging.Logger,
workers_per_node: int | None = None,
threads_per_worker: int | None = None,
memory_per_worker: str | int | None = "auto",
) -> tuple[int, int, int]:
"""
Calculate worker resources for a SLURM cluster based on node constraints.
Determines the optimal number of workers per node, threads per worker,
and memory limit per worker based on SLURM node resources and
user-specified constraints.
If user constraints conflict with available resources,
this function logs warnings and applies reasonable defaults.
Parameters
----------
logger
Logger instance for logging calculation steps and warnings.
workers_per_node
Number of workers per node.
If None, will be calculated based on other parameters
or system resources.
threads_per_worker
Number of threads per worker.
If None, will be calculated based on other parameters
or system resources.
memory_per_worker
Memory limit per worker. Can be:
- None: no explicit limit (uses all available node memory)
- "auto": automatically calculate based on number of workers
- str: parseable memory string (e.g., "4GB", "512MB")
- int: memory in bytes
Returns
-------
A tuple of (memory_limit, n_workers, n_threads) where:
- memory_limit (int): Memory limit per worker in bytes
- n_workers (int): Number of workers per node
- n_threads (int): Number of threads per worker
"""
usable_cpus_per_node = get_usable_cpus_per_node()
usable_memory_per_node = get_usable_memory_per_node()
logger.info(
"Number of available cpus per node: %s",
usable_cpus_per_node,
)
logger.info(
"Available memory per node: %s bytes",
usable_memory_per_node,
)
_memory_limit = None
if memory_per_worker is None:
logger.info(
"User has set memory_per_worker as None. Thus, "
"cluster will not enforce any memory limit."
)
_memory_limit = usable_memory_per_node
elif memory_per_worker != "auto":
parsed_memory_bytes = parse_bytes(memory_per_worker)
if parsed_memory_bytes <= 0:
logger.info(
"User has set memory_per_worker as %s. Thus, "
"cluster will not enforce any memory limit.",
parsed_memory_bytes,
)
_memory_limit = usable_memory_per_node
elif parsed_memory_bytes > usable_memory_per_node:
logger.warning(
"User has set memory_per_worker as %s, which is greater "
"than usable memory per node. Cluster will "
"limit the memory to maximum usable memory.",
parsed_memory_bytes,
)
_memory_limit = usable_memory_per_node
else:
_memory_limit = parsed_memory_bytes
if (parsed_memory_bytes > 0) and (not workers_per_node):
logger.info(
"User has set memory_per_worker without setting "
"workers_per_node. Cluster will ensure "
"that every worker has at least "
"%s bytes of memory",
_memory_limit,
)
workers_per_node = max(
1, int(usable_memory_per_node // _memory_limit)
)
logger.info(
"Cluster has set workers_per_node to %s.",
workers_per_node,
)
_memory_limit = int(usable_memory_per_node // workers_per_node)
logger.info(
"Cluster has set each worker's memory_limit to %s bytes",
_memory_limit,
)
_n_workers = workers_per_node
_nthreads = threads_per_worker
if (_n_workers) and (not _nthreads):
# Overcommit threads, similar to LocalCluster
_nthreads = max(1, int(math.ceil(usable_cpus_per_node / _n_workers)))
elif (not _n_workers) and (_nthreads):
_n_workers = max(1, usable_cpus_per_node // _nthreads)
elif (not _n_workers) and (not _nthreads):
_n_workers, _nthreads = nprocesses_nthreads(usable_cpus_per_node)
if memory_per_worker == "auto":
_memory_limit = int(usable_memory_per_node // _n_workers)
return _memory_limit, _n_workers, _nthreads