Source code for ska_sdp_batchlet.utils.dask_cluster.factory

# pylint: disable = too-few-public-methods
import logging
import os

from distributed import SpecCluster

from ska_sdp_batchlet.utils.dask_cluster.local_cluster import DaskLocalCluster
from ska_sdp_batchlet.utils.dask_cluster.slurm.cluster import DaskSlurmCluster

logger = logging.getLogger(__name__)


[docs] class DaskClusterFactory: """ A factory class which returns a new instance of a SpecCluster based dask cluster instance. """
[docs] @staticmethod def get_cluster(dask_params: dict) -> SpecCluster: """ Create and return a dask cluster based on the execution environment. The selection logic is as follows: 1. If inside a slurm job allocation, return :py:class:`DaskSlurmCluster` 2. Else, return a :py:class:`DaskLocalCluster`. Parameters ---------- dask_params User provided dask parameters. Given to batchlet as json input, in the key "dask_params". Returns ------- Cluster instance which inherits ``SpecCluster``. """ # Considering SLURM_JOB_ID env variable # as source of truth if os.getenv("SLURM_JOB_ID"): logger.info("Starting a slurm based dask cluster.") cluster = DaskSlurmCluster(**dask_params) cluster.wait_for_all_workers() return cluster logger.info("Starting a local dask cluster.") return DaskLocalCluster(**dask_params)