Source code for ska_sdp_wflow_pointing_offset.read_data

# pylint: disable=too-many-arguments,too-many-locals
# pylint: disable=too-many-positional-arguments

"""
Functions for reading data from Measurement Sets
and pulling static channel-dependent static RFI mask
from the SKA Telescope Model Data Repository
"""

import datetime
import logging
import warnings

import katpoint
import numpy
from casacore.tables import table  # pylint: disable=import-error
from ska_sdp_datamodels.visibility import create_visibility_from_ms
from ska_sdp_datamodels.visibility.vis_model import Visibility
from ska_telmodel.data import TMData

from ska_sdp_wflow_pointing_offset.array_data_func import (
    apply_rfi_mask,
    interp_timestamps,
)
from ska_sdp_wflow_pointing_offset.utils import (
    construct_antennas,
    select_channels_and_split,
)

warnings.simplefilter(action="ignore", category=FutureWarning)

LOG = logging.getLogger("ska-sdp-pointing-offset")


def _load_ms_tables(msname):
    """
    Loads Measurement Set.

    :param msname: Measurement set containing visibilities.

    :return: Data description table, spectral window, and
        pointing sub-tables.
    """

    # Get the data description, spectral window, pointing, and field sub-tables
    dd_table = table(msname + "::DATA_DESCRIPTION", ack=False)
    spw_table = table(msname + "::SPECTRAL_WINDOW", ack=False)
    pointing_table = table(msname + "::POINTING", ack=False)

    return dd_table, spw_table, pointing_table


def _get_pointing_unix(pointing_mjds):
    """
    Converts the pointing timestamps in MJD to UTC seconds
    since Unix epoch

    :param pointing_mjds: Pointing timestamps in MJD

    :return: The pointing timestamps in UTC seconds since
        Unix epoch
    """
    # MJD Start date: 1858-11-17
    base = datetime.datetime(1858, 11, 17)
    pointing_unix = numpy.zeros_like(pointing_mjds)
    for i, mjd in enumerate(pointing_mjds):
        obs_datetime = base + datetime.timedelta(days=mjd / 86400)
        pointing_unix[i] = obs_datetime.replace(
            tzinfo=datetime.timezone.utc
        ).timestamp()

    return pointing_unix


