Source code for ska_sdp_spectral_line_imaging.pipeline

# pylint: disable=no-member,import-error
import logging
from pathlib import Path

import dask
from ska_sdp_piper.piper.command import CLIArgument
from ska_sdp_piper.piper.pipeline import Pipeline
from ska_sdp_piper.piper.runners import DaskRunner
from ska_sdp_piper.piper.utils import create_output_dir

from . import __version__
from .diagnosis.cli_arguments import DIAGNOSTIC_CLI_ARGS
from .diagnosis.spectral_line_diagnoser import SpectralLineDiagnoser
from .stages.flagging import flagging_stage
from .stages.imaging import imaging_stage
from .stages.load_data import load_data
from .stages.model import cont_sub, read_model, vis_stokes_conversion
from .stages.predict import predict_stage
from .upstream_output import UpstreamOutput

logger = logging.getLogger()


[docs] class CustomDaskRunner(DaskRunner):
[docs] def execute(self): upstream_output = UpstreamOutput() for stage in self.pipeline.executable_stages: upstream_output = stage(upstream_output) dask.compute(upstream_output.compute_tasks, optimize_graph=True)
spectral_line_imaging_pipeline = Pipeline( "spectral_line_imaging_pipeline", load_data, vis_stokes_conversion, read_model, predict_stage, cont_sub, flagging_stage, imaging_stage, version=__version__, ).overide_run( CLIArgument( "--input", dest="input_path", type=str, required=True, help="Input visibility path", ), runner=CustomDaskRunner, )
[docs] @spectral_line_imaging_pipeline.sub_command( "diagnose", *DIAGNOSTIC_CLI_ARGS, help="Diagnose the pipeline", ) def pipeline_diagnostic(input_path, channel, output, dask_scheduler): """ Pipeline diagnostics sub_command Parameters ---------- cli_args: argparse.Namespace CLI arguments """ input_path = Path(input_path) output_dir = "./diagnosis" if output is None else output with CustomDaskRunner( spectral_line_imaging_pipeline, dask_scheduler=dask_scheduler ): timestamped_output_dir = Path( create_output_dir(output_dir, True, "pipeline-qa") ) logger.info("==========================================") logger.info("=============== DIAGNOSE =================") logger.info("==========================================") logger.info(f"Current run output path : {timestamped_output_dir}") diagnoser = SpectralLineDiagnoser( input_path, timestamped_output_dir, channel ) diagnoser.diagnose()