Source code for ska_sdp_batchlet.utils.dask_cluster.factory

# pylint: disable = too-few-public-methods
import logging
import os
from typing import Literal

from distributed import SpecCluster

from .local_cluster import DaskLocalCluster
from .slurm import DaskSlurmCluster

logger = logging.getLogger(__name__)


[docs] class DaskClusterFactory: """ A factory class which returns a new instance of a SpecCluster based dask cluster instance. This factory current supports 2 kinds of clusters: 1. Slurm-based cluster using via :py:class:`DaskSlurmCluster` 2. Local cluster using via :py:func:`DaskLocalCluster` """
[docs] @classmethod def get_cluster(cls, dask_cluster_params: dict) -> SpecCluster: """ Create and return a dask cluster based on the execution environment. Parameters ---------- dask_cluster_params User provided dask cluster parameters. These are forwarded as is to the constructor of the dask cluster classes. Returns ------- Cluster instance which inherits ``SpecCluster``. """ dashboard_address_from_env = os.getenv( "BATCHLET_DASK_CLUSTER__DASHBOARD_ADDRESS" ) __cluster_params = { "dashboard_address": dashboard_address_from_env, **dask_cluster_params, } if cls._determine_cluster_mode() == "slurm": logger.info("Starting a slurm based dask cluster.") cluster = DaskSlurmCluster(**__cluster_params) cluster.wait_for_all_workers() else: logger.info("Starting a local dask cluster.") cluster = DaskLocalCluster(**__cluster_params) return cluster
@classmethod def _determine_cluster_mode(cls) -> Literal["slurm", "local"]: """ Determine the cluster mode based on environment variables 1. If ``BATCHLET_DASK_CLUSTER__FORCE_LOCAL`` env variable is set OR If not inside a slurm job allocation , return "local" 2. (Default) Return "slurm" """ is_force_local_env_set = os.getenv( "BATCHLET_DASK_CLUSTER__FORCE_LOCAL" ) is_slurm_allocation = os.getenv("SLURM_JOB_ID") if (is_force_local_env_set) or not is_slurm_allocation: return "local" return "slurm"