Source code for ska_sdp_instrumental_calibration.xarray_processors.predict

import logging

import dask.array as da
import numpy as np
import xarray as xr
from astropy.coordinates import SkyCoord
from ska_sdp_datamodels.visibility import Visibility

from ..data_managers.beams import BeamsFactory
from ..data_managers.sky_model import GlobalSkyModel, LocalSkyModel
from ._utils import with_chunks

logger = logging.getLogger(__name__)


def _predict_vis_ufunc(
    uvw: np.ndarray,
    frequency: np.ndarray,
    station_rm: np.ndarray | None,
    polarisation: np.ndarray,
    antenna1: np.ndarray,
    antenna2: np.ndarray,
    phasecentre: SkyCoord,
    local_sky_model: LocalSkyModel,
    beams_factory: BeamsFactory = None,
    output_dtype: type = np.complex64,
) -> np.ndarray:
    """
    A helper function which bridges the gap between
    LocalSkyModel.create_vis and predict_vis_new functions

    :param uvw: (time, frequency, baselineid, spatial)
    :param frequency: (frequency,)
    :param station_rm: (nant,) or None
    :param polarisation: (polarisation,)
    :param antenna1: (nant,)
    :param antenna2: (nant,)
    :param phasecentre
    :param local_sky_model
    :param beams_factory
    :param output_dtype

    returns: (time, frequency, baselineid, polarisation)
    """
    # Need to remove extra frequency dimension from uvw
    uvw_shape = uvw.shape
    uvw = uvw.reshape(uvw_shape[0], *uvw_shape[2:])

    return local_sky_model.create_vis(
        uvw,
        frequency,
        polarisation,
        phasecentre,
        antenna1,
        antenna2,
        beams_factory,
        station_rm,
        output_dtype,
    ).transpose(
        0, 2, 1, 3  # time, frequency, baselineid, polarisation
    )


[docs] def predict_vis( vis: Visibility, gsm: GlobalSkyModel, soln_time: np.ndarray, soln_interval_slices: list[slice], beams_factory: BeamsFactory = None, station_rm: xr.DataArray = None, ) -> Visibility: """ Predict visibilities from a Global Sky Model (Distributed). This function computes expected visibilities by evaluating the radio interferometry measurement equation. It uses a Global Sky Model (GSM) to generate local sky models for specific time steps and applies primary beam effects if a beam factory is provided. The prediction is distributed and chunked: 1. **Time:** Iterates over solution intervals defined by `soln_time` and `soln_interval_slices`. 2. **Frequency:** Handled via Dask chunking on the input visibility object. Parameters ---------- vis The template visibility dataset. Its structure (time, frequency, baselines, UVW coordinates) determines the grid for the prediction. gsm The sky model source. The method `get_local_sky_model` is called for each solution time to retrieve sources above the horizon. soln_time Array of timestamps (float, seconds) representing the center of each solution interval. soln_interval_slices A list of slice objects. Each slice corresponds to a timestamp in `soln_time` and selects the range of time indices in `vis` that fall within that interval. beams_factory Factory to generate antenna primary beams. If None, the sky model is assumed to be multiplied by unity (no beam attenuation). station_rm Station-based Rotation Measures (RM) for ionospheric Faraday rotation simulation. If provided, it is passed to the prediction kernel. Returns ------- A new Visibility object containing the predicted data in the `vis` variable. Metadata and coordinates are preserved from the input `vis`. Raises ------ AssertionError If the number of solution times does not match the number of interval slices. """ assert len(soln_interval_slices) == len( soln_time ), "lengths of soln_interval_slices and soln_time do not match" common_input_args = [] common_input_core_dims = [] input_kwargs = dict( polarisation=vis.polarisation, antenna1=vis.antenna1, antenna2=vis.antenna2, phasecentre=vis.phasecentre, beams_factory=beams_factory, output_dtype=vis.vis.dtype, ) # Process frequency # Convert frequency to a chunked dask array frequency_xdr = xr.DataArray(vis.frequency, name="frequency_xdr").pipe( with_chunks, vis.chunksizes ) common_input_args.append(frequency_xdr) common_input_core_dims.append([]) # Process station_rm if station_rm is not None: # Ensure that it is not chunked across any dim # It can still be a dask array if isinstance(station_rm, da.Array): # "id" is a coordinate from Configuration dataset station_rm = xr.DataArray( station_rm, coords={"id": np.arange(len(station_rm))} ) station_rm = station_rm.chunk(-1) common_input_args.append(station_rm) common_input_core_dims.append(list(station_rm.dims)) else: input_kwargs["station_rm"] = None # Call predict ufunc, once per solution interval predicted_across_soln_time = [] for idx, slc in enumerate(soln_interval_slices): local_sky_model = gsm.get_local_sky_model( soln_time[idx], vis.configuration.location ) predicted_per_soln_time: xr.DataArray = xr.apply_ufunc( _predict_vis_ufunc, vis.uvw.isel(time=slc), *common_input_args, input_core_dims=[ ["baselineid", "spatial"], *common_input_core_dims, ], output_core_dims=[ ["baselineid", "polarisation"], ], dask="parallelized", output_dtypes=[vis.vis.dtype], dask_gufunc_kwargs=dict( output_sizes={ "baselineid": vis.baselineid.size, "polarisation": vis.polarisation.size, } ), kwargs=dict( **input_kwargs, local_sky_model=local_sky_model, ), ) predicted_per_soln_time = predicted_per_soln_time.transpose( "time", "baselineid", "frequency", "polarisation" ) predicted_across_soln_time.append(predicted_per_soln_time) predicted: xr.DataArray = xr.concat(predicted_across_soln_time, dim="time") predicted = predicted.assign_attrs(vis.vis.attrs) return vis.assign({"vis": predicted})