# pylint: disable=too-many-arguments,too-many-locals
# pylint: disable=too-many-positional-arguments
"""
Functions for reading data from Measurement Sets
and pulling static channel-dependent static RFI mask
from the SKA Telescope Model Data Repository
"""
import datetime
import logging
import warnings
import katpoint
import numpy
from casacore.tables import table # pylint: disable=import-error
from ska_sdp_datamodels.visibility import create_visibility_from_ms
from ska_sdp_datamodels.visibility.vis_model import Visibility
from ska_telmodel.data import TMData
from ska_sdp_wflow_pointing_offset.array_data_func import (
apply_rfi_mask,
interp_timestamps,
)
from ska_sdp_wflow_pointing_offset.utils import (
construct_antennas,
select_channels_and_split,
)
warnings.simplefilter(action="ignore", category=FutureWarning)
LOG = logging.getLogger("ska-sdp-pointing-offset")
def _load_ms_tables(msname):
"""
Loads Measurement Set.
:param msname: Measurement set containing visibilities.
:return: Data description table, spectral window, and
pointing sub-tables.
"""
# Get the data description, spectral window, pointing, and field sub-tables
dd_table = table(msname + "::DATA_DESCRIPTION", ack=False)
spw_table = table(msname + "::SPECTRAL_WINDOW", ack=False)
pointing_table = table(msname + "::POINTING", ack=False)
return dd_table, spw_table, pointing_table
def _get_pointing_unix(pointing_mjds):
"""
Converts the pointing timestamps in MJD to UTC seconds
since Unix epoch
:param pointing_mjds: Pointing timestamps in MJD
:return: The pointing timestamps in UTC seconds since
Unix epoch
"""
# MJD Start date: 1858-11-17
base = datetime.datetime(1858, 11, 17)
pointing_unix = numpy.zeros_like(pointing_mjds)
for i, mjd in enumerate(pointing_mjds):
obs_datetime = base + datetime.timedelta(days=mjd / 86400)
pointing_unix[i] = obs_datetime.replace(
tzinfo=datetime.timezone.utc
).timestamp()
return pointing_unix
def _get_middle_timestamp(commanded_azel, pointing_mjds, nants):
"""
Extracts the middle time of a pointing scan, its index, and reshapes
and converts the commanded AzEl pointings to degrees
:param commanded_azel: Commanded AzEl pointing in radians
[nants*ntimes, 1, 2]
:param pointing_mjds: Pointing timestamps in MJD
:param nants: Number of antennas
:return: The index of the middle pointing timestamp, middle pointing
timestamp in MJD and the reshaped commanded AzEl pointing in degrees
"""
# Get the middle timestamp
middle_timestamp_index = numpy.argpartition(
pointing_mjds, len(pointing_mjds) // 2
)[len(pointing_mjds) // 2]
# Reshape commanded pointings to (nants, ntimes, 2) and convert
# to degrees
commanded_azel = numpy.degrees(
commanded_azel.reshape(nants, len(pointing_mjds), 2)
)
middle_timestamp = pointing_mjds[middle_timestamp_index]
return middle_timestamp_index, middle_timestamp, commanded_azel
def _fix_antenna_timestamp_pointing_table(pointing_tab):
"""
Fix pointing table when some antennas are missing data for some timestamps.
We keep the data where each antenna has a timestamp and delete the rest.
:param pointing_tab: Pointing table (CASACORE table)
:return: Fixed Pointing table (CASACORE table)
"""
antid_time = {}
ant_time_cache = []
for row in pointing_tab.iter(["ANTENNA_ID", "TIME"]):
antid, idtime = row.getcol("ANTENNA_ID")[0], row.getcol("TIME")[0]
if not antid_time.get(antid):
antid_time[antid] = set()
antid_time[antid].update([idtime])
ant_time_cache.append((antid, idtime))
overlap_time = set.intersection(*antid_time.values())
overlap_dict = {}
for ovt in overlap_time:
overlap_dict[str(ovt)] = True
select_rowid = []
for rid, (antid, idtime) in enumerate(ant_time_cache):
if overlap_dict.get(str(idtime), False):
select_rowid.append(rid)
fixed_table = pointing_tab.selectrows(select_rowid)
return fixed_table
def _calc_on_sky_offsets_from_target_column(
pointing_table,
ants,
pointing_mjds,
source_ra,
source_dec,
):
"""
Calculates the commanded pointings relative to the target from
all the pointing scans
:param pointing_table: Pointing table (CASACORE table)
:param ants: Number of antennas
:param pointing_mjds: Pointing timestamps in MJD
:param source_ra: Source RA in radians
:param source_dec: Source DEC in radians
:return: The commanded pointings relative to the target in degrees
"""
# Get the commanded pointings in azimuth and elevation from
# the TARGET column. They are in units of radians and have
# shape (ntimes*nants, 1, 2). Convert the shape to
# (nants, ntimes, 2)
commanded_azel = numpy.squeeze(pointing_table.getcol("TARGET"))
commanded_azel = commanded_azel.reshape(len(ants), len(pointing_mjds), 2)
pointing_unix = _get_pointing_unix(pointing_mjds)
on_sky_offsets = numpy.zeros(commanded_azel.shape)
target_radec = katpoint.construct_radec_target(
ra=source_ra, dec=source_dec
)
for i, antenna in enumerate(ants):
# Compute the on-sky offset in cross-elevation and elevation
# with crossel = (commanded_az - target_az) * cos(target_el)
target_az, target_el = target_radec.azel(pointing_unix, antenna)
dcross_el = (commanded_azel[i, :, 0] - target_az) * numpy.cos(
target_el
)
delev = commanded_azel[i, :, 1] - target_el
on_sky_offsets[i] = numpy.column_stack(
(numpy.degrees(dcross_el), numpy.degrees(delev))
)
return on_sky_offsets
def _calc_track_duration(pointing_mjds):
"""
Calculates how long each scan position was tracked for in seconds.
We refer to this length of observation as track duration. Multiple
timestamps are required to calculate the integration time and hence
the track duration. If the time sample is just a single value, None
is returned as the track duration because it cannot be calculated.
:param pointing_mjds: Pointing timestamps in MJD for all
antennas
:return: The pointing timestamps in MJD seconds and track
duration in seconds
"""
# The timestamps for only a single antenna is needed
pointing_mjds = pointing_mjds[0]
# observation length = integration time * number of time samples
if len(pointing_mjds) == 1:
LOG.warning(
"Pointing timestamps has length 1 so the integration "
"time cannot be determined so track_duration is set to None"
)
track_duration = None
else:
int_time = pointing_mjds[1] - pointing_mjds[0]
track_duration = int_time * len(pointing_mjds)
return pointing_mjds, track_duration
def _read_pointing_data(pointing_table, vis, use_source_offset_column=False):
"""
Read pointing data from MS pointing table.
:param pointing_table: casacore table, Pointing table
:param vis: Visibility object from the same MS
:param use_source_offset_column: Read the on-sky offsets
in cross-elevation and elevation from the SOURCE_OFFSET
column of the pointing sub-tables?
:return: on-sky offsets in cross-elevation
and elevation, list of katpoint Antennas, pointing timestamps
in MJD, the middle timestamp and the commanded azel at
the middle timestamp, and track duration in seconds
"""
# Build katpoint Antenna from antenna configuration
antenna_positions = vis.configuration.data_vars["xyz"].data
antenna_diameters = vis.configuration.data_vars["diameter"].data
antenna_names = vis.configuration.data_vars["names"].data
ants = construct_antennas(
xyz=antenna_positions,
diameter=antenna_diameters,
station=antenna_names,
)
# Get the antenna pointing timestamps. Has shape (ntimes*nants) so
# we can select the timestamps for just one antenna
pointing_mjds = pointing_table.getcol("TIME")
# When some antennas are missing data at some timestamp(s),
# it then means the number of rows (nants*ntimestamps) will be
# fewer than expected so the reshaping would cause a ValueError.
try:
pointing_mjds = pointing_mjds.reshape(
len(ants), pointing_mjds.shape[0] // len(ants)
)
except ValueError:
LOG.warning(
"Inconsistent antenna-time shape in the pointing table."
"Dropping pointing data where some antennas are missing "
"data for a timestamp."
)
pointing_table = _fix_antenna_timestamp_pointing_table(pointing_table)
pointing_mjds = pointing_table.getcol("TIME")
pointing_mjds = pointing_mjds.reshape(
len(ants), pointing_mjds.shape[0] // len(ants)
)
pointing_mjds, track_duration = _calc_track_duration(pointing_mjds)
middle_timestamp_index, middle_timestamp, commanded_azel = (
_get_middle_timestamp(
pointing_table.getcol("TARGET"), pointing_mjds, len(ants)
)
)
if use_source_offset_column:
# Get the commanded on-sky offsets in cross-elevation
# and elevation from the SOURCE_OFFSET column. They are in
# units of radians and have shape (ntimes*nants, 1, 2).
# Convert the shape to (nants, ntimes, 2)
on_sky_offsets = numpy.squeeze(pointing_table.getcol("SOURCE_OFFSET"))
on_sky_offsets = numpy.degrees(
on_sky_offsets.reshape(len(ants), len(pointing_mjds), 2)
)
else:
# Read the commanded pointings from the TARGET column and
# convert to commanded pointings relative to the target
# in cross-elevation and elevation
on_sky_offsets = _calc_on_sky_offsets_from_target_column(
pointing_table,
ants,
pointing_mjds,
vis.phasecentre.ra.rad,
vis.phasecentre.dec.rad,
)
# Perform the interpolation of the on-sky_offsets by aligning
# them with visibility timestamps
on_sky_offsets = interp_timestamps(
on_sky_offsets, pointing_mjds, vis.time.data
)
return (
on_sky_offsets,
ants,
middle_timestamp_index,
middle_timestamp,
commanded_azel,
track_duration,
)
def _read_visibilities(
msname,
dd_table,
spw_table,
apply_mask=False,
start_freq=None,
end_freq=None,
num_chunks=16,
):
"""
Extracts parameters from a measurement set required for
computing the pointing offsets.
:param msname: Name of Measurement set file.
:param dd_table: Data description table from MS.
:param spw_table: Spectral window table from MS.
:param apply_mask: Apply RFI mask?
:param start_freq: Starting frequency for selection in MHz.
If no selection needed, use None
:param end_freq: Ending frequency for selection in MHz.
If no selection needed, use None
:param num_chunks: Number of frequency chunks
:return: List of Visibility object(s)
"""
# Get the frequencies
spw_id = dd_table.getcol("SPECTRAL_WINDOW_ID")[0]
freqs = spw_table.getcol("CHAN_FREQ")[spw_id] / 1.0e6 # Hz -> MHz
channels = numpy.arange(len(freqs))
# Optionally select frequency channels and split frequencies and
# corresponding channels into chunks
freqs, channels = select_channels_and_split(
freqs, channels, start_freq, end_freq, num_chunks
)
# Optionally apply static RFI mask
if apply_mask:
# Apply RFI mask
freqs, channels = apply_rfi_mask(
freqs, channels, start_freq, end_freq, num_chunks
)
new_vis_list = []
for nu, chan in zip(freqs, channels):
LOG.info("Selected channel numbers are %s to %s", chan[0], chan[-1])
vis_list = create_visibility_from_ms(
msname=msname,
channum=None,
start_chan=chan[0],
end_chan=chan[-1],
ack=False,
datacolumn="DATA",
selected_sources=None,
selected_dds=None,
average_channels=False,
)
if apply_mask:
# Update Visibilities to ensure the right frequency range
# is selected when RFI mask is applied. This is to overcome
# the shortcoming of the Visibility object containing all
# data in the provided channel range
indices = [
numpy.where(frequency == vis_list[0].frequency.data / 1.0e6)
for frequency in nu
]
indices = [idx[0][0] for idx in indices if idx[0].size > 0]
new_vis_list.append(
Visibility.constructor(
frequency=nu * 1.0e6, # MHz -> Hz
channel_bandwidth=vis_list[0].channel_bandwidth[indices],
phasecentre=vis_list[0].phasecentre,
configuration=vis_list[0].configuration,
uvw=vis_list[0].uvw.data,
time=vis_list[0].time.data,
vis=vis_list[0].vis.data[:, :, indices],
weight=vis_list[0].weight.data[:, :, indices],
integration_time=vis_list[0].integration_time.data,
flags=vis_list[0].flags.data[:, :, indices],
baselines=vis_list[0].baselines,
polarisation_frame=vis_list[
0
].visibility_acc.polarisation_frame,
source=vis_list[0].source,
meta=vis_list[0].meta,
)
)
else:
new_vis_list.append(vis_list[0])
return new_vis_list
def _get_reference_pointing(
middle_timestamp_index_list,
commanded_azel_list,
middle_timestamps_list,
on_sky_offsets_list,
):
"""
Identify the on-source scan and extract its commanded AzEl
(reference_commanded_azel) and corresponding middle timestamp
(reference_timestamp).
:param middle_timestamp_index_list: List of indices of the middle
pointing timestamps
:param commanded_azel_list: list of commanded azel pointing
at all timestamps for each antenna. List has one
element per scan.
:param middle_timestamps_list: list of middle pointing timestamps.
List has one element per scan.
:param on_sky_offsets_list: list of on-sky offsets in xel-el,
for each antenna and time. List has one element per scan.
:return: reference_commanded_azel, reference_timestamp
"""
for i, on_sky_offset in enumerate(on_sky_offsets_list):
# Average each on-sky offsets in time/antennas
on_sky_offset_xel = abs(on_sky_offset.mean(axis=(0, 1))[0])
on_sky_offset_el = abs(on_sky_offset.mean(axis=(0, 1))[1])
if numpy.isclose(
on_sky_offset_xel.round(1), 0.0, rtol=1e-06
) and numpy.isclose(on_sky_offset_el.round(1), 0.0, rtol=1e-06):
LOG.info(
"Scan %s identified as the on-source scan so "
"commanded AzEl and reference time at centre "
"of the observation can be determined...",
i + 1,
)
reference_timestamp = middle_timestamps_list[i]
reference_commanded_azel = commanded_azel_list[i][
:, middle_timestamp_index_list[i]
]
else:
# Store the median timestamp
reference_timestamp = numpy.median(
numpy.array(middle_timestamps_list)
)
# Cannot provide the commanded pointing
reference_commanded_azel = None
if reference_commanded_azel is None:
LOG.warning(
"On-source scan could not be identified. Hence the "
"commanded AzEl and reference time at centre of "
"the observation could not be determined!"
)
return reference_commanded_azel, reference_timestamp
def _split_two_dish_modes_ms(msfiles):
"""
Split MS files into two groups for two-dish scan mode based on
the TIME in POINTING subtable.
The two scan orders are defined as:
1. dish A on source and dish B doing the scans
2. dish B on source and dish A doing the scans
:param msfiles: Lists of measurement sets file names
return sorted Lists of measurement sets file names
"""
time_file_dict = {}
for file in msfiles:
pointing_tab = table(f"{file}/POINTING", ack=False)
timestamp = pointing_tab.getcol("TIME")[0]
time_file_dict[timestamp] = file
sorted_files = dict(sorted(time_file_dict.items()))
files_list = list(sorted_files.values())
pos = len(files_list) // 2
return [files_list[:pos], files_list[pos:]]
[docs]
def read_batch_visibilities(
ms_all_files,
apply_mask=False,
use_source_offset_column=False,
rfi_filename=None,
start_freq=None,
end_freq=None,
fit_to_vis=False,
num_chunks=16,
):
"""
Extracts parameters from multiple measurement sets required
for computing the pointing offsets.
:param ms_all_files: List of all measurement set files to be loaded
:param apply_mask: Apply RFI mask?
:param use_source_offset_column: Read on-sky offsets in
cross-elevation and elevation from the SOURCE_OFFSET
column of the pointing sub-tables? If False, antenna
pointings in azimuth and elevation are read from the
TARGET column of the pointing table and the on-sky
offsets in cross-elevation and elevation are computed
thereafter.
:param rfi_filename: Name of RFI mask file (in .h5)
:param start_freq: Start frequency for selection in MHz.
If no selection needed, use None
:param end_freq: End frequency for selection in MHz.
If no selection needed, use None
:param fit_to_vis: Fit primary beam to cross-correlation visibility
amplitudes instead of the antenna gain amplitudes
:return: List of Visibility, list of on-sky offsets in
cross-elevation and elevation, list of katpoint Antennas,
the reference timestamp, commanded pointings at
all timestamps, and list of track duration
"""
if apply_mask and rfi_filename is None:
raise ValueError("RFI File is required!!")
if apply_mask:
# Download the channel-dependent static RFI mask just once
LOG.info("Pulling static RFI mask from telmodel store...")
tmdata = TMData()
try:
tmdata.get(rfi_filename).copy("rfi_mask.h5")
LOG.info("Static RFI mask downloaded")
except KeyError as exc:
raise FileNotFoundError("RFI mask file not found!") from exc
vis_lists = []
on_sky_offsets_lists = []
reference_timestamps_lists = []
reference_commanded_azel_lists = []
ants_lists = []
track_duration_lists = []
if fit_to_vis:
ms_all_files = _split_two_dish_modes_ms(ms_all_files)
else:
ms_all_files = [ms_all_files]
for msfiles in ms_all_files:
vis_list = []
on_sky_offsets_list = []
middle_timestamp_index_list = []
middle_timestamps_list = []
commanded_azel_list = []
track_duration_list = []
for msname in msfiles:
LOG.info("Reading MS: %s", msname)
dd_table, spw_table, pointing_table = _load_ms_tables(msname)
vis = _read_visibilities(
msname,
dd_table,
spw_table,
apply_mask,
start_freq,
end_freq,
num_chunks,
)
(
on_sky_offsets,
ants,
middle_timestamp_index,
middle_timestamp,
commanded_azel,
track_duration,
) = _read_pointing_data(
pointing_table,
vis[0], # only need the configuration and time
# so any vis can be used
use_source_offset_column=use_source_offset_column,
)
vis_list.append(vis)
on_sky_offsets_list.append(on_sky_offsets)
middle_timestamp_index_list.append(middle_timestamp_index)
middle_timestamps_list.append(middle_timestamp)
commanded_azel_list.append(commanded_azel)
track_duration_list.append(track_duration)
(
reference_commanded_azel,
reference_timestamp,
) = _get_reference_pointing(
middle_timestamp_index_list,
commanded_azel_list,
middle_timestamps_list,
on_sky_offsets_list,
)
vis_lists.append(vis_list)
on_sky_offsets_lists.append(on_sky_offsets_list)
reference_timestamps_lists.append(reference_timestamp)
reference_commanded_azel_lists.append(reference_commanded_azel)
ants_lists.append(ants)
track_duration_lists.append(track_duration_list)
return (
vis_lists,
on_sky_offsets_lists,
ants_lists,
reference_timestamps_lists,
reference_commanded_azel_lists,
track_duration_lists,
)