Source code for ska_sdp_datamodels.gridded_visibility.grid_vis_model

# pylint: disable=too-many-ancestors,too-many-locals,invalid-name
"""
Gridded visibility models.
"""

import numpy
import xarray

from ska_sdp_datamodels.science_data_model import (
    PolarisationFrame,
    QualityAssessment,
)
from ska_sdp_datamodels.xarray_accessor import XarrayAccessorMixin
from ska_sdp_datamodels.xarray_coordinate_support import (
    conv_func_wcs,
    griddata_wcs,
    image_wcs,
)


[docs] class GridData(xarray.Dataset): """Class to hold Gridded data for Fourier processing - Has four or more coordinates: [chan, pol, z, y, x] where x can be u, l; y can be v, m; z can be w, n. Note: current implementation only uses 4 coordinates: [nchan, npol, v, u] The conventions for indexing in WCS and numpy are opposite: - In astropy.wcs, the order is (longitude, latitude, polarisation, frequency); - in numpy, the order is (frequency, polarisation, depth, latitude, longitude). .. warning:: The polarisation_frame is kept in two places, the WCS and the polarisation_frame variable. The latter should be considered definitive. Here is an example:: <xarray.GridData> Dimensions: (frequency: 3, polarisation: 4, v: 256, u: 256) Coordinates: * frequency (frequency) float64 1e+08 1.01e+08 1.02e+08 * polarisation (polarisation) <U1 'I' 'Q' 'U' 'V' * v (v) float64 -3.333e+04 -3.307e+04 ... 3.307e+04 * u (u) float64 3.333e+04 3.307e+04 ... -3.307e+04 Data variables: pixels (frequency, polarisation, v, u) complex128 0j 0j 0j ... 0j 0j Attributes: data_model: GridData _polarisation_frame: stokesIQUV """ __slots__ = () def __init__( self, data_vars=None, coords=None, attrs=None, ): super().__init__(data_vars, coords=coords, attrs=attrs)
[docs] @classmethod def constructor(cls, data, polarisation_frame=None, grid_wcs=None): """ Create a GridData :param data: pixel data array; dims: ["frequency", "polarisation", "v", "u"] :param polarisation_frame: PolarisationFrame object :param grid_wcs: astropy WCS object :return: GridData """ nchan, npol, nv, nu = data.shape frequency = grid_wcs.sub([4]).wcs_pix2world(range(nchan), 0)[0] if not npol == polarisation_frame.npol: raise ValueError( "Polarisation dimensions of input PolarisationFrame " "does not mach that of data polarisation dimensions: " f"{polarisation_frame.npol} != {npol}" ) cu = grid_wcs.wcs.crval[0] cv = grid_wcs.wcs.crval[1] du = grid_wcs.wcs.cdelt[0] dv = grid_wcs.wcs.cdelt[1] dims = ["frequency", "polarisation", "v", "u"] # Define the coordinates on these dimensions coords = { "frequency": frequency, "polarisation": polarisation_frame.names, "v": numpy.linspace( cv - dv * nv / 2, cv + dv * nv / 2, nv, endpoint=False ), "u": numpy.linspace( cu - du * nu / 2, cu + du * nu / 2, nu, endpoint=False ), } attrs = {} attrs["data_model"] = "GridData" attrs["_polarisation_frame"] = polarisation_frame.type data_vars = {} data_vars["pixels"] = xarray.DataArray(data, dims=dims, coords=coords) return cls(data_vars, coords=coords, attrs=attrs)
def __sizeof__(self): """Override default method to return size of dataset :return: int """ # Dask uses sizeof() class to get memory occupied by various data # objects. For custom data objects like this one, dask falls back to # sys.getsizeof() function to get memory usage. sys.getsizeof() in # turns calls __sizeof__() magic method to get memory size. Here we # override the default method (which gives size of reference table) # to return size of Dataset. return int(self.nbytes)
[docs] @xarray.register_dataset_accessor("griddata_acc") class GridDataAccessor(XarrayAccessorMixin): """ GridDataAccessor property accessor """ @property def nchan(self): """Number of channels""" return len(self._obj.frequency) @property def npol(self): """Number of polarisations""" return PolarisationFrame(self._obj.attrs["_polarisation_frame"]).npol @property def polarisation_frame(self): """Polarisation frame (from coords)""" return PolarisationFrame(self._obj.attrs["_polarisation_frame"]) @property def shape(self): """Shape of data array""" return self._obj["pixels"].data.shape @property def griddata_wcs(self): """Return the equivalent WCS coordinates""" return griddata_wcs(self._obj) @property def projection_wcs(self): """Return the projected WCS coordinates on image""" return image_wcs(self._obj) def qa_grid_data(self, context="") -> QualityAssessment: """Assess the quality of a griddata :return: QualityAssessment """ grid_data = self._obj["pixels"].data data = { "shape": str(self._obj["pixels"].data.shape), "max": numpy.max(grid_data), "min": numpy.min(grid_data), "rms": numpy.std(grid_data), "sum": numpy.sum(grid_data), "medianabs": numpy.median(numpy.abs(grid_data)), "median": numpy.median(grid_data), } qa = QualityAssessment( origin="qa_grid_data", data=data, context=context ) return qa
[docs] class ConvolutionFunction(xarray.Dataset): """ Class to hold Convolution function for Fourier processing - Has four or more coordinates: [chan, pol, z, y, x] where x can be u, l; y can be v, m; z can be w, n. The cf has axes [chan, pol, dy, dx, y, x] where z, y, x are spatial axes in either sky or Fourier plane. The order in the WCS is reversed so the grid_WCS describes UU, VV, WW, STOKES, FREQ axes. The axes UU,VV have the same physical stride as the Image. The axes DUU, DVV are sub-sampled. Convolution function holds the original sky plane projection in the projection_wcs. Here is an example:: <xarray.ConvolutionFunction> Dimensions: (du: 8, dv: 8, frequency: 1, polarisation: 1, u: 16, v: 16, w: 1) Coordinates: * frequency (frequency) float64 1e+08 * polarisation (polarisation) <U1 'I' * w (w) float64 0.0 * dv (dv) float64 -1.031e+05 -7.735e+04 ... 7.735e+04 * du (du) float64 -1.031e+05 -7.735e+04 ... 7.735e+04 * v (v) float64 -1.65e+06 -1.444e+06 ... 1.444e+06 * u (u) float64 -1.65e+06 -1.444e+06 ... 1.444e+06 Data variables: pixels (frequency, polarisation, w, dv, du, v, u) complex128 0j ...... Attributes: data_model: ConvolutionFunction grid_wcs: WCS Keywords Number of WCS axes: 7 CTYPE : 'UU' ... projection_wcs: WCS Keywords Number of WCS axes: 4 CTYPE : 'RA--... polarisation_frame: stokesI """ # noqa:E501 __slots__ = () def __init__( self, data_vars=None, coords=None, attrs=None, ): super().__init__(data_vars, coords=coords, attrs=attrs)
[docs] @classmethod def constructor(cls, data, cf_wcs=None, polarisation_frame=None): """ Create ConvolutionFunction :param data: Data for cf :param cf_wcs: Astropy WCS object for the grid :param polarisation_frame: Polarisation_frame e.g. PolarisationFrame('linear') """ nchan, npol, nw, oversampling, _, support, _ = data.shape frequency = cf_wcs.sub(["spectral"]).wcs_pix2world(range(nchan), 0)[0] assert npol == polarisation_frame.npol, ( "Mismatch between requested image polarisation " "and actual visibility polarisation" ) du = cf_wcs.wcs.cdelt[0] dv = cf_wcs.wcs.cdelt[1] ddu = cf_wcs.wcs.cdelt[0] / oversampling ddv = cf_wcs.wcs.cdelt[1] / oversampling cu = cf_wcs.wcs.crval[0] cv = cf_wcs.wcs.crval[1] cdu = oversampling // 2 cdv = oversampling // 2 wstep = numpy.abs(cf_wcs.wcs.cdelt[4]) coords = { "frequency": frequency, "polarisation": polarisation_frame.names, "dv": numpy.linspace( cdv - ddv * oversampling / 2, cdv + ddv * oversampling / 2, oversampling, endpoint=False, ), "du": numpy.linspace( cdu - ddu * oversampling / 2, cdu + ddu * oversampling / 2, oversampling, endpoint=False, ), "w": numpy.linspace( -wstep * nw / 2, wstep * nw / 2, nw, endpoint=False ), "v": numpy.linspace( cv - dv * support / 2, cv + dv * support / 2, support, endpoint=False, ), "u": numpy.linspace( cu - du * support / 2, cu + du * support / 2, support, endpoint=False, ), } if nw == 1: coords["w"]: numpy.zeros([1]) dims = ["frequency", "polarisation", "w", "dv", "du", "u", "v"] assert coords["u"][0] != coords["u"][-1] assert coords["v"][0] != coords["v"][-1] attrs = {} attrs["data_model"] = "ConvolutionFunction" attrs["_polarisation_frame"] = polarisation_frame.type nchan = len(frequency) npol = polarisation_frame.npol if data is None: data = numpy.zeros( [ nchan, npol, nw, oversampling, oversampling, support, support, ], dtype="complex", ) else: assert data.shape == ( nchan, npol, nw, oversampling, oversampling, support, support, ), ( f"Polarisation frame {polarisation_frame.type} and data shape " f"{data.shape} are incompatible" ) data_vars = {} data_vars["pixels"] = xarray.DataArray(data, dims=dims, coords=coords) return cls(data_vars, coords=coords, attrs=attrs)
def __sizeof__(self): """Override default method to return size of dataset :return: int """ # Dask uses sizeof() class to get memory occupied by various data # objects. For custom data objects like this one, dask falls back to # sys.getsizeof() function to get memory usage. sys.getsizeof() in # turns calls __sizeof__() magic method to get memory size. Here we # override the default method (which gives size of reference table) # to return size of Dataset. return int(self.nbytes)
[docs] @xarray.register_dataset_accessor("convolutionfunction_acc") class ConvolutionFunctionAccessor(XarrayAccessorMixin): """ ConvolutionFunction property accessor """ @property def nchan(self): """Number of channels""" return len(self._obj.frequency) @property def npol(self): """Number of polarisations""" return PolarisationFrame(self._obj.attrs["_polarisation_frame"]).npol @property def cf_wcs(self): """Return the equivalent WCS coordinates""" return conv_func_wcs(self._obj) @property def shape(self): """Shape of data array""" return self._obj["pixels"].data.shape @property def polarisation_frame(self): """Polarisation frame (from coords)""" return PolarisationFrame(self._obj.attrs["_polarisation_frame"]) def qa_convolution_function(self, context="") -> QualityAssessment: """Assess the quality of a ConvolutionFunction :return: QualityAssessment """ conv_func_data = self._obj["pixels"].data data = { "shape": str(self._obj["pixels"].data.shape), "max": numpy.max(conv_func_data), "min": numpy.min(conv_func_data), "rms": numpy.std(conv_func_data), "sum": numpy.sum(conv_func_data), "medianabs": numpy.median(numpy.abs(conv_func_data)), "median": numpy.median(conv_func_data), } qa = QualityAssessment( origin="qa_convolution_function", data=data, context=context ) return qa