""" Image operations visible to the Execution Framework as Components
"""
__all__ = [
"add_image",
"average_image_over_frequency",
"create_w_term_like",
"create_window",
"fft_image_to_griddata_with_wcs",
"import_image_from_fits",
"pad_image",
"sub_image",
"polarisation_frame_from_wcs",
"remove_continuum_image",
"reproject_image",
"show_components",
"show_image",
"smooth_image",
"scale_and_rotate_image",
"apply_voltage_pattern_to_image",
]
import copy
import logging
import warnings
import numpy
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.wcs import WCS, FITSFixedWarning
from astropy.wcs.utils import skycoord_to_pixel
from reproject import reproject_interp
from ska_sdp_datamodels.gridded_visibility import create_griddata_from_image
from ska_sdp_datamodels.image.image_model import Image
from ska_sdp_datamodels.science_data_model.polarisation_model import PolarisationFrame
from ska_sdp_func_python.calibration import apply_jones
from ska_sdp_func_python.fourier_transforms import fft, ifft, w_beam
from ska_sdp_func_python.image import (
convert_polimage_to_stokes,
convert_stokes_to_polimage,
)
from rascil.processing_components.parameters import get_parameter
warnings.simplefilter("ignore", FITSFixedWarning)
log = logging.getLogger("rascil-logger")
[docs]
def import_image_from_fits(fitsfile: str, fixpol=True) -> Image:
"""Read an Image from fits
:param fitsfile: FITS file in storage
:return: Image
"""
warnings.simplefilter("ignore", FITSFixedWarning)
hdulist = fits.open(fitsfile)
data = hdulist[0].data
header = hdulist[0].header
try:
bmaj = header.cards["BMAJ"].value
bmin = header.cards["BMIN"].value
bpa = header.cards["BPA"].value
clean_beam = {"bmaj": bmaj, "bmin": bmin, "bpa": bpa}
except KeyError:
clean_beam = None
wcs = WCS(fitsfile)
hdulist.close()
polarisation_frame = PolarisationFrame("stokesI")
if len(data.shape) == 4:
# RASCIL images are RA, DEC, STOKES, FREQ
if wcs.axis_type_names[3] == "STOKES" or wcs.axis_type_names[2] == "FREQ":
wcs = wcs.swapaxes(2, 3)
data = numpy.transpose(data, (1, 0, 2, 3))
try:
polarisation_frame = polarisation_frame_from_wcs(wcs, data.shape)
# FITS and RASCIL polarisation conventions differ
if fixpol:
permute = polarisation_frame.fits_to_datamodels[polarisation_frame.type]
newim_data = data.copy()
for ip, p in enumerate(permute):
newim_data[:, p, ...] = data[:, ip, ...]
data = newim_data
except ValueError:
polarisation_frame = PolarisationFrame("stokesI")
elif len(data.shape) == 2:
ny, nx = data.shape
data.reshape([1, 1, ny, nx])
log.debug(
"import_image_from_fits: created %s image of shape %s"
% (data.dtype, str(data.shape))
)
log.debug(
"import_image_from_fits: Max, min in %s = %.6f, %.6f"
% (fitsfile, data.max(), data.min())
)
return Image.constructor(
data=data, polarisation_frame=polarisation_frame, wcs=wcs, clean_beam=clean_beam
)
[docs]
def reproject_image(im: Image, newwcs: WCS, shape=None) -> (Image, Image):
"""Re-project an image to a new coordinate system
Currently uses the reproject python package. This seems to have some features do
be careful using this method.
For timeslice imaging griddata is used.
:param im: Image to be reprojected
:param newwcs: New WCS
:param shape: Desired shape
:return: Reprojected Image, Footprint Image
"""
if len(im["pixels"].shape) == 4:
nchan, npol, ny, nx = im["pixels"].shape
if im["pixels"].data.dtype == "complex":
rep_real = numpy.zeros(shape, dtype="float")
rep_imag = numpy.zeros(shape, dtype="float")
foot = numpy.zeros(shape, dtype="float")
for chan in range(nchan):
for pol in range(npol):
rep_real[chan, pol], foot[chan, pol] = reproject_interp(
(im["pixels"].data.real[chan, pol], im.image_acc.wcs.sub(2)),
newwcs.sub(2),
shape[2:],
order="bicubic",
)
rep_imag[chan, pol], foot[chan, pol] = reproject_interp(
(im["pixels"].data.imag[chan, pol], im.image_acc.wcs.sub(2)),
newwcs.sub(2),
shape[2:],
order="bicubic",
)
rep = rep_real + 1j * rep_imag
else:
rep = numpy.zeros(shape, dtype="float")
foot = numpy.zeros(shape, dtype="float")
for chan in range(nchan):
for pol in range(npol):
rep[chan, pol], foot[chan, pol] = reproject_interp(
(im["pixels"].data[chan, pol], im.image_acc.wcs.sub(2)),
newwcs.sub(2),
shape[2:],
order="bicubic",
)
if numpy.sum(foot.data) < 1e-12:
log.warning("reproject_image: no valid points in reprojection")
elif len(im["pixels"].data.shape) == 2:
if im["pixels"].data.dtype == "complex":
rep_real, foot = reproject_interp(
(im["pixels"].data.real, im.image_acc.wcs),
newwcs,
shape,
order="bicubic",
)
rep_imag, foot = reproject_interp(
(im["pixels"].data.imag, im.image_acc.wcs),
newwcs,
shape,
order="bicubic",
)
rep = rep_real + 1j * rep_imag
else:
rep, foot = reproject_interp(
(im["pixels"].data, im.image_acc.wcs), newwcs, shape, order="bicubic"
)
if numpy.sum(foot.data) < 1e-12:
log.warning("reproject_image: no valid points in reprojection")
else:
raise ValueError(
"Cannot reproject image with shape {}".format(im["pixels"].shape)
)
rep = numpy.nan_to_num(rep)
foot = numpy.nan_to_num(foot)
return (
Image.constructor(rep, im.image_acc.polarisation_frame, newwcs),
Image.constructor(foot, im.image_acc.polarisation_frame, newwcs),
)
[docs]
def add_image(im1: Image, im2: Image) -> Image:
"""Add two images
:param im1: Image
:param im2: Image
:return: Image
"""
return Image.constructor(
data=im1["pixels"].data + im2["pixels"].data,
polarisation_frame=im1.image_acc.polarisation_frame,
wcs=im1.image_acc.wcs,
)
[docs]
def show_image(
im: Image,
fig=None,
title: str = "",
pol=0,
chan=0,
cm="Greys",
components=None,
vmin=None,
vmax=None,
vscale=1.0,
):
"""Show an Image with coordinates using matplotlib, optionally with components
:param im: Image
:param fig: Matplotlib figure
:param title: String for title of plot
:param pol: Polarisation to show (index)
:param chan: Channel to show (index)
:param components: Optional components to be overlaid
:param vmin: Clip to this minimum
:param vmax: Clip to this maximum
:param vscale: scale max, min by this amount
:return:
"""
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=im.image_acc.wcs.sub([1, 2]))
if len(im["pixels"].data.shape) == 4:
data_array = numpy.real(im["pixels"].data[chan, pol, :, :])
else:
data_array = numpy.real(im["pixels"].data)
if vmax is None:
vmax = vscale * numpy.max(data_array)
if vmin is None:
vmin = vscale * numpy.min(data_array)
cm = ax.imshow(data_array, origin="lower", cmap=cm, vmax=vmax, vmin=vmin)
ax.set_xlabel(im.image_acc.wcs.wcs.ctype[0])
ax.set_ylabel(im.image_acc.wcs.wcs.ctype[1])
ax.set_title(title)
fig.colorbar(cm, orientation="vertical", shrink=0.7)
if components is not None:
for sc in components:
x, y = skycoord_to_pixel(sc.direction, im.image_acc.wcs, 0, "wcs")
ax.plot(x, y, marker="+", color="red")
return fig
[docs]
def show_components(im, comps, npixels=128, fig=None, vmax=None, vmin=None, title=""):
"""Show components against an image
:param im:
:param comps:
:param npixels:
:param fig:
:return:
"""
import matplotlib.pyplot as plt
if vmax is None:
vmax = numpy.max(im["pixels"].data[0, 0, ...])
if vmin is None:
vmin = numpy.min(im["pixels"].data[0, 0, ...])
if not fig:
fig = plt.figure()
plt.clf()
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
for isc, sc in enumerate(comps):
newim = im.copy(deep=True)
plt.subplot(111, projection=newim.image_acc.wcs.sub([1, 2]))
centre = numpy.round(
skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 1, "wcs")
).astype("int")
newim["pixels"].data = newim["pixels"].data[
:,
:,
(centre[1] - npixels // 2) : (centre[1] + npixels // 2),
(centre[0] - npixels // 2) : (centre[0] + npixels // 2),
]
newim.image_acc.wcs.wcs.crpix[0] -= centre[0] - npixels // 2
newim.image_acc.wcs.wcs.crpix[1] -= centre[1] - npixels // 2
plt.imshow(
newim["pixels"].data[0, 0, ...],
origin="lower",
cmap="Greys",
vmax=vmax,
vmin=vmin,
)
x, y = skycoord_to_pixel(sc.direction, newim.image_acc.wcs, 0, "wcs")
plt.plot(x, y, marker="+", color="red")
plt.title("Name = %s, flux = %s" % (sc.name, sc.flux))
plt.show()
[docs]
def smooth_image(model: Image, width=1.0, normalise=True):
"""Smooth an image with a 2D Gaussian kernel
:param model: Image
:param width: Kernel width in pixels
:param normalise: Normalise kernel peak to unity
"""
assert isinstance(model, Image), model
assert model.image_acc.is_canonical()
from astropy.convolution import convolve_fft
from astropy.convolution.kernels import Gaussian2DKernel
kernel = Gaussian2DKernel(width)
model_type = model["pixels"].data.dtype
cmodel = Image.constructor(
data=numpy.zeros_like(model["pixels"].data),
polarisation_frame=model.image_acc.polarisation_frame,
wcs=model.image_acc.wcs,
clean_beam=model.attrs["clean_beam"],
)
nchan, npol, _, _ = model.image_acc.shape
for pol in range(npol):
for chan in range(nchan):
cmodel["pixels"].data[chan, pol, :, :] = convolve_fft(
model["pixels"].data[chan, pol, :, :],
kernel,
normalize_kernel=False,
allow_huge=True,
)
# The convolve_fft step seems to return an object dtype
cmodel["pixels"].data = cmodel["pixels"].data.astype(model_type)
if normalise and isinstance(kernel, Gaussian2DKernel):
cmodel["pixels"].data *= 2 * numpy.pi * width**2
return cmodel
[docs]
def average_image_over_frequency(im: Image) -> Image:
"""Integrate image across frequency
:return: Integrated image
"""
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
nchannels = len(im.frequency.data)
newim_data = numpy.mean(im["pixels"].data, axis=0)[numpy.newaxis, ...]
assert not numpy.isnan(numpy.sum(im["pixels"].data)), "NaNs present in image data"
newim_wcs = copy.deepcopy(im.image_acc.wcs)
newim_wcs.wcs.crval[3] = numpy.average(im.frequency.data)
newim_wcs.wcs.crpix[3] = 1
return Image.constructor(
data=newim_data,
polarisation_frame=im.image_acc.polarisation_frame,
wcs=newim_wcs,
)
[docs]
def remove_continuum_image(im: Image, degree=1, mask=None):
"""Fit and remove continuum visibility in place
Fit a polynomial in frequency of the specified degree where mask is
True and remove it from the image
:param im:
:param degree: 1 is a constant, 2 is a slope, etc.
:param mask: Frequency mask
:return:
"""
assert isinstance(im, Image), im
assert im.image_acc.is_canonical()
if mask is not None:
assert numpy.sum(mask) > 2 * degree, "Insufficient channels for fit"
nchan, npol, ny, nx = im["pixels"].data.shape
channels = numpy.arange(nchan)
frequency = im.image_acc.wcs.sub(["spectral"]).wcs_pix2world(channels, 0)[0]
frequency -= frequency[nchan // 2]
frequency /= numpy.max(frequency)
wt = numpy.ones_like(frequency)
if mask is not None:
wt[mask] = 0.0
for pol in range(npol):
for y in range(ny):
for x in range(nx):
fit = numpy.polyfit(
frequency, im["pixels"].data[:, pol, y, x], w=wt, deg=degree
)
prediction = numpy.polyval(fit, frequency)
im["pixels"].data[:, pol, y, x] -= prediction
return im
[docs]
def create_window(template, window_type, **kwargs):
"""Create a window image using one of a number of methods
The window is 1.0 or 0.0
window types:
'quarter': Inner quarter of the image
'no_edge': 'window_edge' pixels around edge set to zero
'threshold': template image pixels < 'window_threshold' absolute
value set to zero
:param template: Template image
:param window_type: 'quarter' | 'no_edge' | 'threshold'
:return: New image containing window
See also
:py:func:`rascil.processing_components.image.deconvolution.deconvolve_cube`
"""
assert isinstance(template, Image), template
assert template.image_acc.is_canonical()
window = Image.constructor(
data=numpy.zeros_like(template["pixels"].data),
polarisation_frame=template.image_acc.polarisation_frame,
wcs=template.image_acc.wcs,
clean_beam=template.attrs["clean_beam"],
)
if window_type == "quarter":
qx = template["pixels"].shape[3] // 4
qy = template["pixels"].shape[2] // 4
window["pixels"].data[..., (qy + 1) : 3 * qy, (qx + 1) : 3 * qx] = 1.0
log.info("create_mask: Cleaning inner quarter of each sky plane")
elif window_type == "no_edge":
edge = get_parameter(kwargs, "window_edge", 16)
nx = template["pixels"].shape[3]
ny = template["pixels"].shape[2]
window["pixels"].data[
..., (edge + 1) : (ny - edge), (edge + 1) : (nx - edge)
] = 1.0
log.info("create_mask: Window omits %d-pixel edge of each sky plane" % (edge))
elif window_type == "threshold":
window_threshold = get_parameter(kwargs, "window_threshold", None)
if window_threshold is None:
window_threshold = 10.0 * numpy.std(template["pixels"].data)
window["pixels"].data[template["pixels"].data >= window_threshold] = 1.0
log.info("create_mask: Window omits all points below %g" % (window_threshold))
elif window_type is None:
log.info("create_mask: Mask covers entire image")
else:
raise ValueError("Window shape %s is not recognized" % window_type)
return window
[docs]
def polarisation_frame_from_wcs(wcs, shape) -> PolarisationFrame:
"""Convert wcs to polarisation_frame
See FITS definition in Table 29 of
https://fits.gsfc.nasa.gov/standard40/fits_standard40draft1.pdf
or subsequent revision
1 I Standard Stokes unpolarized
2 Q Standard Stokes linear
3 U Standard Stokes linear
4 V Standard Stokes circular
−1 RR Right-right circular
−2 LL Left-left circular
−3 RL Right-left cross-circular
−4 LR Left-right cross-circular
−5 XX X parallel linear
−6 YY Y parallel linear
−7 XY XY cross linear
−8 YX YX cross linear
stokesI [1]
stokesIQUV [1,2,3,4]
circular [-1,-2,-3,-4]
linear [-5,-6,-7,-8]
For example::
pol_frame =
polarisation_frame_from_wcs(im.image_acc.wcs, im["pixels"].data.shape)
:param wcs: World Coordinate System
:param shape: Shape corresponding to wcs
:returns: Polarisation_Frame object
"""
# The third axis should be stokes:
polarisation_frame = None
if len(shape) == 2:
polarisation_frame = PolarisationFrame("stokesI")
else:
npol = shape[1]
pol = wcs.sub(["stokes"]).wcs_pix2world(range(npol), 0)[0]
pol = numpy.array(pol, dtype="int")
for key in PolarisationFrame.fits_codes.keys():
keypol = numpy.array(PolarisationFrame.fits_codes[key])
if numpy.array_equal(pol, keypol):
polarisation_frame = PolarisationFrame(key)
return polarisation_frame
if polarisation_frame is None:
raise ValueError("Cannot determine polarisation code")
return polarisation_frame
[docs]
def fft_image_to_griddata_with_wcs(im):
"""WCS-aware FFT of a canonical image
The only transforms supported are:
RA--SIN, DEC--SIN <-> UU, VV
XX, YY <-> KX, KY
For example::
from rascil.processing_components import
create_test_image, fft_image_to_griddata_with_wcs
im = create_test_image()
print(im)
Image:
Shape: (1, 1, 256, 256)
WCS: WCS Keywords
Number of WCS axes: 4
CTYPE : 'RA---SIN' 'DEC--SIN' 'STOKES' 'FREQ'
CRVAL : 0.0 35.0 1.0 100000000.0
CRPIX : 129.0 129.0 1.0 1.0
PC1_1 PC1_2 PC1_3 PC1_4 : 1.0 0.0 0.0 0.0
PC2_1 PC2_2 PC2_3 PC2_4 : 0.0 1.0 0.0 0.0
PC3_1 PC3_2 PC3_3 PC3_4 : 0.0 0.0 1.0 0.0
PC4_1 PC4_2 PC4_3 PC4_4 : 0.0 0.0 0.0 1.0
CDELT : -0.000277777791 0.000277777791 1.0 100000.0
NAXIS : 0 0
Polarisation frame: stokesI
print(fft_image_to_griddata_with_wcs(im))
Image:
Shape: (1, 1, 256, 256)
WCS: WCS Keywords
Number of WCS axes: 4
CTYPE : 'UU' 'VV' 'STOKES' 'FREQ'
CRVAL : 0.0 0.0 1.0 100000000.0
CRPIX : 129.0 129.0 1.0 1.0
PC1_1 PC1_2 PC1_3 PC1_4 : 1.0 0.0 0.0 0.0
PC2_1 PC2_2 PC2_3 PC2_4 : 0.0 1.0 0.0 0.0
PC3_1 PC3_2 PC3_3 PC3_4 : 0.0 0.0 1.0 0.0
PC4_1 PC4_2 PC4_3 PC4_4 : 0.0 0.0 0.0 1.0
CDELT : -805.7218610503596 805.7218610503596 1.0 100000.0
NAXIS : 0 0
Polarisation frame: stokesI
:param im:
:return:
See also
:py:func:`ska_sdp_func_python.fourier_transforms.fft_support.fft`
:py:func:`ska_sdp_func_python.fourier_transforms.fft_support.ifft`
"""
assert im.attrs["data_model"] == "Image"
assert len(im["pixels"].data.shape) == 4
wcs = im.image_acc.wcs
if wcs.wcs.ctype[0] == "RA---SIN" and wcs.wcs.ctype[1] == "DEC--SIN":
ft_types = ["UU", "VV"]
elif wcs.wcs.ctype[0] == "XX" and wcs.wcs.ctype[1] == "YY":
ft_types = ["KX", "KY"]
elif wcs.wcs.ctype[0] == "AZELGEO long" and wcs.wcs.ctype[1] == "AZELGEO lati":
ft_types = ["KX", "KY"]
else:
raise NotImplementedError(
"Cannot FFT specified axes {0}, {1}".format(
wcs.wcs.ctype[0], wcs.wcs.ctype[1]
)
)
gd = create_griddata_from_image(im, ft_types=ft_types)
gd["pixels"].data = ifft(im["pixels"].data.astype("complex"))
return gd
def ifft_griddata_to_image(gd, template):
"""WCS-aware FFT of a canonical image
The only transforms supported are:
RA--SIN, DEC--SIN <-> UU, VV
XX, YY <-> KX, KY
For example::
from rascil.processing_components import
create_test_image, fft_image_to_griddata
im = create_test_image()
print(im)
Image:
Shape: (1, 1, 256, 256)
WCS: WCS Keywords
Number of WCS axes: 4
CTYPE : 'RA---SIN' 'DEC--SIN' 'STOKES' 'FREQ'
CRVAL : 0.0 35.0 1.0 100000000.0
CRPIX : 129.0 129.0 1.0 1.0
PC1_1 PC1_2 PC1_3 PC1_4 : 1.0 0.0 0.0 0.0
PC2_1 PC2_2 PC2_3 PC2_4 : 0.0 1.0 0.0 0.0
PC3_1 PC3_2 PC3_3 PC3_4 : 0.0 0.0 1.0 0.0
PC4_1 PC4_2 PC4_3 PC4_4 : 0.0 0.0 0.0 1.0
CDELT : -0.000277777791 0.000277777791 1.0 100000.0
NAXIS : 0 0
Polarisation frame: stokesI
print(fft_image_to_griddata(im))
Image:
Shape: (1, 1, 256, 256)
WCS: WCS Keywords
Number of WCS axes: 4
CTYPE : 'UU' 'VV' 'STOKES' 'FREQ'
CRVAL : 0.0 0.0 1.0 100000000.0
CRPIX : 129.0 129.0 1.0 1.0
PC1_1 PC1_2 PC1_3 PC1_4 : 1.0 0.0 0.0 0.0
PC2_1 PC2_2 PC2_3 PC2_4 : 0.0 1.0 0.0 0.0
PC3_1 PC3_2 PC3_3 PC3_4 : 0.0 0.0 1.0 0.0
PC4_1 PC4_2 PC4_3 PC4_4 : 0.0 0.0 0.0 1.0
CDELT : -805.7218610503596 805.7218610503596 1.0 100000.0
NAXIS : 0 0
Polarisation frame: stokesI
:param gd: Input GridData
:param template_image: Template output image
:return: Image
See also
:py:func:`rascil.processing_components.fourier_transforms.fft_support.fft`
:py:func:`rascil.processing_components.fourier_transforms.fft_support.ifft`
"""
assert len(gd["pixels"].data.shape) == 4
wcs = gd.griddata_acc.griddata_wcs
template_wcs = template.image_acc.wcs
ft_wcs = copy.deepcopy(template_wcs)
if wcs.wcs.ctype[0] == "UU" and wcs.wcs.ctype[1] == "VV":
ft_wcs.wcs.ctype[0] = template_wcs.wcs.ctype[0]
ft_wcs.wcs.ctype[1] = template_wcs.wcs.ctype[1]
elif wcs.wcs.ctype[0] == "KX" and wcs.wcs.ctype[1] == "KY":
ft_wcs.wcs.ctype[0] = template_wcs.wcs.ctype[0]
ft_wcs.wcs.ctype[1] = template_wcs.wcs.ctype[1]
elif wcs.wcs.ctype[0] == "UU_AZELGEO" and wcs.wcs.ctype[1] == "VV_AZELGEO":
ft_wcs.wcs.ctype[0] = template_wcs.wcs.ctype[0]
ft_wcs.wcs.ctype[1] = template_wcs.wcs.ctype[1]
else:
raise NotImplementedError(
"Cannot IFFT specified axes {0}, {1}".format(
wcs.wcs.ctype[0], wcs.wcs.ctype[1]
)
)
ft_data = fft(gd["pixels"].data.astype("complex"))
return Image.constructor(
data=ft_data,
polarisation_frame=gd.griddata_acc.polarisation_frame,
wcs=template_wcs,
)
[docs]
def pad_image(im: Image, shape):
"""Pad an image to desired shape, adding equally to all edges
Appropriate for standard 4D image with axes (freq, pol, y, x). Only pads in y, x
The wcs crpix is adjusted appropriately.
:param im: Image to be padded
:param shape: Shape in 4 dimensions
:return: Padded image
"""
if im["pixels"].data.shape == shape:
return im
else:
newwcs = copy.deepcopy(im.image_acc.wcs)
newwcs.wcs.crpix[0] = (
im.image_acc.wcs.wcs.crpix[0]
+ shape[3] // 2
- im["pixels"].data.shape[3] // 2
)
newwcs.wcs.crpix[1] = (
im.image_acc.wcs.wcs.crpix[1]
+ shape[2] // 2
- im["pixels"].data.shape[2] // 2
)
for axis, _ in enumerate(im["pixels"].data.shape):
if shape[axis] < im["pixels"].data.shape[axis]:
raise ValueError(
"Padded shape %s is smaller than input shape %s"
% (shape, im["pixels"].data.shape)
)
newdata = numpy.zeros(shape, dtype=im["pixels"].dtype)
ystart = shape[2] // 2 - im["pixels"].data.shape[2] // 2
yend = ystart + im["pixels"].data.shape[2]
xstart = shape[3] // 2 - im["pixels"].data.shape[3] // 2
xend = xstart + im["pixels"].data.shape[3]
newdata[..., ystart:yend, xstart:xend] = im["pixels"][...]
return Image.constructor(
data=newdata, polarisation_frame=im.image_acc.polarisation_frame, wcs=newwcs
)
[docs]
def sub_image(im: Image, shape):
"""Subsection an image to desired shape, cutting equally from all edges
Appropriate for standard 4D image with axes (freq, pol, y, x). Only works in y, x
The wcs crpix is adjusted appropriately.
:param im: Image to be padded
:param shape: Shape in 4 dimensions
:return: Padded image
"""
if im["pixels"].data.shape == shape:
return im
else:
if len(shape) == 2:
shape = (1, 1, shape[0], shape[1])
newwcs = copy.deepcopy(im.image_acc.wcs)
newwcs.wcs.crpix[0] = (
im.image_acc.wcs.wcs.crpix[0]
+ shape[3] // 2
- im["pixels"].data.shape[3] // 2
)
newwcs.wcs.crpix[1] = (
im.image_acc.wcs.wcs.crpix[1]
+ shape[2] // 2
- im["pixels"].data.shape[2] // 2
)
for axis, _ in enumerate(im["pixels"].data.shape):
if shape[axis] > im["pixels"].data.shape[axis]:
raise ValueError(
"Padded shape %s is larger than input shape %s"
% (shape, im["pixels"].data.shape)
)
ystart = im["pixels"].data.shape[2] // 2 - shape[2] // 2
yend = ystart + shape[2]
xstart = im["pixels"].data.shape[3] // 2 - shape[3] // 2
xend = xstart + shape[3]
newdata = im["pixels"][..., ystart:yend, xstart:xend]
return Image.constructor(
data=newdata, polarisation_frame=im.image_acc.polarisation_frame, wcs=newwcs
)
[docs]
def create_w_term_like(
im: Image, w, phasecentre=None, remove_shift=False, dopol=False
) -> Image:
"""Create an image with a w term phase term in it:
.. math::
I(l,m) = e^{-2 \\pi j (w(\\sqrt{1-l^2-m^2}-1)}
The phasecentre is used as the delay centre for the w term (i.e. where n==0)
:param im: template image
:param phasecentre: SkyCoord definition of phasecentre
:param w: w value to evaluate
:param remove_shift:
:param dopol: Do screen in polarisation?
:return: Image
"""
fim_shape = list(im["pixels"].data.shape)
if not dopol:
fim_shape[1] = 1
wcs = im.image_acc.wcs
fim_array = numpy.zeros(fim_shape, dtype="complex")
cellsize = abs(wcs.wcs.cdelt[0]) * numpy.pi / 180.0
nchan, npol, _, npixel = fim_shape
if phasecentre is SkyCoord:
wcentre = phasecentre.to_pixel(wcs, origin=0)
else:
wcentre = [wcs.wcs.crpix[0] - 1.0, wcs.wcs.crpix[1] - 1.0]
fim_array[...] = w_beam(
npixel,
npixel * cellsize,
w=w,
cx=wcentre[0],
cy=wcentre[1],
remove_shift=remove_shift,
)[numpy.newaxis, numpy.newaxis, ...]
fim = Image.constructor(
data=fim_array, polarisation_frame=im.image_acc.polarisation_frame, wcs=wcs
)
fov = npixel * cellsize
fresnel = numpy.abs(w) * (0.5 * fov) ** 2
log.debug(
"create_w_term_image: For w = %.1f, field of view = %.6f, Fresnel number = %.2f"
% (w, fov, fresnel)
)
return fim
[docs]
def scale_and_rotate_image(im, angle=0.0, scale=None, order=5):
"""Scale and then rotate and image in x, y axes
Applies scale then rotates
:param im: Image
:param angle: Angle in radians
:param scale: Scale [scale_x, scale_y]
:param order: Order of interpolation (0-5)
:return:
"""
from scipy.ndimage.interpolation import affine_transform
nchan, npol, ny, nx = im["pixels"].data.shape
c_in = 0.5 * numpy.array([ny, nx])
c_out = 0.5 * numpy.array([ny, nx])
rot = numpy.array(
[[numpy.cos(angle), -numpy.sin(angle)], [numpy.sin(angle), numpy.cos(angle)]]
)
inv_rot = rot.T
if scale is None:
scale = [1.0, 1.0]
newim = Image.constructor(
data=numpy.zeros_like(im["pixels"].data),
polarisation_frame=im.image_acc.polarisation_frame,
wcs=im.image_acc.wcs,
clean_beam=im.attrs["clean_beam"],
)
inv_scale = numpy.diag(scale)
inv_transform = numpy.dot(inv_scale, inv_rot)
offset = c_in - numpy.dot(inv_transform, c_out)
for chan in range(nchan):
for pol in range(npol):
if im["pixels"].data.dtype == "complex":
newim["pixels"].data[chan, pol] = affine_transform(
im["pixels"].data[chan, pol].real,
inv_transform,
offset=offset,
order=order,
output_shape=(ny, nx),
).astype("float") + 1.0j * affine_transform(
im["pixels"].data[chan, pol].imag,
inv_transform,
offset=offset,
order=order,
output_shape=(ny, nx),
).astype(
"float"
)
elif im["pixels"].data.dtype == "float":
newim["pixels"].data[chan, pol] = affine_transform(
im["pixels"].data[chan, pol].real,
inv_transform,
offset=offset,
order=order,
output_shape=(ny, nx),
).astype("float")
else:
raise ValueError(
"Cannot process data type {}".format(im["pixels"].data.dtype)
)
return newim
def rotate_image(im, angle=0.0, order=5):
"""Rotate an image in x, y axes
:param im: Image
:param angle: Angle in radians
:param order: Order of interpolation (0-5)
:return:
"""
from scipy.ndimage.interpolation import rotate
newim = im.copy(deep=True)
if newim["pixels"].data.dtype == "complex":
newim["pixels"].data = rotate(
im["pixels"].data.real,
angle=numpy.rad2deg(angle),
axes=(-2, -1),
order=order,
) + 1j * rotate(
im["pixels"].data.imag,
angle=numpy.rad2deg(angle),
axes=(-2, -1),
order=order,
)
else:
newim["pixels"].data = rotate(
im["pixels"].data, angle=numpy.rad2deg(angle), axes=(-2, -1), order=order
)
return newim
[docs]
def apply_voltage_pattern_to_image(
im: Image, vp: Image, inverse=False, min_det=1e-1, **kwargs
) -> Image:
"""Apply a voltage pattern to an image
For each pixel, the application is as follows:
I_{corrected}(l,m) = vp(l,m) I(l,m) jones(j,m).H
:param im: Image to have jones applied
:param vp: Jones image to be applied
:param inverse: Apply the inverse (default=False)
:param min_det: Minimum determinant to correct
:return: new Image with Jones applied
"""
newim = Image.constructor(
data=numpy.zeros_like(im["pixels"].data),
polarisation_frame=im.image_acc.polarisation_frame,
wcs=im.image_acc.wcs,
clean_beam=im.attrs["clean_beam"],
)
if inverse:
log.debug("apply_gaintable: Apply inverse voltage pattern image")
else:
log.debug("apply_gaintable: Apply voltage pattern image")
is_scalar = vp.image_acc.shape[1] == 1
nchan, npol, ny, nx = im["pixels"].data.shape
assert im["pixels"].data.shape == vp["pixels"].data.shape
if is_scalar:
log.debug("apply_voltage_pattern_to_image: Scalar voltage pattern")
if inverse:
for chan in range(nchan):
pb = (
vp["pixels"].data[chan, 0, ...]
* numpy.conjugate(vp["pixels"].data[chan, 0, ...])
).real
newim["pixels"].data[chan, 0, ...] *= pb
else:
for chan in range(nchan):
pb = (
vp["pixels"].data[chan, 0, ...]
* numpy.conjugate(vp["pixels"].data[chan, 0, ...])
).real
mask = pb > 0.0
newim["pixels"].data[chan, 0, ...][mask] /= pb[mask]
else:
log.debug("apply_voltage_pattern_to_image: Full Jones voltage pattern")
polim = convert_stokes_to_polimage(im, vp.image_acc.polarisation_frame)
assert npol == 4
im_t = numpy.transpose(polim["pixels"].data, (0, 2, 3, 1)).reshape(
[nchan, ny, nx, 2, 2]
)
vp_t = numpy.transpose(vp["pixels"].data, (0, 2, 3, 1)).reshape(
[nchan, ny, nx, 2, 2]
)
newim_t = numpy.zeros([nchan, ny, nx, 2, 2], dtype="complex")
for chan in range(nchan):
for y in range(ny):
for x in range(nx):
newim_t[chan, y, x] = apply_jones(
vp_t[chan, y, x], im_t[chan, y, x], inverse, min_det=min_det
)
newim = Image.constructor(
data=newim_t.reshape([nchan, ny, nx, 4]).transpose((0, 3, 1, 2)),
polarisation_frame=vp.image_acc.polarisation_frame,
wcs=im.image_acc.wcs,
)
newim = convert_polimage_to_stokes(newim)
return newim