# pylint: disable=no-member,import-error
import logging
from pathlib import Path
import dask
from ska_sdp_piper.piper.command import CLIArgument
from ska_sdp_piper.piper.pipeline import Pipeline
from ska_sdp_piper.piper.runners import DaskRunner
from ska_sdp_piper.piper.utils import create_output_dir
from . import __version__
from .diagnosis.cli_arguments import DIAGNOSTIC_CLI_ARGS
from .diagnosis.spectral_line_diagnoser import SpectralLineDiagnoser
from .stages.flagging import flagging_stage
from .stages.imaging import imaging_stage
from .stages.load_data import load_data
from .stages.model import cont_sub, read_model, vis_stokes_conversion
from .stages.predict import predict_stage
from .upstream_output import UpstreamOutput
logger = logging.getLogger()
[docs]
class CustomDaskRunner(DaskRunner):
[docs]
def execute(self):
upstream_output = UpstreamOutput()
for stage in self.pipeline.executable_stages:
upstream_output = stage(upstream_output)
dask.compute(upstream_output.compute_tasks, optimize_graph=True)
spectral_line_imaging_pipeline = Pipeline(
"spectral_line_imaging_pipeline",
load_data,
vis_stokes_conversion,
read_model,
predict_stage,
cont_sub,
flagging_stage,
imaging_stage,
version=__version__,
).overide_run(
CLIArgument(
"--input",
dest="input_path",
type=str,
required=True,
help="Input visibility path",
),
runner=CustomDaskRunner,
)
[docs]
@spectral_line_imaging_pipeline.sub_command(
"diagnose",
*DIAGNOSTIC_CLI_ARGS,
help="Diagnose the pipeline",
)
def pipeline_diagnostic(input_path, channel, output, dask_scheduler):
"""
Pipeline diagnostics sub_command
Parameters
----------
cli_args: argparse.Namespace
CLI arguments
"""
input_path = Path(input_path)
output_dir = "./diagnosis" if output is None else output
with CustomDaskRunner(
spectral_line_imaging_pipeline, dask_scheduler=dask_scheduler
):
timestamped_output_dir = Path(
create_output_dir(output_dir, True, "pipeline-qa")
)
logger.info("==========================================")
logger.info("=============== DIAGNOSE =================")
logger.info("==========================================")
logger.info(f"Current run output path : {timestamped_output_dir}")
diagnoser = SpectralLineDiagnoser(
input_path, timestamped_output_dir, channel
)
diagnoser.diagnose()