- Module code
- ska_sdp_batchlet.utils.dask_cluster.factory
-
Source code for ska_sdp_batchlet.utils.dask_cluster.factory
# pylint: disable = too-few-public-methods
import logging
import os
from typing import Literal
from distributed import SpecCluster
from .local_cluster import DaskLocalCluster
from .slurm import DaskSlurmCluster
logger = logging.getLogger(__name__)
[docs]
class DaskClusterFactory:
"""
A factory class which returns a new instance of
a SpecCluster based dask cluster instance.
This factory current supports 2 kinds of clusters:
1. Slurm-based cluster using via :py:class:`DaskSlurmCluster`
2. Local cluster using via :py:func:`DaskLocalCluster`
"""
[docs]
@classmethod
def get_cluster(cls, dask_cluster_params: dict) -> SpecCluster:
"""
Create and return a dask cluster based on the execution environment.
Parameters
----------
dask_cluster_params
User provided dask cluster parameters. These are forwarded as is to
the constructor of the dask cluster classes.
Returns
-------
Cluster instance which inherits ``SpecCluster``.
"""
dashboard_address_from_env = os.getenv(
"BATCHLET_DASK_CLUSTER__DASHBOARD_ADDRESS"
)
__cluster_params = {
"dashboard_address": dashboard_address_from_env,
**dask_cluster_params,
}
if cls._determine_cluster_mode() == "slurm":
logger.info("Starting a slurm based dask cluster.")
cluster = DaskSlurmCluster(**__cluster_params)
cluster.wait_for_all_workers()
else:
logger.info("Starting a local dask cluster.")
cluster = DaskLocalCluster(**__cluster_params)
return cluster
@classmethod
def _determine_cluster_mode(cls) -> Literal["slurm", "local"]:
"""
Determine the cluster mode based on environment variables
1. If ``BATCHLET_DASK_CLUSTER__FORCE_LOCAL`` env variable is set OR
If not inside a slurm job allocation , return "local"
2. (Default) Return "slurm"
"""
is_force_local_env_set = os.getenv(
"BATCHLET_DASK_CLUSTER__FORCE_LOCAL"
)
is_slurm_allocation = os.getenv("SLURM_JOB_ID")
if (is_force_local_env_set) or not is_slurm_allocation:
return "local"
return "slurm"