Source code for src.imaging_prototype

"""
Prototype Continuum Imaging Pipeline (CIP).

Based on RASCIL's CIP, with the option of using
Processing Function Library functions instead of RASCIL
processing components where available.
"""

import logging
import pprint
import sys

from rascil.processing_components import create_visibility_from_ms
from rascil.workflows import restore_skymodel_list_rsexecute_workflow
from rascil.workflows.rsexecute.execution_support import rsexecute
from ska_sdp_datamodels.sky_model.sky_model import SkyModel
from ska_sdp_func_python.imaging import (
    advise_wide_field,
    create_image_from_visibility,
    invert_visibility,
    predict_visibility,
)
from ska_sdp_func_python.visibility import convert_visibility_to_stokesI

from src.cli_parser import cli_parser
from src.dask_utils import (
    dask_wrapper,
    print_task_stream_timing,
    set_up_dask,
    tear_down_dask,
)
from src.imaging_utils import export_results, imaging_subtract_vis, time_run
from src.processing_function_integration import deconvolution, dft_visibility

log = logging.getLogger("prototype-pipeline")
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))


[docs]class ContinuumImagingPipeline: """ Prototype Continuum Imaging Pipeline. Doesn't use calibration. """ def __init__(self, imaging_context, use_dask=False): """ :param imaging_context: Gridding type. Options: ng (nifty-gridder), wg (wagg, GPU only), 2d or awprojection :param use_dask: whether to run computation with dask or not """ self.bvis_list = None self.model_image_list = None self.skymodel_list = None self.psf_list = None self.dirty_image_list = None self.restored_list = None self._imaging_context = imaging_context self._use_dask = use_dask self._tmp_zeroed_bvis_list = None self._psf_sumwt_list = None # needed for restore def __call__(self, **input_args): """ Set up for running the continuum imaging pipeline. Steps: 1. load MeasurementSet into Visibilities 2. Generate initial model images and SkyModels 3. Initialize the point spread function (PSF) """ self._load_bvis_from_ms( input_args["input_ms"], input_args["input_nchan"], input_args["nchan_per_vis"], ) self._init_model_images_list( input_args["imaging_npixel"], input_args["imaging_cellsize"] ) self._init_skymodel_list() self._init_psf_list() # helper list, contains copy of input bvis with zero as visibility data self._tmp_zeroed_bvis_list = [ dask_wrapper(bvis.copy, use_dask=self._use_dask, nout=1)( deep=True, zero=True ) for bvis in self.bvis_list ]
[docs] def predict(self, processing_func_source): """ Run the predict step, including subtracting the predicted model data from the input data. :param processing_func_source: which type of processing functions to use; Options: 'rascil' - use RASCIL 'proc_func' - Use the Processing Function Library :return: updates self.bvis_list in place """ # Run Direct Fourier Transform (DFT) on SkyComponents dft_vis_list = [ dask_wrapper(dft_visibility, use_dask=self._use_dask, nout=1)( vis, self.skymodel_list[i].components, processing_func_source ) for i, vis in enumerate(self._tmp_zeroed_bvis_list) ] # Predict updated_bvis_model = [ dask_wrapper(predict_visibility, use_dask=self._use_dask, nout=1)( dft_v, self.skymodel_list[i].image, context=self._imaging_context, ) for i, dft_v in enumerate(dft_vis_list) ] for i, dft_v in enumerate(dft_vis_list): dask_wrapper(_update_data, use_dask=self._use_dask, nout=0)( dft_v["vis"], updated_bvis_model[i]["vis"] ) # SUBTRACT model from input bvis; update bvis_list self.bvis_list = [ dask_wrapper( imaging_subtract_vis, use_dask=self._use_dask, nout=1 )(bvis, dft_vis_list[i]) for i, bvis in enumerate(self.bvis_list) ]
[docs] def invert(self): """ Run the invert step. :return: updates self.dirty_image_list in place """ inverted_list = [ dask_wrapper(invert_visibility, use_dask=self._use_dask, nout=2)( residual_vis, self.skymodel_list[i].image, dopsf=False, context=self._imaging_context, ) for i, residual_vis in enumerate(self.bvis_list) ] self.dirty_image_list = [inverted[0] for inverted in inverted_list]
[docs] def deconvolve(self, fit_skymodel, component_threshold, clean_threshold): """ Run deconvolution. :param fit_skymodel: True: fit the skymodel and extract sky components False: run CLEAN-based deconvolution :param component_threshold: Sources with absolute flux > this level (Jy) are fitted :param clean_threshold: Clean stopping threshold (Jy/beam) :return: updates self.skymodel_list in place """ # deconvolution uses RASCIL's rsexecute to run Dask at the moment, # this may change once other, non-rascil deconvolution functions are added self.skymodel_list = deconvolution( self.dirty_image_list, self.psf_list, self.skymodel_list, fit_skymodel=fit_skymodel, component_threshold=component_threshold, clean_threshold=clean_threshold, )
[docs] def restore(self): """ Restore images. Uses RASCIL's restore_skymodel_pipeline_rsexecute_workflow function, which implicitly uses rsexecute to run with Dask. """ # the following rascil pipleline workflow coordinates multiple restoring options in RASCIL CIP: # rascil.workflows.rsexecute.pipelines.pipeline_skymodel_rsexecute.restore_skymodel_pipeline_rsexecute_workflow # we default to the "list" version, which is just the following restore call: self.restored_list = restore_skymodel_list_rsexecute_workflow( self.skymodel_list, self._psf_sumwt_list, self.dirty_image_list, )
[docs] def _load_bvis_from_ms(self, input_ms, channels_in_data, nchan_per_vis): """ Load MeasurementSet data into Visibility objects. (rascil.data_models.memory_data_models.Visibility) Note: - Polarization of data is always converted to Stokes I after loading. - Based on rascil.apps.rascil_imager.get_vis_list and rascil.workflows.rsexecute.visibility.visibility_rsexecute.create_visibility_from_ms_rsexecute - these RASCIL functions take multiple arguments, most of which we hardcode here. However, in the future, we may need to allow for users to specify these. :param input_ms: path to input MeasurementSet :param channels_in_data: how many frequency channels does the data set contain :param nchan_per_vis: how many frequency channels to load into a single Visibility """ # Read MS data num_bvis = channels_in_data // nchan_per_vis bvis_list = [ dask_wrapper( create_visibility_from_ms, use_dask=self._use_dask, nout=1 )( msname=input_ms, selected_dds=[0], start_chan=chan_block * nchan_per_vis, end_chan=(1 + chan_block) * nchan_per_vis - 1, average_channels=False, )[ 0 ] for chan_block in range(num_bvis) ] log.info("Number of BlockVisibilities: %s", len(bvis_list)) self.bvis_list = [ dask_wrapper( convert_visibility_to_stokesI, use_dask=self._use_dask, nout=1, )(bvis) for bvis in bvis_list ]
[docs] def _init_model_images_list(self, n_pixel, cell_size=None): """ Initialize model images from input visibility data Note: According to RASCIL Continuum Imaging Pipeline, each model image needs to have a single frequency channel. This will make sure that when we run invert with ng, the resulting Image list will have a channel each and, hence deconvolution won't break (which expects 1 chan / Image). :param cell_size: image cell size [rad]; if None, it is calculated :param n_pixel: number of pixels on a side of the image """ if cell_size is None: advice_list = [ dask_wrapper( advise_wide_field, use_dask=self._use_dask, nout=1 )(bv, guard_band_image=3.0) for bv in self.bvis_list ] cell_size = advice_list[0]["cellsize"] log.info("Setting cellsize to %s rad", cell_size) self.model_image_list = [ dask_wrapper( create_image_from_visibility, use_dask=self._use_dask, nout=1 )(bvis, cellsize=cell_size, nchan=1, npixel=n_pixel) for bvis in self.bvis_list ]
[docs] def _init_skymodel_list(self): """ Initialize SkyModel for each model image (Note: we may need to allow it as input too, in the future) """ self.skymodel_list = [ dask_wrapper(SkyModel, use_dask=self._use_dask, nout=1)( image=model ) for model in self.model_image_list ]
[docs] def _init_psf_list(self): """ Create Point Spread Functions (PSF) """ self._psf_sumwt_list = [ dask_wrapper(invert_visibility, use_dask=self._use_dask, nout=2)( bvis, self.model_image_list[i], dopsf=True, context=self._imaging_context, ) for i, bvis in enumerate(self.bvis_list) ] self.psf_list = [psf[0] for psf in self._psf_sumwt_list]
def _update_data(input_vis, model_data): """ Update input_vis visibility data with model_data. """ input_vis.data = input_vis.data + model_data.data def run_cip(args): """ Run the continuum imaging pipeline :param args: user-defined arguments """ log.info( "Running with user input arguments:\n %s", pprint.pformat(vars(args)) ) if args.use_dask: # start Dask client = set_up_dask() rsexecute.set_client(client=client, use_dask=True) else: rsexecute.set_client(use_dask=False) log.info("Running initial setup") cont_img = ContinuumImagingPipeline( args.imaging_context, use_dask=args.use_dask ) time_run(cont_img, use_dask=args.use_dask)(**vars(args)) # Pre-major cycle calls log.info("Running initial Predict and Subtract") time_run(cont_img.predict, use_dask=args.use_dask)(args.processing_func) log.info("Running initial Invert") time_run(cont_img.invert, use_dask=args.use_dask)() log.info("Running initial Deconvolution (Finding SkyComponents)") time_run(cont_img.deconvolve, use_dask=args.use_dask)( True, args.component_threshold, args.clean_threshold ) # Run major cycle log.info("Entering major cycles.") for j in range(args.n_major): log.info("Running major cycle %s", j) log.info("PREDICT and SUBTRACT") time_run(cont_img.predict, use_dask=args.use_dask)( args.processing_func ) log.info("INVERT") time_run(cont_img.invert, use_dask=args.use_dask)() log.info("DECONVOLUTION (msclean)") time_run(cont_img.deconvolve, use_dask=args.use_dask)( False, args.component_threshold, args.clean_threshold ) log.info("Finished running major cycles.") # Finalize images log.info("Running final Predict and Subtract") time_run(cont_img.predict, use_dask=args.use_dask)(args.processing_func) log.info("Running final Invert") time_run(cont_img.invert, use_dask=args.use_dask)() log.info("Restoring images.") time_run(cont_img.restore)() if args.use_dask: log.info("Computing Dask Graph") ( cont_img.restored_list, cont_img.skymodel_list, cont_img.dirty_image_list, ) = time_run(client.compute)( [ cont_img.restored_list, cont_img.skymodel_list, cont_img.dirty_image_list, ], sync=True, ) log.info("Dask Graph computation finished.") log.info("Exporting images and SkyModel.") output_prefix = args.input_ms.replace(".ms", "_test_").split("/")[-1] file_prefix = args.output_dir + "/" + output_prefix + args.processing_func restored_name = export_results( cont_img.restored_list, cont_img.dirty_image_list, cont_img.skymodel_list, file_prefix, ) if args.use_dask: # close Dask task_stream = client.get_task_stream() print_task_stream_timing(task_stream) tear_down_dask(client) return restored_name def main(): parser = cli_parser() args = parser.parse_args() args.use_dask = args.use_dask == "True" run_cip(args) if __name__ == "__main__": main()