- Module code
- ska_sdp_batchlet.utils.dask_cluster.factory
-
Source code for ska_sdp_batchlet.utils.dask_cluster.factory
# pylint: disable = too-few-public-methods
import logging
import os
from distributed import SpecCluster
from ska_sdp_batchlet.utils.dask_cluster.local_cluster import DaskLocalCluster
from ska_sdp_batchlet.utils.dask_cluster.slurm.cluster import DaskSlurmCluster
logger = logging.getLogger(__name__)
[docs]
class DaskClusterFactory:
"""
A factory class which returns a new instance of
a SpecCluster based dask cluster instance.
"""
[docs]
@staticmethod
def get_cluster(dask_params: dict) -> SpecCluster:
"""
Create and return a dask cluster based on the execution environment.
The selection logic is as follows:
1. If inside a slurm job allocation, return
:py:class:`DaskSlurmCluster`
2. Else, return a :py:class:`DaskLocalCluster`.
Parameters
----------
dask_params
User provided dask parameters. Given to batchlet as json input,
in the key "dask_params".
Returns
-------
Cluster instance which inherits ``SpecCluster``.
"""
# Considering SLURM_JOB_ID env variable
# as source of truth
if os.getenv("SLURM_JOB_ID"):
logger.info("Starting a slurm based dask cluster.")
cluster = DaskSlurmCluster(**dask_params)
cluster.wait_for_all_workers()
return cluster
logger.info("Starting a local dask cluster.")
return DaskLocalCluster(**dask_params)