Source code for ska_sdp_spectral_line_imaging.util

from typing import Tuple

import dask
import dask.array
import numpy as np
import xarray as xr
from astropy import units as au
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS
from ska_sdp_datamodels.science_data_model.polarisation_model import (
    PolarisationFrame,
)

from ska_sdp_spectral_line_imaging.constants import (
    FITS_AXIS_TO_IMAGE_DIM,
    FITS_CODE_TO_POL_NAME,
    POL_NAME_TO_FITS_CODE,
)


[docs] class SafeDict(dict): """ Class to allow for selective format_map function """ def __missing__(self, key): return "{" + key + "}"
[docs] def rechunk(target, ref, dim): """ Rechunk a target DataArray based on a ref DataArray Parameters ---------- target: xr.DataArray DataArray to be rechunked ref: xr.DataArray Reference DataArray dim: dict Dimensions to be expanded along Returns ------- xr.DataArray """ return ( target.expand_dims(dim=dim).transpose(*ref.dims).chunk(ref.chunksizes) )
[docs] def export_image_as(image, output_path, export_format="fits"): """ Export data in the desired export_format Parameters ---------- image: ska_sdp_datamodels.image.image_model.Image Image to be exported output_path: str Output file name export_format: str Data format for the data. Allowed values: fits|zarr Returns ------- dask.delayed.Delayed Raises ------ ValueError: If the provided data format is not in fits or zarr """ if export_format == "fits": return export_to_fits(image, output_path) elif export_format == "zarr": return export_to_zarr(image, output_path) else: raise ValueError(f"Unsupported format: {export_format}")
[docs] def export_to_zarr(data, output_path, clear_attrs=False): """ Lazily export xarray dataset/dataarray to zarr file format. Parameters ---------- data: xarray.DataArray | xarray.Dataset Xarray data to be exported output_path: str Output file path. A ".zarr" is appended to this path. clear_attrs: bool = False Whether to clear attributes of the data before writing to zarr. Returns ------- dask.delayed.Delayed A dask delayed object which represents the task of writing data to zarr. """ data_to_export = data.copy(deep=False) if clear_attrs: data_to_export.attrs.clear() return data_to_export.to_zarr(store=f"{output_path}.zarr", compute=False)
@dask.delayed def _write_array_to_fits_delayed(data, header, output_path): """ A helper function to write numpy array or dask array to fits image. The caller has to call "compute()" on the returned delayed object, in order to initiate the write call. Parameters ---------- data: numpy.ndarray | dask.array Data to write to the file. header: astropy.io.fits.header.Header FITS Header object, typically generated from WCS information. output_path: str Path to write to. Returns ------- dask.delayed.Delayed """ fits.writeto( filename=output_path, data=data, header=header, overwrite=True, )
[docs] def export_to_fits(image, output_path): """ Exports Image instance to multiple FITS files, one per polarisation. The caller is expected to call ``dask.compute()`` on the returned delayed objects for the actual writes to happen. Parameters ---------- image: ska_sdp_datamodels.image.Image Image instance containing all the polarisations. output_path: str Path to write FITS image to. The output_path is appended with ".{pol}.fits" where "pol" is the polarization being written. Returns ------- List[dask.delayed.Delayed] """ wcs = image.image_acc.wcs # Since Image dimensions are: ["frequency", "polarisation", "y", "x"] pol_axis_in_np = 1 # WCS / FITS dimensions must be: ["RA", "DEC", "POL", "CHAN"] pol_axis_in_wcs = 2 # clean_beam must be a dictionary with keys {"bmaj", "bmin", "bpa"}. clean_beam = image.attrs["clean_beam"] tasks = [] for pol in image.polarisation.values: data = ( image.pixels.sel(polarisation=pol) .expand_dims(dim="polarisation", axis=pol_axis_in_np) .data ) new_wcs = wcs.deepcopy() new_wcs.wcs.crval[pol_axis_in_wcs] = POL_NAME_TO_FITS_CODE[pol] new_wcs.wcs.cdelt[pol_axis_in_wcs] = 1.0 header = new_wcs.to_header() if clean_beam: header.append( card=fits.Card( "BMAJ", clean_beam["bmaj"], "[deg] CLEAN beam major axis", ) ) header.append( card=fits.Card( "BMIN", clean_beam["bmin"], "[deg] CLEAN beam minor axis", ) ) header.append( card=fits.Card( "BPA", clean_beam["bpa"], "[deg] CLEAN beam position angle", ) ) tasks.append( _write_array_to_fits_delayed( data, header, f"{output_path}.{pol}.fits" ) ) return tasks
[docs] def estimate_cell_size_in_arcsec( baseline: float, wavelength: float, factor=3.0 ) -> float: """ A generalized function which estimates cell size for given baseline value. This function is dask compatible i.e. can take dask arrays as input, and return dask array as output. Parameters ---------- baseline: float Baseline length in meters. For better estimation, this has to be the maximum baseline length in any direction. wavelength: float Wavelength in meters. For better estimation, it has to be the minimum wavelength observed. factor: float Scaling factor. Returns ------- float Cell size in arcsecond. **The output is rounded** to the 2 decimal places. """ baseline = baseline / wavelength cell_size_rad = 1.0 / (2.0 * factor * baseline) cell_size_arcsec = np.rad2deg(cell_size_rad) * 3600 # Rounded to 2 decimals return cell_size_arcsec.round(2)
[docs] def estimate_image_size( wavelength: float, antenna_diameter: float, cell_size: float ) -> int: """ Estimates dimension of the image which will be used in the imaging stage. This function is dask compatible i.e. can take dask arrays as input, and return dask array as output. Parameters ---------- wavelength: float Wavelength in meters. For better estimation, this has to be the maximum wavelength observed. antenna_diameter: float Diameter of the antenna in meters. For better estimation, this has to be the minimum of the diameters of all antennas. cell_size: float Cell size in arcsecond. Returns ------- int Size of the image. **The output is rounded** to the nearest multiple of 100 greater than the calculated image size. """ cell_size_rad = np.deg2rad(cell_size / 3600) image_size = (1.5 * wavelength) / (cell_size_rad * antenna_diameter) # Rounding to the nearest multiple of 100 return np.ceil(image_size / 100) * 100
[docs] def get_polarization_frame_from_observation( observation: xr.Dataset, ) -> PolarisationFrame: """ Reads an observation from the xradio processing set, and generates a PolarizationFrame instance from the polarization coordinates. This is required to generate an instance of Image class. Parameters ---------- observation: xarray.Dataset Observation from xradio processing set Returns ------- PolarisationFrame """ polarization_lookup = { "_".join(value): key for key, value in PolarisationFrame.polarisation_frames.items() } polarization_frame = PolarisationFrame( polarization_lookup["_".join(observation.polarization.data)] ) return polarization_frame
# NOTE: This does not handle MOMENT images, only FREQ
[docs] def get_wcs_from_observation(obs, cell_size, nx, ny) -> WCS: """ Reads an observation from the xradio processing set, and extracts WCS information. This is required to create an instance of Image class defined in `ska_sdp_datamodels.image.Image`. Since Image dimensions are fixed to ["frequency", "polarisation", "y", "x"], the sequence of axes in WCS is ["RA", "DEC", "STOKES", "FREQ"]. **NOTE:** Polarization axis is defaulted to stokes, with crval = 1.0 and cdelta = 1.0. This is done due to the difference in the sequence of linear and circular polarizations values in FITS and in processing set. Consumer of this function is expected to populate correct values for polarizations present in the processing set. **NOTE:** field_and_source_base_xds is assumed to be the same across all observations in the processing set. The first observation is used to extract the field phase center. Parameters ---------- obs: xarray.Dataset Concatenated observation dataset from upstream output cell_size: float Cell size in arcseconds. nx: int Image size X ny: int Image size Y Returns ------- WCS """ field_phase_center = ( obs.VISIBILITY.field_and_source_xds.FIELD_PHASE_CENTER_DIRECTION ) if set(field_phase_center.sky_dir_label.values) != {"ra", "dec"}: raise ValueError( "Phase field center coordinates labels are not equal to RA / DEC." ) if field_phase_center.units != "rad": raise ValueError("Phase field center value is not defined in radian.") # computes FIELD_PHASE_CENTER_DIRECTION if its a dask array fp_center = { label: value for label, value in zip( field_phase_center.sky_dir_label.values, field_phase_center.to_numpy().flatten(), ) } fp_frame = field_phase_center.frame.lower() # TODO: Verify: Is the fp_frame equal to frame? coord = SkyCoord( ra=fp_center["ra"] * au.rad, dec=fp_center["dec"] * au.rad, frame=fp_frame, ) cell_size_degree = cell_size / 3600 freq_channel_width = obs.frequency.channel_width["data"] ref_freq = float(obs.frequency.data[0]) freq_unit = obs.frequency.units # NOTE: Hardcoding to stokes polarization, # consumer should fill-in correct polarization value. ref_pol = 1.0 del_pol = 1.0 new_wcs = WCS(naxis=4) # computes nx and ny if those are dask arrays new_wcs.wcs.crpix = [nx // 2, ny // 2, 1, 1] new_wcs.wcs.cunit = ["deg", "deg", "", freq_unit] # computes cell_size_degree if its a dask array new_wcs.wcs.cdelt = [ -cell_size_degree, cell_size_degree, del_pol, freq_channel_width, ] new_wcs.wcs.crval = [coord.ra.deg, coord.dec.deg, ref_pol, ref_freq] new_wcs.wcs.ctype = ["RA---SIN", "DEC--SIN", "STOKES", "FREQ"] # NOTE: "ICRS" since sdp-datamodels also have fixed radesys new_wcs.wcs.radesys = "ICRS" # new_wcs.wcs.radesys = coord.frame.name.upper() # NOTE: "2000.0" since sdp-datamodels also have fixed equinox new_wcs.wcs.equinox = 2000.0 # new_wcs.wcs.equinox = coord.frame.equinox.jyear # NOTE: Verify this assignment is correct new_wcs.wcs.specsys = obs.frequency.observer return new_wcs
@dask.delayed def read_fits_memmapped_delayed(image_path, hduid=0): with fits.open( image_path, mode="denywrite", memmap=True, lazy_load_hdus=True ) as hdul: hdu = hdul[hduid] data = hdu.data return data
[docs] def get_dask_array_from_fits( image_path: str, hduid: int, shape: Tuple, dtype: type, ): data = dask.array.from_delayed( read_fits_memmapped_delayed(image_path, hduid), shape=shape, dtype=dtype, ) return data
[docs] def get_dataarray_from_fits(image_path, hduid=0): """ Reads FITS image and returns an xarray dataarray with dimensions ["polarization", "frequency", "y", "x"] or only ["y", "x"] if data is 2 dimensionsional. Function can also read coordinte values for dimensions "polarization" and "frequency". Spatial coordinates "y" and "x" are linear, and the their coordinate values are not populatedin output dataarray. If needed, those can be populated later. Refer ska_sdp_datamodels.image.Image.constructor. The image data is read as a dask array using delayed read calls to astropy.fits.open. Parameters ---------- image_path: str Path to FITS image hduid: int The HDU number in the HDUList read from FITS image. Returns ------- xarray.DataArray Raises ------ NotImplementedError If chunksizes are passed as parameter """ # opening image only to get metadata with fits.open(image_path, memmap=True) as hdul: hdu = hdul[hduid] shape = hdu.data.shape dtype = hdu.data.dtype wcs = WCS(image_path) dimensions = [ FITS_AXIS_TO_IMAGE_DIM[axis] for axis in reversed(wcs.axis_type_names) ] coordinates = {} if "frequency" in dimensions: spectral_wcs = wcs.sub(["spectral"]) frequency_range = spectral_wcs.wcs_pix2world( range(spectral_wcs.pixel_shape[0]), 0 )[0] coordinates["frequency"] = frequency_range if "polarization" in dimensions: pol_wcs = wcs.sub(["stokes"]) pol_codes = pol_wcs.wcs_pix2world(range(pol_wcs.pixel_shape[0]), 0)[0] pol_names = [FITS_CODE_TO_POL_NAME[code] for code in pol_codes] coordinates["polarization"] = pol_names data = get_dask_array_from_fits(image_path, hduid, shape, dtype) return xr.DataArray( data, dims=dimensions, coords=coordinates, name="fits_image_arr", )