"""Function to manage skymodels.
"""
__all__ = [
"partition_skymodel_by_flux",
"show_skymodel",
"initialize_skymodel_voronoi",
"calculate_skymodel_equivalent_image",
"update_skymodel_from_gaintables",
"update_skymodel_from_image",
"expand_skymodel_by_skycomponents",
"create_skymodel_from_skycomponents_gaintables",
"extract_skycomponents_from_skymodel",
]
import logging
import matplotlib.pyplot as plt
import numpy
from astropy.wcs.utils import skycoord_to_pixel
from ska_sdp_datamodels.sky_model.sky_model import SkyModel
from ska_sdp_func_python.sky_component import (
filter_skycomponents_by_flux,
find_skycomponents,
fit_skycomponent,
image_voronoi_iter,
insert_skycomponent,
)
from rascil.processing_components.image.operations import smooth_image
from rascil.processing_components.parameters import get_parameter
log = logging.getLogger("rascil-logger")
[docs]
def partition_skymodel_by_flux(sc, model, flux_threshold=-numpy.inf):
"""Partition skymodel according to flux
Bright skycomponents are put into a SkyModel as a list, and weak skycomponents
are inserted into SkyModel as an image.
:param sc: List of skycomponents
:param model: Model image
:param flux_threshold:
:return: SkyModel
For example::
fluxes = numpy.linspace(0, 1.0, 11)
sc = [create_skycomponent(direction=phasecentre, flux=numpy.array([[f]]), frequency=frequency,
polarisation_frame=PolarisationFrame('stokesI')) for f in fluxes]
sm = partition_skymodel_by_flux(sc, model, flux_threshold=0.31)
assert len(sm.components) == 7, len(sm.components)
"""
brightsc = filter_skycomponents_by_flux(sc, flux_min=flux_threshold)
weaksc = filter_skycomponents_by_flux(sc, flux_max=flux_threshold)
log.info(
"Converted %d components into %d bright components and one image containing %d components"
% (len(sc), len(brightsc), len(weaksc))
)
im = model.copy(deep=True)
im = insert_skycomponent(im, weaksc)
return SkyModel(
components=[comp.copy() for comp in brightsc],
image=im.copy(deep=True),
mask=None,
fixed=False,
)
[docs]
def show_skymodel(sms, psf_width=1.75, cm="Greys", vmax=None, vmin=None):
"""Show a list of SkyModels
:param sms: List of SkyModels
:param psf_width: Width of PSF in pixels
:param cm: matplotlib colormap
:param vmax: Maximum in image display
:param vmin: Minimum in image display
:return:
"""
sp = 1
for ism, sm in enumerate(sms):
plt.clf()
plt.subplot(121, projection=sms[ism].image.image_acc.wcs.sub([1, 2]))
sp += 1
smodel = sms[ism].image.copy(deep=True)
smodel = insert_skycomponent(smodel, sms[ism].components)
smodel = smooth_image(smodel, psf_width)
if vmax is None:
vmax = numpy.max(smodel["pixels"].data[0, 0, ...])
if vmin is None:
vmin = numpy.min(smodel["pixels"].data[0, 0, ...])
plt.imshow(
smodel["pixels"].data[0, 0, ...],
origin="lower",
cmap=cm,
vmax=vmax,
vmin=vmin,
)
plt.xlabel(sms[ism].image.image_acc.wcs.wcs.ctype[0])
plt.ylabel(sms[ism].image.image_acc.wcs.wcs.ctype[1])
plt.title("SkyModel%d" % ism)
components = sms[ism].components
if components is not None:
for sc in components:
x, y = skycoord_to_pixel(
sc.direction, sms[ism].image.image_acc.wcs, 1, "wcs"
)
plt.plot(x, y, marker="+", color="red")
gaintable = sms[ism].gaintable
if gaintable is not None:
plt.subplot(122)
sp += 1
phase = numpy.angle(sm.gaintable.gain[:, :, 0, 0, 0])
phase -= phase[:, 0][:, numpy.newaxis]
plt.imshow(phase, origin="lower")
plt.xlabel("Dish/Station")
plt.ylabel("Integration")
plt.show()
[docs]
def initialize_skymodel_voronoi(model, comps, gt=None):
"""Create a skymodel by Voronoi partitioning of the components, fill with components
:param model: Model image
:param comps: SkyComponents
:param gt: Gaintable
:return:
"""
skymodel_images = list()
for i, mask in enumerate(image_voronoi_iter(model, comps)):
im = model.copy(deep=True)
im["pixels"].data *= mask["pixels"].data
if gt is not None:
newgt = gt.copy(deep=True)
newgt.attrs["phasecentre"] = comps[i].direction
else:
newgt = None
skymodel_images.append(
SkyModel(image=im, components=None, gaintable=newgt, mask=mask)
)
return skymodel_images
[docs]
def calculate_skymodel_equivalent_image(sm):
"""Calculate an equivalent image for a skymodel
Uses the image from the first skymodel as the template for the image
:param sm: List of skymodels
:return: Image
"""
combined_model = sm[0].image.copy(deep=True)
combined_model["pixels"].data[...] = 0.0
for th in sm:
if th.image is not None:
if th.mask is not None:
combined_model["pixels"].data += (
th.mask["pixels"].data * th.image["pixels"].data
)
else:
combined_model["pixels"].data += th.image["pixels"].data
return combined_model
[docs]
def update_skymodel_from_image(sm, im, damping=0.5):
"""Update a skymodel for an image, applying damping factor
:param sm: List of skymodels
:param im: Image
:return: List of SkyModels
"""
for i, th in enumerate(sm):
newim = im.copy(deep=True)
if th.mask is not None:
newim["pixels"].data *= th.mask["pixels"].data
th.image["pixels"].data += damping * newim["pixels"].data
return sm
[docs]
def update_skymodel_from_gaintables(sm, gt_list, calibration_context="T", damping=0.5):
"""Update a skymodel from a list of gaintables
:param sm: List of skymodels
:param gt_list: List of gain tables
:param calibration_context: Type of gaintable e.g. 'T', 'G'
:return: List of skymodels
"""
assert len(sm) == len(gt_list)
for i, th in enumerate(sm):
th.gaintable["gain"].data *= numpy.exp(
damping * 1j * numpy.angle(gt_list[i][calibration_context].gain)
)
return sm
[docs]
def expand_skymodel_by_skycomponents(sm, **kwargs):
"""Expand a sky model so that all components and the image are in separate skymodels
The mask and gaintable are taken to apply for all new skymodels.
:param sm: SkyModel
:return: List of SkyModels
"""
def copy_image(im):
"""Copy an image
:param im:
:return:
"""
if im is None:
return None
else:
return im.copy(deep=True)
result = [
SkyModel(
components=[comp],
image=None,
gaintable=sm.gaintable.copy(deep=True) if sm.gaintable else None,
mask=copy_image(sm.mask),
fixed=sm.fixed,
)
for comp in sm.components
]
if sm.image is not None:
result.append(
SkyModel(
components=None,
image=copy_image(sm.image),
gaintable=sm.gaintable.copy(deep=True) if sm.gaintable else None,
mask=copy_image(sm.mask),
fixed=sm.fixed,
)
)
return result
[docs]
def create_skymodel_from_skycomponents_gaintables(components, gaintables, **kwargs):
"""Create a list of sky model from lists of components and gaintables
:param sm: SkyModel
:return: List of SkyModels
"""
assert len(components) == len(gaintables)
result = [
SkyModel(
components=[comp.copy()],
image=None,
mask=None,
gaintable=gaintables[icomp].copy(deep=True),
)
for icomp, comp in enumerate(components)
]
return result