def _get_middle_timestamp(commanded_azel, pointing_mjds, nants):
    """
    Extracts the middle time of a pointing scan, its index, and reshapes
    and converts the commanded AzEl pointings to degrees

    :param commanded_azel: Commanded AzEl pointing in radians
        [nants*ntimes, 1, 2]
    :param pointing_mjds: Pointing timestamps in MJD
    :param nants: Number of antennas

    :return: The index of the middle pointing timestamp, middle pointing
        timestamp in MJD and the reshaped commanded AzEl pointing in degrees
    """

    # Get the middle timestamp
    middle_timestamp_index = numpy.argpartition(
        pointing_mjds, len(pointing_mjds) // 2
    )[len(pointing_mjds) // 2]

    # Reshape commanded pointings to (nants, ntimes, 2) and convert
    # to degrees
    commanded_azel = numpy.degrees(
        commanded_azel.reshape(nants, len(pointing_mjds), 2)
    )
    middle_timestamp = pointing_mjds[middle_timestamp_index]

    return middle_timestamp_index, middle_timestamp, commanded_azel


def _fix_antenna_timestamp_pointing_table(pointing_tab):
    """
    Fix pointing table when some antennas are missing data for some timestamps.
    We keep the data where each antenna has a timestamp and delete the rest.

    :param pointing_tab: Pointing table (CASACORE table)

    :return: Fixed Pointing table (CASACORE table)
    """
    antid_time = {}
    ant_time_cache = []
    for row in pointing_tab.iter(["ANTENNA_ID", "TIME"]):
        antid, idtime = row.getcol("ANTENNA_ID")[0], row.getcol("TIME")[0]
        if not antid_time.get(antid):
            antid_time[antid] = set()
        antid_time[antid].update([idtime])

        ant_time_cache.append((antid, idtime))

    overlap_time = set.intersection(*antid_time.values())

    overlap_dict = {}
    for ovt in overlap_time:
        overlap_dict[str(ovt)] = True

    select_rowid = []
    for rid, (antid, idtime) in enumerate(ant_time_cache):
        if overlap_dict.get(str(idtime), False):
            select_rowid.append(rid)

    fixed_table = pointing_tab.selectrows(select_rowid)

    return fixed_table


def _calc_on_sky_offsets_from_target_column(
    pointing_table,
    ants,
    pointing_mjds,
    source_ra,
    source_dec,
):
    """
    Calculates the commanded pointings relative to the target from
    all the pointing scans

    :param pointing_table: Pointing table (CASACORE table)
    :param ants: Number of antennas
    :param pointing_mjds: Pointing timestamps in MJD
    :param source_ra: Source RA in radians
    :param source_dec: Source DEC in radians

    :return: The commanded pointings relative to the target in degrees
    """
    # Get the commanded pointings in azimuth and elevation from
    # the TARGET column. They are in units of radians and have
    # shape (ntimes*nants, 1, 2). Convert the shape to
    # (nants, ntimes, 2)
    commanded_azel = numpy.squeeze(pointing_table.getcol("TARGET"))
    commanded_azel = commanded_azel.reshape(len(ants), len(pointing_mjds), 2)
    pointing_unix = _get_pointing_unix(pointing_mjds)

    on_sky_offsets = numpy.zeros(commanded_azel.shape)
    target_radec = katpoint.construct_radec_target(
        ra=source_ra, dec=source_dec
    )
    for i, antenna in enumerate(ants):
        # Compute the on-sky offset in cross-elevation and elevation
        # with crossel = (commanded_az - target_az) * cos(target_el)
        target_az, target_el = target_radec.azel(pointing_unix, antenna)

        dcross_el = (commanded_azel[i, :, 0] - target_az) * numpy.cos(
            target_el
        )
        delev = commanded_azel[i, :, 1] - target_el
        on_sky_offsets[i] = numpy.column_stack(
            (numpy.degrees(dcross_el), numpy.degrees(delev))
        )

    return on_sky_offsets


def _calc_track_duration(pointing_mjds):
    """
    Calculates how long each scan position was tracked for in seconds.
    We refer to this length of observation as track duration. Multiple
    timestamps are required to calculate the integration time and hence
    the track duration. If the time sample is just a single value, None
    is returned as the track duration because it cannot be calculated.

    :param pointing_mjds: Pointing timestamps in MJD for all
        antennas

    :return: The pointing timestamps in MJD seconds and track
        duration in seconds
    """
    # The timestamps for only a single antenna is needed
    pointing_mjds = pointing_mjds[0]

    # observation length = integration time * number of time samples
    if len(pointing_mjds) == 1:
        LOG.warning(
            "Pointing timestamps has length 1 so the integration "
            "time cannot be determined so track_duration is set to None"
        )
        track_duration = None
    else:
        int_time = pointing_mjds[1] - pointing_mjds[0]
        track_duration = int_time * len(pointing_mjds)

    return pointing_mjds, track_duration


def _read_pointing_data(pointing_table, vis, use_source_offset_column=False):
    """
    Read pointing data from MS pointing table.

    :param pointing_table: casacore table, Pointing table
    :param vis: Visibility object from the same MS
    :param use_source_offset_column: Read the on-sky offsets
        in cross-elevation and elevation from the SOURCE_OFFSET
        column of the pointing sub-tables?

    :return: on-sky offsets in cross-elevation
        and elevation, list of katpoint Antennas, pointing timestamps
        in MJD, the middle timestamp and the commanded azel at
        the middle timestamp, and track duration in seconds
    """
    # Build katpoint Antenna from antenna configuration
    antenna_positions = vis.configuration.data_vars["xyz"].data
    antenna_diameters = vis.configuration.data_vars["diameter"].data
    antenna_names = vis.configuration.data_vars["names"].data
    ants = construct_antennas(
        xyz=antenna_positions,
        diameter=antenna_diameters,
        station=antenna_names,
    )

    # Get the antenna pointing timestamps. Has shape (ntimes*nants) so
    # we can select the timestamps for just one antenna
    pointing_mjds = pointing_table.getcol("TIME")

    # When some antennas are missing data at some timestamp(s),
    # it then means the number of rows (nants*ntimestamps) will be
    # fewer than expected so the reshaping would cause a ValueError.
    try:
        pointing_mjds = pointing_mjds.reshape(
            len(ants), pointing_mjds.shape[0] // len(ants)
        )
    except ValueError:
        LOG.warning(
            "Inconsistent antenna-time shape in the pointing table."
            "Dropping pointing data where some antennas are missing "
            "data for a timestamp."
        )
        pointing_table = _fix_antenna_timestamp_pointing_table(pointing_table)
        pointing_mjds = pointing_table.getcol("TIME")
        pointing_mjds = pointing_mjds.reshape(
            len(ants), pointing_mjds.shape[0] // len(ants)
        )

    pointing_mjds, track_duration = _calc_track_duration(pointing_mjds)
    middle_timestamp_index, middle_timestamp, commanded_azel = (
        _get_middle_timestamp(
            pointing_table.getcol("TARGET"), pointing_mjds, len(ants)
        )
    )
    if use_source_offset_column:
        # Get the commanded on-sky offsets in cross-elevation
        # and elevation from the SOURCE_OFFSET column. They are in
        # units of radians and have shape (ntimes*nants, 1, 2).
        # Convert the shape to (nants, ntimes, 2)
        on_sky_offsets = numpy.squeeze(pointing_table.getcol("SOURCE_OFFSET"))
        on_sky_offsets = numpy.degrees(
            on_sky_offsets.reshape(len(ants), len(pointing_mjds), 2)
        )
    else:
        # Read the commanded pointings from the TARGET column and
        # convert to commanded pointings relative to the target
        # in cross-elevation and elevation
        on_sky_offsets = _calc_on_sky_offsets_from_target_column(
            pointing_table,
            ants,
            pointing_mjds,
            vis.phasecentre.ra.rad,
            vis.phasecentre.dec.rad,
        )

    # Perform the interpolation of the on-sky_offsets by aligning
    # them with visibility timestamps
    on_sky_offsets = interp_timestamps(
        on_sky_offsets, pointing_mjds, vis.time.data
    )

    return (
        on_sky_offsets,
        ants,
        middle_timestamp_index,
        middle_timestamp,
        commanded_azel,
        track_duration,
    )


def _read_visibilities(
    msname,
    dd_table,
    spw_table,
    apply_mask=False,
    start_freq=None,
    end_freq=None,
    num_chunks=16,
):
    """
    Extracts parameters from a measurement set required for
    computing the pointing offsets.

    :param msname: Name of Measurement set file.
    :param dd_table: Data description table from MS.
    :param spw_table: Spectral window table from MS.
    :param apply_mask: Apply RFI mask?
    :param start_freq: Starting frequency for selection in MHz.
        If no selection needed, use None
    :param end_freq: Ending frequency for selection in MHz.
        If no selection needed, use None
    :param num_chunks: Number of frequency chunks

    :return: List of Visibility object(s)
    """

    # Get the frequencies
    spw_id = dd_table.getcol("SPECTRAL_WINDOW_ID")[0]
    freqs = spw_table.getcol("CHAN_FREQ")[spw_id] / 1.0e6  # Hz -> MHz
    channels = numpy.arange(len(freqs))

    # Optionally select frequency channels and split frequencies and
    # corresponding channels into chunks
    freqs, channels = select_channels_and_split(
        freqs, channels, start_freq, end_freq, num_chunks
    )

    # Optionally apply static RFI mask
    if apply_mask:
        # Apply RFI mask
        freqs, channels = apply_rfi_mask(
            freqs, channels, start_freq, end_freq, num_chunks
        )

    new_vis_list = []
    for nu, chan in zip(freqs, channels):
        LOG.info("Selected channel numbers are %s to %s", chan[0], chan[-1])
        vis_list = create_visibility_from_ms(
            msname=msname,
            channum=None,
            start_chan=chan[0],
            end_chan=chan[-1],
            ack=False,
            datacolumn="DATA",
            selected_sources=None,
            selected_dds=None,
            average_channels=False,
        )
        if apply_mask:
            # Update Visibilities to ensure the right frequency range
            # is selected when RFI mask is applied. This is to overcome
            # the shortcoming of the Visibility object containing all
            # data in the provided channel range
            indices = [
                numpy.where(frequency == vis_list[0].frequency.data / 1.0e6)
                for frequency in nu
            ]
            indices = [idx[0][0] for idx in indices if idx[0].size > 0]
            new_vis_list.append(
                Visibility.constructor(
                    frequency=nu * 1.0e6,  # MHz -> Hz
                    channel_bandwidth=vis_list[0].channel_bandwidth[indices],
                    phasecentre=vis_list[0].phasecentre,
                    configuration=vis_list[0].configuration,
                    uvw=vis_list[0].uvw.data,
                    time=vis_list[0].time.data,
                    vis=vis_list[0].vis.data[:, :, indices],
                    weight=vis_list[0].weight.data[:, :, indices],
                    integration_time=vis_list[0].integration_time.data,
                    flags=vis_list[0].flags.data[:, :, indices],
                    baselines=vis_list[0].baselines,
                    polarisation_frame=vis_list[
                        0
                    ].visibility_acc.polarisation_frame,
                    source=vis_list[0].source,
                    meta=vis_list[0].meta,
                )
            )
        else:
            new_vis_list.append(vis_list[0])

    return new_vis_list


def _get_reference_pointing(
    middle_timestamp_index_list,
    commanded_azel_list,
    middle_timestamps_list,
    on_sky_offsets_list,
):
    """
    Identify the on-source scan and extract its commanded AzEl
    (reference_commanded_azel) and corresponding middle timestamp
    (reference_timestamp).

    :param middle_timestamp_index_list: List of indices of the middle
        pointing timestamps
    :param commanded_azel_list: list of commanded azel pointing
        at all timestamps for each antenna. List has one
        element per scan.
    :param middle_timestamps_list: list of middle pointing timestamps.
        List has one element per scan.
    :param on_sky_offsets_list: list of on-sky offsets in xel-el,
        for each antenna and time. List has one element per scan.

    :return: reference_commanded_azel, reference_timestamp
    """

    for i, on_sky_offset in enumerate(on_sky_offsets_list):
        # Average each on-sky offsets in time/antennas
        on_sky_offset_xel = abs(on_sky_offset.mean(axis=(0, 1))[0])
        on_sky_offset_el = abs(on_sky_offset.mean(axis=(0, 1))[1])
        if numpy.isclose(
            on_sky_offset_xel.round(1), 0.0, rtol=1e-06
        ) and numpy.isclose(on_sky_offset_el.round(1), 0.0, rtol=1e-06):
            LOG.info(
                "Scan %s identified as the on-source scan so "
                "commanded AzEl and reference time at centre "
                "of the observation can be determined...",
                i + 1,
            )
            reference_timestamp = middle_timestamps_list[i]
            reference_commanded_azel = commanded_azel_list[i][
                :, middle_timestamp_index_list[i]
            ]
        else:
            # Store the median timestamp
            reference_timestamp = numpy.median(
                numpy.array(middle_timestamps_list)
            )

            # Cannot provide the commanded pointing
            reference_commanded_azel = None

    if reference_commanded_azel is None:
        LOG.warning(
            "On-source scan could not be identified. Hence the "
            "commanded AzEl and reference time at centre of "
            "the observation could not be determined!"
        )

    return reference_commanded_azel, reference_timestamp


def _split_two_dish_modes_ms(msfiles):
    """
    Split MS files into two groups for two-dish scan mode based on
    the TIME in POINTING subtable.
    The two scan orders are defined as:
      1. dish A on source and dish B doing the scans
      2. dish B on source and dish A doing the scans

    :param msfiles: Lists of measurement sets file names

    return sorted Lists of measurement sets file names
    """

    time_file_dict = {}
    for file in msfiles:
        pointing_tab = table(f"{file}/POINTING", ack=False)
        timestamp = pointing_tab.getcol("TIME")[0]
        time_file_dict[timestamp] = file
    sorted_files = dict(sorted(time_file_dict.items()))
    files_list = list(sorted_files.values())
    pos = len(files_list) // 2

    return [files_list[:pos], files_list[pos:]]


[docs] def read_batch_visibilities( ms_all_files, apply_mask=False, use_source_offset_column=False, rfi_filename=None, start_freq=None, end_freq=None, fit_to_vis=False, num_chunks=16, ): """ Extracts parameters from multiple measurement sets required for computing the pointing offsets. :param ms_all_files: List of all measurement set files to be loaded :param apply_mask: Apply RFI mask? :param use_source_offset_column: Read on-sky offsets in cross-elevation and elevation from the SOURCE_OFFSET column of the pointing sub-tables? If False, antenna pointings in azimuth and elevation are read from the TARGET column of the pointing table and the on-sky offsets in cross-elevation and elevation are computed thereafter. :param rfi_filename: Name of RFI mask file (in .h5) :param start_freq: Start frequency for selection in MHz. If no selection needed, use None :param end_freq: End frequency for selection in MHz. If no selection needed, use None :param fit_to_vis: Fit primary beam to cross-correlation visibility amplitudes instead of the antenna gain amplitudes :return: List of Visibility, list of on-sky offsets in cross-elevation and elevation, list of katpoint Antennas, the reference timestamp, commanded pointings at all timestamps, and list of track duration """ if apply_mask and rfi_filename is None: raise ValueError("RFI File is required!!") if apply_mask: # Download the channel-dependent static RFI mask just once LOG.info("Pulling static RFI mask from telmodel store...") tmdata = TMData() try: tmdata.get(rfi_filename).copy("rfi_mask.h5") LOG.info("Static RFI mask downloaded") except KeyError as exc: raise FileNotFoundError("RFI mask file not found!") from exc vis_lists = [] on_sky_offsets_lists = [] reference_timestamps_lists = [] reference_commanded_azel_lists = [] ants_lists = [] track_duration_lists = [] if fit_to_vis: ms_all_files = _split_two_dish_modes_ms(ms_all_files) else: ms_all_files = [ms_all_files] for msfiles in ms_all_files: vis_list = [] on_sky_offsets_list = [] middle_timestamp_index_list = [] middle_timestamps_list = [] commanded_azel_list = [] track_duration_list = [] for msname in msfiles: LOG.info("Reading MS: %s", msname) dd_table, spw_table, pointing_table = _load_ms_tables(msname) vis = _read_visibilities( msname, dd_table, spw_table, apply_mask, start_freq, end_freq, num_chunks, ) ( on_sky_offsets, ants, middle_timestamp_index, middle_timestamp, commanded_azel, track_duration, ) = _read_pointing_data( pointing_table, vis[0], # only need the configuration and time # so any vis can be used use_source_offset_column=use_source_offset_column, ) vis_list.append(vis) on_sky_offsets_list.append(on_sky_offsets) middle_timestamp_index_list.append(middle_timestamp_index) middle_timestamps_list.append(middle_timestamp) commanded_azel_list.append(commanded_azel) track_duration_list.append(track_duration) ( reference_commanded_azel, reference_timestamp, ) = _get_reference_pointing( middle_timestamp_index_list, commanded_azel_list, middle_timestamps_list, on_sky_offsets_list, ) vis_lists.append(vis_list) on_sky_offsets_lists.append(on_sky_offsets_list) reference_timestamps_lists.append(reference_timestamp) reference_commanded_azel_lists.append(reference_commanded_azel) ants_lists.append(ants) track_duration_lists.append(track_duration_list) return ( vis_lists, on_sky_offsets_lists, ants_lists, reference_timestamps_lists, reference_commanded_azel_lists, track_duration_lists, )