"""
Prototype Continuum Imaging Pipeline (CIP).
Based on RASCIL's CIP, with the option of using
Processing Function Library functions instead of RASCIL
processing components where available.
"""
import logging
import pprint
import sys
from rascil.processing_components import create_visibility_from_ms
from rascil.workflows import restore_skymodel_list_rsexecute_workflow
from rascil.workflows.rsexecute.execution_support import rsexecute
from ska_sdp_datamodels.sky_model.sky_model import SkyModel
from ska_sdp_func_python.imaging import (
advise_wide_field,
create_image_from_visibility,
invert_visibility,
predict_visibility,
)
from ska_sdp_func_python.visibility import convert_visibility_to_stokesI
from src.cli_parser import cli_parser
from src.dask_utils import (
dask_wrapper,
print_task_stream_timing,
set_up_dask,
tear_down_dask,
)
from src.imaging_utils import export_results, imaging_subtract_vis, time_run
from src.processing_function_integration import deconvolution, dft_visibility
log = logging.getLogger("prototype-pipeline")
log.setLevel(logging.INFO)
log.addHandler(logging.StreamHandler(sys.stdout))
[docs]class ContinuumImagingPipeline:
"""
Prototype Continuum Imaging Pipeline.
Doesn't use calibration.
"""
def __init__(self, imaging_context, use_dask=False):
"""
:param imaging_context: Gridding type. Options: ng (nifty-gridder),
wg (wagg, GPU only), 2d or awprojection
:param use_dask: whether to run computation with dask or not
"""
self.bvis_list = None
self.model_image_list = None
self.skymodel_list = None
self.psf_list = None
self.dirty_image_list = None
self.restored_list = None
self._imaging_context = imaging_context
self._use_dask = use_dask
self._tmp_zeroed_bvis_list = None
self._psf_sumwt_list = None # needed for restore
def __call__(self, **input_args):
"""
Set up for running the continuum imaging pipeline.
Steps:
1. load MeasurementSet into Visibilities
2. Generate initial model images and SkyModels
3. Initialize the point spread function (PSF)
"""
self._load_bvis_from_ms(
input_args["input_ms"],
input_args["input_nchan"],
input_args["nchan_per_vis"],
)
self._init_model_images_list(
input_args["imaging_npixel"], input_args["imaging_cellsize"]
)
self._init_skymodel_list()
self._init_psf_list()
# helper list, contains copy of input bvis with zero as visibility data
self._tmp_zeroed_bvis_list = [
dask_wrapper(bvis.copy, use_dask=self._use_dask, nout=1)(
deep=True, zero=True
)
for bvis in self.bvis_list
]
[docs] def predict(self, processing_func_source):
"""
Run the predict step, including subtracting the
predicted model data from the input data.
:param processing_func_source:
which type of processing functions to use;
Options: 'rascil' - use RASCIL
'proc_func' - Use the Processing Function Library
:return: updates self.bvis_list in place
"""
# Run Direct Fourier Transform (DFT) on SkyComponents
dft_vis_list = [
dask_wrapper(dft_visibility, use_dask=self._use_dask, nout=1)(
vis, self.skymodel_list[i].components, processing_func_source
)
for i, vis in enumerate(self._tmp_zeroed_bvis_list)
]
# Predict
updated_bvis_model = [
dask_wrapper(predict_visibility, use_dask=self._use_dask, nout=1)(
dft_v,
self.skymodel_list[i].image,
context=self._imaging_context,
)
for i, dft_v in enumerate(dft_vis_list)
]
for i, dft_v in enumerate(dft_vis_list):
dask_wrapper(_update_data, use_dask=self._use_dask, nout=0)(
dft_v["vis"], updated_bvis_model[i]["vis"]
)
# SUBTRACT model from input bvis; update bvis_list
self.bvis_list = [
dask_wrapper(
imaging_subtract_vis, use_dask=self._use_dask, nout=1
)(bvis, dft_vis_list[i])
for i, bvis in enumerate(self.bvis_list)
]
[docs] def invert(self):
"""
Run the invert step.
:return: updates self.dirty_image_list in place
"""
inverted_list = [
dask_wrapper(invert_visibility, use_dask=self._use_dask, nout=2)(
residual_vis,
self.skymodel_list[i].image,
dopsf=False,
context=self._imaging_context,
)
for i, residual_vis in enumerate(self.bvis_list)
]
self.dirty_image_list = [inverted[0] for inverted in inverted_list]
[docs] def deconvolve(self, fit_skymodel, component_threshold, clean_threshold):
"""
Run deconvolution.
:param fit_skymodel: True: fit the skymodel and extract sky components
False: run CLEAN-based deconvolution
:param component_threshold: Sources with absolute flux > this level (Jy) are fitted
:param clean_threshold: Clean stopping threshold (Jy/beam)
:return: updates self.skymodel_list in place
"""
# deconvolution uses RASCIL's rsexecute to run Dask at the moment,
# this may change once other, non-rascil deconvolution functions are added
self.skymodel_list = deconvolution(
self.dirty_image_list,
self.psf_list,
self.skymodel_list,
fit_skymodel=fit_skymodel,
component_threshold=component_threshold,
clean_threshold=clean_threshold,
)
[docs] def restore(self):
"""
Restore images.
Uses RASCIL's restore_skymodel_pipeline_rsexecute_workflow function,
which implicitly uses rsexecute to run with Dask.
"""
# the following rascil pipleline workflow coordinates multiple restoring options in RASCIL CIP:
# rascil.workflows.rsexecute.pipelines.pipeline_skymodel_rsexecute.restore_skymodel_pipeline_rsexecute_workflow
# we default to the "list" version, which is just the following restore call:
self.restored_list = restore_skymodel_list_rsexecute_workflow(
self.skymodel_list,
self._psf_sumwt_list,
self.dirty_image_list,
)
[docs] def _load_bvis_from_ms(self, input_ms, channels_in_data, nchan_per_vis):
"""
Load MeasurementSet data into Visibility objects.
(rascil.data_models.memory_data_models.Visibility)
Note:
- Polarization of data is always converted to Stokes I after loading.
- Based on rascil.apps.rascil_imager.get_vis_list and
rascil.workflows.rsexecute.visibility.visibility_rsexecute.create_visibility_from_ms_rsexecute
- these RASCIL functions take multiple arguments, most of which we hardcode here.
However, in the future, we may need to allow for users to specify these.
:param input_ms: path to input MeasurementSet
:param channels_in_data: how many frequency channels does the data set contain
:param nchan_per_vis: how many frequency channels to load into a
single Visibility
"""
# Read MS data
num_bvis = channels_in_data // nchan_per_vis
bvis_list = [
dask_wrapper(
create_visibility_from_ms, use_dask=self._use_dask, nout=1
)(
msname=input_ms,
selected_dds=[0],
start_chan=chan_block * nchan_per_vis,
end_chan=(1 + chan_block) * nchan_per_vis - 1,
average_channels=False,
)[
0
]
for chan_block in range(num_bvis)
]
log.info("Number of BlockVisibilities: %s", len(bvis_list))
self.bvis_list = [
dask_wrapper(
convert_visibility_to_stokesI,
use_dask=self._use_dask,
nout=1,
)(bvis)
for bvis in bvis_list
]
[docs] def _init_model_images_list(self, n_pixel, cell_size=None):
"""
Initialize model images from input visibility data
Note: According to RASCIL Continuum Imaging Pipeline,
each model image needs to have a single frequency channel.
This will make sure that when we run invert with ng,
the resulting Image list will have a channel each and,
hence deconvolution won't break (which expects 1 chan / Image).
:param cell_size: image cell size [rad]; if None, it is calculated
:param n_pixel: number of pixels on a side of the image
"""
if cell_size is None:
advice_list = [
dask_wrapper(
advise_wide_field, use_dask=self._use_dask, nout=1
)(bv, guard_band_image=3.0)
for bv in self.bvis_list
]
cell_size = advice_list[0]["cellsize"]
log.info("Setting cellsize to %s rad", cell_size)
self.model_image_list = [
dask_wrapper(
create_image_from_visibility, use_dask=self._use_dask, nout=1
)(bvis, cellsize=cell_size, nchan=1, npixel=n_pixel)
for bvis in self.bvis_list
]
[docs] def _init_skymodel_list(self):
"""
Initialize SkyModel for each model image
(Note: we may need to allow it as input too, in the future)
"""
self.skymodel_list = [
dask_wrapper(SkyModel, use_dask=self._use_dask, nout=1)(
image=model
)
for model in self.model_image_list
]
[docs] def _init_psf_list(self):
"""
Create Point Spread Functions (PSF)
"""
self._psf_sumwt_list = [
dask_wrapper(invert_visibility, use_dask=self._use_dask, nout=2)(
bvis,
self.model_image_list[i],
dopsf=True,
context=self._imaging_context,
)
for i, bvis in enumerate(self.bvis_list)
]
self.psf_list = [psf[0] for psf in self._psf_sumwt_list]
def _update_data(input_vis, model_data):
"""
Update input_vis visibility data with model_data.
"""
input_vis.data = input_vis.data + model_data.data
def run_cip(args):
"""
Run the continuum imaging pipeline
:param args: user-defined arguments
"""
log.info(
"Running with user input arguments:\n %s", pprint.pformat(vars(args))
)
if args.use_dask:
# start Dask
client = set_up_dask()
rsexecute.set_client(client=client, use_dask=True)
else:
rsexecute.set_client(use_dask=False)
log.info("Running initial setup")
cont_img = ContinuumImagingPipeline(
args.imaging_context, use_dask=args.use_dask
)
time_run(cont_img, use_dask=args.use_dask)(**vars(args))
# Pre-major cycle calls
log.info("Running initial Predict and Subtract")
time_run(cont_img.predict, use_dask=args.use_dask)(args.processing_func)
log.info("Running initial Invert")
time_run(cont_img.invert, use_dask=args.use_dask)()
log.info("Running initial Deconvolution (Finding SkyComponents)")
time_run(cont_img.deconvolve, use_dask=args.use_dask)(
True, args.component_threshold, args.clean_threshold
)
# Run major cycle
log.info("Entering major cycles.")
for j in range(args.n_major):
log.info("Running major cycle %s", j)
log.info("PREDICT and SUBTRACT")
time_run(cont_img.predict, use_dask=args.use_dask)(
args.processing_func
)
log.info("INVERT")
time_run(cont_img.invert, use_dask=args.use_dask)()
log.info("DECONVOLUTION (msclean)")
time_run(cont_img.deconvolve, use_dask=args.use_dask)(
False, args.component_threshold, args.clean_threshold
)
log.info("Finished running major cycles.")
# Finalize images
log.info("Running final Predict and Subtract")
time_run(cont_img.predict, use_dask=args.use_dask)(args.processing_func)
log.info("Running final Invert")
time_run(cont_img.invert, use_dask=args.use_dask)()
log.info("Restoring images.")
time_run(cont_img.restore)()
if args.use_dask:
log.info("Computing Dask Graph")
(
cont_img.restored_list,
cont_img.skymodel_list,
cont_img.dirty_image_list,
) = time_run(client.compute)(
[
cont_img.restored_list,
cont_img.skymodel_list,
cont_img.dirty_image_list,
],
sync=True,
)
log.info("Dask Graph computation finished.")
log.info("Exporting images and SkyModel.")
output_prefix = args.input_ms.replace(".ms", "_test_").split("/")[-1]
file_prefix = args.output_dir + "/" + output_prefix + args.processing_func
restored_name = export_results(
cont_img.restored_list,
cont_img.dirty_image_list,
cont_img.skymodel_list,
file_prefix,
)
if args.use_dask:
# close Dask
task_stream = client.get_task_stream()
print_task_stream_timing(task_stream)
tear_down_dask(client)
return restored_name
def main():
parser = cli_parser()
args = parser.parse_args()
args.use_dask = args.use_dask == "True"
run_cip(args)
if __name__ == "__main__":
main()