import asyncio
import bisect
import collections
import logging
import os
import sys
import time
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
AsyncIterable,
Callable,
ContextManager,
Dict,
Iterable,
Optional,
OrderedDict,
Sequence,
Tuple,
Union,
)
import numpy as np
from astropy.coordinates import Angle
from casacore import tables
from realtime.receive.core import baseline_utils, time_utils
from realtime.receive.core.antenna import Antenna
from realtime.receive.core.baselines import Baselines
from realtime.receive.core.channel_range import ChannelRange
from realtime.receive.core.pointing import Pointing
from realtime.receive.core.scan import (
Beam,
Channels,
Field,
FrequencyType,
PhaseDirection,
Polarisations,
Scan,
ScanType,
SpectralWindow,
StokesType,
)
logger = logging.getLogger(__name__)
# pylint: disable=too-many-lines,missing-function-docstring
@dataclass(frozen=True)
class TensorRef:
"""A reference to an ndarray in memory"""
oid: bytes
data: np.ndarray
@dataclass(frozen=True)
class MSScan:
"""
Relational scan model as supported by measurement set v2
"""
scan_number: int
beams: OrderedDict[str, Beam]
polarisations: OrderedDict[str, Polarisations]
fields: OrderedDict[str, Field]
channels: OrderedDict[str, Channels]
spectral_windows: OrderedDict[str, SpectralWindow]
datadesc: Dict[Tuple[str, str], int] # pol_name, sw_id
@classmethod
def from_scan(cls, scan: Scan):
"""
Infers the measurement set relational model for a complete
ska ScanType (Naive traversal of beams would otherwise create
duplicate table entries without this).
"""
beams = collections.OrderedDict()
datadesc = {}
polarizations = collections.OrderedDict()
fields = collections.OrderedDict()
channels = collections.OrderedDict()
spectral_windows = collections.OrderedDict()
dd_id = 0
for b in scan.scan_type.beams:
beams[b.beam_id] = b
fields[b.field.field_id] = b.field
polarizations[b.polarisations.polarisation_id] = b.polarisations
channels[b.channels.channels_id] = b.channels
for sw in b.channels.spectral_windows:
spectral_windows[sw.spectral_window_id] = sw
# Generate a lookuptable only for pol-sw pairs
# that exist in beams
dd = (b.polarisations.polarisation_id, sw.spectral_window_id)
if dd not in datadesc:
datadesc[dd] = dd_id
dd_id += 1
return cls(
scan.scan_number,
beams,
polarizations,
fields,
channels,
spectral_windows,
datadesc,
)
@classmethod
def from_main_table(cls, maintable: tables.table):
"""
Infers the measurement set relational model as best as
possible from measurement set tables. Limitations for MSv2 include:
* Beam function is lost
* Channels created per beam
* Channels channels_id is lost
* Polarization name is lost
"""
scan_number = maintable.getcol("SCAN_NUMBER")[0]
beams = collections.OrderedDict()
polarizations = collections.OrderedDict()
fields = collections.OrderedDict()
channels = collections.OrderedDict()
spectral_windows = collections.OrderedDict()
poltable = _subtable(maintable, "POLARIZATION", readonly=True)
for row in range(poltable.nrows()):
stokes = poltable.getcell("CORR_TYPE", row)
# pol name is lost
polarizations[str(row)] = Polarisations(
",".join([StokesType(stoke).name for stoke in stokes]),
[StokesType(stoke) for stoke in stokes],
)
fieldtable = _subtable(maintable, "FIELD", readonly=True)
units = fieldtable.col("PHASE_DIR").getkeyword("QuantumUnits")
for row in range(fieldtable.nrows()):
name = fieldtable.getcell("NAME", row)
phase_dir = fieldtable.getcell("PHASE_DIR", row).swapaxes(0, 1)
fields[name] = Field(
name,
PhaseDirection(
Angle(phase_dir[0], units[0]),
Angle(phase_dir[1], units[1]),
"",
),
)
swtable = _subtable(maintable, "SPECTRAL_WINDOW", readonly=True)
for row in range(swtable.nrows()):
name = swtable.getcell("NAME", row)
strides = np.unique(
swtable.getcell("CHAN_WIDTH", row) / swtable.getcell("RESOLUTION", row)
)
if len(strides) == 1 and strides[0].is_integer():
stride = int(strides[0])
else:
logger.warning("Detected malformed strides in MS: %s. Defaulting to 1.", strides)
stride = 1
spectral_windows[name] = SpectralWindow(
name,
count=swtable.getcell("NUM_CHAN", row),
start=0,
freq_min=swtable.getcell("CHAN_FREQ", row)[0]
- swtable.getcell("RESOLUTION", row)[0] / 2,
freq_max=swtable.getcell("CHAN_FREQ", row)[-1]
+ swtable.getcell("RESOLUTION", row)[-1] / 2,
stride=stride,
)
# Extract beams
# cach of distinct field, spectral window, polarization tuples
swbeamtable: tables.table = tables.taql(
(
"SELECT DISTINCT "
"FIELD_ID, "
"[SELECT SPECTRAL_WINDOW_ID FROM ::DATA_DESCRIPTION][DATA_DESC_ID] as SPECTRAL_WINDOW_ID, "
"[SELECT POLARIZATION_ID FROM ::DATA_DESCRIPTION][DATA_DESC_ID] as POLARIZATION_ID "
"FROM $1"
),
tables=[maintable],
locals={""},
)
# TODO: In MSv2 a beam is approximated as a unique combination of FIELD_ID and
# POLARIZATION_ID, any matching SPECTRAL_WINDOW_ID will be associated to a single beam.
beamtable: tables.table = tables.taql(
"SELECT DISTINCT FIELD_ID, POLARIZATION_ID FROM $1",
tables=[swbeamtable],
locals={""},
)
def swquery(field_idx: int, polarization_idx: int):
"""Query for all spectral windows associated with a single beam"""
return (
"SELECT DISTINCT SPECTRAL_WINDOW_ID "
"FROM $1 "
f"WHERE FIELD_ID == {field_idx} AND POLARIZATION_ID == {polarization_idx}"
)
polarray = list(polarizations.values())
fieldarray = list(fields.values())
swarray = list(spectral_windows.values())
for beam_id in range(beamtable.nrows()):
beam_id_str = str(beam_id)
field_idx = beamtable.getcell("FIELD_ID", beam_id)
pol_idx = beamtable.getcell("POLARIZATION_ID", beam_id)
beam_spectral_windows: tables.table = tables.taql(
swquery(field_idx, pol_idx),
tables=[swbeamtable],
locals={""},
)
# TODO: channels is not stored in MSv2 so this will create
# a unique channels container per beam
channels[beam_id_str] = Channels(
beam_id_str,
[
swarray[beam_spectral_windows.getcell("SPECTRAL_WINDOW_ID", row)]
for row in range(beam_spectral_windows.nrows())
],
)
beams[beam_id_str] = Beam(
beam_id_str,
"missing",
channels[beam_id_str],
polarray[pol_idx],
fieldarray[field_idx],
)
return cls(
scan_number,
beams,
polarizations,
fields,
channels,
spectral_windows,
datadesc={},
)
def _subtable(t: tables.table, name: str, readonly=True):
return tables.table(t.getkeyword(name), readonly=readonly, ack=False)
def _fill_observation_subtable(ms):
observation = _subtable(ms, "OBSERVATION", readonly=False)
username = os.getenv("USERNAME", "Unknown")
observation.addrows(1)
def write(col, v):
observation.putcell(col, 0, v)
write("SCHEDULE", [""])
write("PROJECT", "")
write("OBSERVER", username)
write("TELESCOPE_NAME", "cbf-sdp-emulator")
write("TIME_RANGE", (0, 0))
def _fill_antenna_subtable(ms, antennas: Sequence[Antenna]):
num_stations = len(antennas)
antenna_names = np.array([antenna.name for antenna in antennas])
antenna_positions = np.array([antenna.pos for antenna in antennas])
antenna_dish_diameters = np.array([antenna.dish_diameter for antenna in antennas])
antenna_flags = np.repeat(False, num_stations)
antenna_mounts = np.repeat("FIXED", num_stations)
antenna_t = _subtable(ms, "ANTENNA", readonly=False)
antenna_t.addrows(num_stations)
antenna_t.putcol("POSITION", antenna_positions)
antenna_t.putcol("DISH_DIAMETER", antenna_dish_diameters)
antenna_t.putcol("FLAG_ROW", antenna_flags)
antenna_t.putcol("MOUNT", antenna_mounts)
antenna_t.putcol("NAME", antenna_names)
antenna_t.putcol("STATION", antenna_names)
def _fill_history_subtable(ms):
history = _subtable(ms, "HISTORY", readonly=False)
history.addrows(1)
now = time.time()
now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now))
now_mjd = time_utils.unix_to_mjd(now)
def write(col, v):
history.putcell(col, 0, v)
write("MESSAGE", f"Measurement Set created at {now_str}")
write("APPLICATION", "cbf-sdp-emulator")
write("PRIORITY", "INFO")
write("ORIGIN", "cbf-sdp-emulator")
write("TIME", now_mjd)
write("OBSERVATION_ID", -1)
write("APP_PARAMS", sys.argv[1:])
write("CLI_COMMAND", sys.argv)
def _fill_feed_subtable(
ms: tables.table,
num_stations: int,
beams: OrderedDict[str, Beam],
spectral_windows: OrderedDict[str, SpectralWindow],
):
feed = _subtable(ms, "FEED", readonly=False)
num_receptors = 2
sw_names = list(spectral_windows.keys())
for beam_idx, beam in enumerate(beams.values()):
for sw in beam.channels.spectral_windows:
rownr = feed.nrows()
feed.addrows(num_stations)
feed.putcol("ANTENNA_ID", np.arange(num_stations), startrow=rownr)
feed.putcol(
"BEAM_ID",
np.repeat(beam_idx, num_stations),
startrow=rownr,
)
feed.putcol("FEED_ID", np.zeros(num_stations), startrow=rownr)
feed.putcol(
"SPECTRAL_WINDOW_ID",
np.repeat(sw_names.index(sw.spectral_window_id), num_stations),
startrow=rownr,
)
feed.putcol(
"BEAM_OFFSET",
np.zeros((num_stations, 2, num_receptors)),
startrow=rownr,
)
feed.putcol(
"NUM_RECEPTORS",
np.repeat(num_receptors, num_stations),
startrow=rownr,
)
feed.putcol(
"POLARIZATION_TYPE",
np.repeat(np.array([["X", "Y"]]), num_stations, axis=0),
startrow=rownr,
)
feed.putcol(
"POL_RESPONSE",
np.zeros(
(num_stations, num_receptors, num_receptors),
dtype="complex",
),
startrow=rownr,
)
feed.putcol(
"RECEPTOR_ANGLE",
np.zeros((num_stations, num_receptors)),
startrow=rownr,
)
rownr += num_stations
def _fill_polarizations_subtable(ms: tables.table, pols: Sequence[Polarisations]):
poltable = _subtable(ms, "POLARIZATION", readonly=False)
for pol in pols:
pol_id = poltable.nrows()
poltable.addrows(1)
poltable.putcell(
"CORR_TYPE",
pol_id,
np.array([s.value for s in pol.correlation_type]),
)
poltable.putcell(
"CORR_PRODUCT",
pol_id,
np.array([s.product for s in pol.correlation_type]),
)
poltable.putcell("FLAG_ROW", pol_id, False)
poltable.putcell("NUM_CORR", pol_id, len(pol.correlation_type))
def _fill_spectral_window_subtable(ms: tables.table, spectral_windows: Sequence[SpectralWindow]):
swtable = _subtable(ms, "SPECTRAL_WINDOW", readonly=False)
for sw in spectral_windows:
sw_id = swtable.nrows()
swtable.addrows(1)
swtable.putcell("CHAN_FREQ", sw_id, sw.frequencies)
swtable.putcell("CHAN_WIDTH", sw_id, np.repeat(sw.channel_width, sw.count))
swtable.putcell("EFFECTIVE_BW", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
swtable.putcell("FLAG_ROW", sw_id, 0)
swtable.putcell("FREQ_GROUP", sw_id, 0)
swtable.putcell("FREQ_GROUP_NAME", sw_id, "")
swtable.putcell("IF_CONV_CHAIN", sw_id, 0)
swtable.putcell("MEAS_FREQ_REF", sw_id, FrequencyType.TOPO.value)
swtable.putcell("NAME", sw_id, sw.spectral_window_id)
swtable.putcell("NET_SIDEBAND", sw_id, 0)
swtable.putcell("NUM_CHAN", sw_id, sw.count)
swtable.putcell("REF_FREQUENCY", sw_id, sw.freq_min)
swtable.putcell("RESOLUTION", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
swtable.putcell("TOTAL_BANDWIDTH", sw_id, sw.freq_max - sw.freq_min)
def _fill_data_description_subtable(
ms: tables.table,
beams: Sequence[Beam],
spectral_windows: OrderedDict[str, SpectralWindow],
polarisations: OrderedDict[str, Polarisations],
):
ddtable = _subtable(ms, "DATA_DESCRIPTION", readonly=False)
sw_ids = list(spectral_windows.keys())
pol_names = list(polarisations.keys())
for beam in beams:
for sw in beam.channels.spectral_windows:
dd_id = ddtable.nrows()
ddtable.addrows(1)
ddtable.putcell(
"POLARIZATION_ID",
dd_id,
pol_names.index(beam.polarisations.polarisation_id),
)
ddtable.putcell(
"SPECTRAL_WINDOW_ID",
dd_id,
sw_ids.index(sw.spectral_window_id),
)
def _fill_state_subtable(ms: tables.table):
stable = _subtable(ms, "STATE", readonly=False)
stable.addrows(1)
def _fill_processor_subtable(ms: tables.table):
stable = _subtable(ms, "PROCESSOR", readonly=False)
stable.addrows(1)
def _fill_field_subtable(ms: tables.table, fields: Sequence[Field]):
"""
Writes a new entry to the field subtable.
YAN-1249 The FIELD table also needs the REFERENCE_DIR and
the DELAY_DIR.
PHASE_DIR - the direction of the phase center
DELAY_DIR - the direction the correlator delays point to.
REFERENCE_DIR - in single dish mode used if position switching is
being used - for interfermeters is the same as the PHASE and or DELAY
Stephen Ord (21-Apr-2023)
REF: https://casa.nrao.edu/Memos/229.html
The SDP Telmodel only has a PHASE_DIR in the FIELD schema
Technical Debt - we are assuming the DELAY/PHASE and reference DIR
are the same.
YAN-1324 has been raised to determine if this is correct - or
if an architecture change is required.
"""
fieldtable = tables.table(ms.getkeyword("FIELD"), readonly=False)
for field in fields:
rownr = fieldtable.nrows()
fieldtable.addrows(1)
fieldtable.putcell(
"PHASE_DIR",
rownr,
np.array(
[
np.reshape(field.phase_dir.ra.rad, -1),
np.reshape(field.phase_dir.dec.rad, -1),
]
).swapaxes(0, 1),
)
fieldtable.putcell(
"DELAY_DIR",
rownr,
np.array(
[
np.reshape(field.phase_dir.ra.rad, -1),
np.reshape(field.phase_dir.dec.rad, -1),
]
).swapaxes(0, 1),
)
fieldtable.putcell(
"REFERENCE_DIR",
rownr,
np.array(
[
np.reshape(field.phase_dir.ra.rad, -1),
np.reshape(field.phase_dir.dec.rad, -1),
]
).swapaxes(0, 1),
)
fieldtable.putcell("NAME", rownr, field.field_id)
class Mode(Enum):
"""MeasurementSet open modes"""
READONLY = 0
READWRITE = 1
CREATE = 2
[docs]
class MeasurementSet(ContextManager):
"""
A simple wrapper around a Measurement Set.
Writing is finalized during _t destructor.
"""
_t: Optional[tables.table]
_name: str
_readonly: bool
_use_plasmastman: bool
_plasma_socket: Optional[str]
_antennas: Optional[Sequence[Antenna]]
_model: MSScan
@property
def name(self) -> str:
return self._name
@property
def t(self) -> tables.table:
assert self._t is not None
return self._t
@property
def num_rows(self):
assert self._t is not None
return self._t.nrows()
@property
def num_stations(self):
return self._num_stations
@property
def num_channels(self):
"""Gets the number of channels used by the DATA column"""
return self._num_channels
@property
def channel_range(self):
warnings.warn(
"This actually returns the channel indexes that are available "
"to write to in the data. For a true insight into the channels "
"in the data look at the spectral windows returned by "
"calculate_scan_type()",
DeprecationWarning,
)
return ChannelRange(0, self.num_channels)
@property
def num_pols(self):
"""Gets the number of polarizations used by the DATA column"""
return self._num_pols
def calculate_scan_type(self):
"""
Gets the SKA Scan Model from the MeasurementSet. (MS tables
only store a subset of the scan model currently thus this
should only be used for testing)
"""
return ScanType("0", beams=list(self._model.beams.values()))
@classmethod
def open(cls, name: str, mode: Mode = Mode.READONLY):
"""Opens a MeasurementSet"""
return cls(name, mode=mode)
@classmethod
def create(
cls,
name: str,
scan: Scan,
antennas: Sequence[Antenna],
baselines: Baselines,
plasma_socket: Optional[str] = None,
):
"""Creates a new measurement set from scan and tm data"""
return MeasurementSet(
name,
mode=Mode.CREATE,
scan=scan,
antennas=antennas,
baselines=baselines,
plasma_socket=plasma_socket,
)
def __init__(
self,
name: str,
mode: Mode,
scan: Optional[Scan] = None,
antennas: Optional[Sequence[Antenna]] = None,
baselines: Optional[Baselines] = None,
plasma_socket: Optional[str] = None,
):
self._name = name
self._readonly = mode == Mode.READONLY
self._antennas = antennas
self._use_plasmastman = plasma_socket is not None
self._plasma_socket = plasma_socket
if mode in (Mode.READONLY, Mode.READWRITE):
self._open()
elif mode == Mode.CREATE:
if scan is None or antennas is None or baselines is None:
raise ValueError("scan, antennas and baselines required to create an MS")
baseline_utils.validate_antenna_and_baseline_counts(len(antennas), len(baselines))
self._a1 = baselines.antenna1
self._a2 = baselines.antenna2
self._create(scan)
else:
raise ValueError(mode)
def _open(self):
"""
Begins the opening of a MeasurementSet. Must be called once after
constructing with readonly=True.
"""
self._t = tables.table(self._name, readonly=self._readonly, ack=False)
self._num_stations = self._subtable("ANTENNA").nrows()
swindow = self._subtable("SPECTRAL_WINDOW")
self._num_channels: int = int(swindow.getcol("NUM_CHAN")[0])
self._num_pols: int = self._subtable("POLARIZATION").getcol("NUM_CORR")[0]
# relational ms datamodel
self._model = MSScan.from_main_table(self._t)
def _create(self, scan: Scan):
"""
Begins the creation of a MeasurementSet. Must be called once after
constructing with readonly=False.
"""
assert not self._readonly
assert self._antennas is not None
# relational datamodel of scan
self._model = MSScan.from_scan(scan)
self._num_stations = len(self._antennas)
# cached variables (for DATA dimensions)
self._num_channels = scan.scan_type.num_channels
self._num_pols = scan.scan_type.num_pols
if self._use_plasmastman:
self._t = None
else:
self._create_table(self._model)
def _create_table(self, model: MSScan, vis: Optional[TensorRef] = None):
"""
Creates a MeasurementSet folder structure using a TiledColumn Storage Manager
for storing data on the filesystem and populates the table and antenna array
members _t, _a1 and _a2
"""
assert not self._readonly
assert self._antennas is not None
num_baselines = len(self._a1)
# Setting up a specific storage manager only for the DATA column
# for the time being; more can come in the future
if not self._use_plasmastman:
assert vis is None, "vis not required when creating a TiledData measurement set"
data_coldesc = tables.makearrcoldesc(
"DATA",
"",
valuetype="complex",
shape=(self._num_channels, self._num_pols),
datamanagergroup="TiledData",
datamanagertype="TiledShapeStMan",
keywords={"UNIT": "Jy"},
)
data_dminfo = {
"DEFAULTTILESHAPE": np.array(
[
self.num_pols,
self._num_channels,
2 * num_baselines,
],
dtype=np.int32,
),
"MAXIMUMCACHESIZE": 0,
}
else:
assert vis is not None, "vis is required when creating a PlasmaData Measurement Set"
data_coldesc = tables.makearrcoldesc(
"DATA",
"",
valuetype="complex",
shape=(self._num_channels, self._num_pols),
datamanagergroup="TiledData",
datamanagertype="PlasmaStMan",
keywords={"UNIT": "Jy"},
)
data_dminfo = {
"PLASMASOCKET": self._plasma_socket,
"TENSOROBJECTIDS": {"DATA": vis.oid},
}
tabdesc = tables.maketabdesc(
[
data_coldesc,
tables.makearrcoldesc(
"FLAG",
"",
valuetype="bool",
shape=(self._num_channels, self._num_pols),
datamanagergroup="TiledFlag",
datamanagertype="TiledShapeStMan",
),
tables.makearrcoldesc(
"WEIGHT",
"",
valuetype="float",
shape=(self._num_pols,),
datamanagergroup="TiledWeight",
datamanagertype="TiledShapeStMan",
),
tables.makearrcoldesc(
"SIGMA",
"",
valuetype="float",
shape=(self._num_pols,),
datamanagergroup="TiledSigma",
datamanagertype="TiledShapeStMan",
),
]
)
dminfo = tables.makedminfo(
tabdesc,
{
"TiledData": data_dminfo,
"TiledFlag": {
"DEFAULTTILESHAPE": np.array(
[
self.num_pols,
self._num_channels,
2 * num_baselines,
],
dtype=np.int32,
),
"MAXIMUMCACHESIZE": 0,
},
"TiledWeight": {
"DEFAULTTILESHAPE": np.array(
[self._num_pols, 2 * num_baselines], dtype=np.int32
),
"MAXIMUMCACHESIZE": 0,
},
"TiledSigma": {
"DEFAULTTILESHAPE": np.array(
[self._num_pols, 2 * num_baselines], dtype=np.int32
),
"MAXIMUMCACHESIZE": 0,
},
},
)
self._t = tables.default_ms(self._name, tabdesc=tabdesc, dminfo=dminfo)
MeasurementSet._create_pointing_table(self._t)
_fill_observation_subtable(self._t)
_fill_antenna_subtable(self._t, self._antennas)
_fill_polarizations_subtable(self._t, tuple(model.polarisations.values()))
_fill_spectral_window_subtable(self._t, tuple(model.spectral_windows.values()))
_fill_data_description_subtable(
self._t,
tuple(model.beams.values()),
model.spectral_windows,
model.polarisations,
)
_fill_history_subtable(self._t)
_fill_feed_subtable(
self._t,
self._num_stations,
self._model.beams,
self._model.spectral_windows,
)
_fill_field_subtable(self._t, tuple(model.fields.values()))
_fill_state_subtable(self._t)
_fill_processor_subtable(self._t)
@staticmethod
def _create_pointing_table(main_table: tables.table):
def make_pointing_coldesc(name: str, comment: str, ref=None):
measinfo = {"type": "direction"}
if ref:
measinfo["Ref"] = ref
return tables.makearrcoldesc(
name,
0.0,
valuetype="double",
comment=comment,
shape=(1, 2),
datamanagertype="TiledColumnStMan",
keywords={
"QuantumUnits": ["rad", "rad"],
"MEASINFO": measinfo,
},
)
tabdesc = tables.maketabdesc(
[
make_pointing_coldesc("TARGET", "Target direction", "AZEL"),
make_pointing_coldesc("DIRECTION", "Antenna pointing direction", "AZEL"),
make_pointing_coldesc("POINTING_OFFSET", "A priori pointing correction in El/xEl"),
make_pointing_coldesc("SOURCE_OFFSET", "Offset from source in El/xEl"),
],
)
pointing_table = tables.default_ms_subtable(
"POINTING",
name=f"{main_table.name()}/POINTING",
tabdesc=tabdesc,
)
pointing_table.close()
def _subtable(self, name, readonly=True):
assert self._t is not None
return _subtable(self._t, name, readonly=readonly)
def close(self):
if self._t is not None:
self._t.close()
self._t = None # type: ignore
def __enter__(self):
"""context manager protocol"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""context manager protocol"""
self.close()
def read_cell(self, columnname, row: int) -> np.ndarray:
assert self._t is not None
return self._t.getcell(columnname, row)
def read_column(self, columnname, startrow=0, nrow=None) -> np.ndarray:
"""Reads num_rows rows from columnname starting at first_row"""
assert self._t is not None
nrow = nrow or self.num_rows
return self._t.getcol(columnname, startrow=startrow, nrow=nrow)
def read_coords(self, startrow, nrow) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
assert self._t is not None
uvw = self.read_column("UVW", startrow=startrow, nrow=nrow)
return uvw[:, 0], uvw[:, 1], uvw[:, 2]
def read_weight(self, start_row, start_pol, num_pols, num_baselines) -> np.ndarray:
assert self._t is not None
blc = [start_pol]
trc = [start_pol + num_pols - 1]
return self._t.getcolslice("WEIGHT", blc, trc, startrow=start_row, nrow=num_baselines)
def get_row_num(self, payload_time: float) -> int:
"""
Args:
payload_time (float): payload time to lookup in the maintable
row
Returns:
int: the first row number where the time matches
"""
assert self._t is not None
if self._t.nrows() == 0:
return 0
times: np.ndarray = self.read_column("TIME", 0, self._t.nrows())
if 0.0 in times:
logger.warning("read_column has returned an array with anomalous 0.0 entry - retrying")
# finds the insertion point to the left of any matching entry
row = bisect.bisect_left(times, payload_time) # type: ignore
row_interval = bisect.bisect_right(times, times[0]) # type: ignore
return row // row_interval
def write_defaulted_rows(self, num_baselines: int, end_row: int):
"""
Add new rows and override tables with default data
"""
assert self._t is not None
while end_row >= self._t.nrows():
nrows = self._t.nrows()
self._t.addrows(num_baselines)
# Default flags is True to indicate where visibilities are dropped/invalid
row_range = range(nrows, nrows + num_baselines)
self._t.putcell(
"FLAG",
row_range,
np.ones([self._num_channels, self._num_pols]),
)
def write_field_id(self, row: int, field_id: int):
assert self._t is not None
self._t.putcell(
columnname="FIELD_ID",
rownr=row,
value=field_id,
)
def write_datadesc_id(self, row: int, sw_id: str, pol_id: str):
assert self._t is not None
self._t.putcell(
columnname="DATA_DESC_ID",
rownr=row,
value=self._model.datadesc[(pol_id, sw_id)],
)
def write_weight(
self,
start_row: int,
start_pol: int,
num_pols: int,
row_count: int,
weight: Optional[Union[np.ndarray, TensorRef]] = None,
):
assert not self._readonly
assert self._t is not None
if weight is None:
weight = np.ones((row_count, num_pols), dtype=float)
assert isinstance(weight, np.ndarray), "tileddata measurement sets require ndarray weights"
blc = [start_pol]
trc = [start_pol + num_pols - 1]
self._t.putcolslice("WEIGHT", weight, blc, trc, startrow=start_row, nrow=row_count)
def write_flag(
self,
start_row: int,
start_chan: int,
chan_count: int,
row_count: int,
flag: Optional[Union[np.ndarray, TensorRef]] = None,
):
"""
Writes flag column to the main table. If flag is None
then 0 will be written over the row range.
"""
assert not self._readonly
assert self._t is not None
if flag is None:
flag = np.zeros((row_count, chan_count, self.num_pols), dtype=bool)
assert isinstance(flag, np.ndarray), "tileddata measurement sets require ndarray flags"
blc = [start_chan, 0]
trc = [start_chan + chan_count - 1, self.num_pols - 1]
self._t.putcolslice("FLAG", flag, blc, trc, startrow=start_row, nrow=row_count)
def write_sigma(
self,
start_row: int,
start_pol: int,
pol_count: int,
row_count: int,
sigma: Optional[Union[np.ndarray, TensorRef]] = None,
):
assert not self._readonly
assert self._t is not None
if sigma is None:
sigma = np.ones((row_count, pol_count), dtype=float)
assert isinstance(sigma, np.ndarray), "tileddata measurement sets require ndarray sigmas"
blc = [start_pol]
trc = [start_pol + pol_count - 1]
self._t.putcolslice("SIGMA", sigma, blc, trc, startrow=start_row, nrow=row_count)
def write_scan(self, start_row: int, row_count: int, scan_id: int):
"""Writes the scan number"""
assert self._t is not None
self._t.putcol(
"SCAN_NUMBER",
np.repeat(scan_id, row_count),
startrow=start_row,
nrow=row_count,
)
# pylint: disable=unused-argument
def write_coords(self, start_row: int, row_count: int, uvw, interval, exposure, time):
assert not self._readonly
assert self._t is not None
times = np.repeat(time, row_count)
def put_col_slice(col_name, values):
self._t.putcol(col_name, values, startrow=start_row, nrow=row_count) # type: ignore
put_col_slice("UVW", uvw)
put_col_slice("ANTENNA1", self._a1[:row_count])
put_col_slice("ANTENNA2", self._a2[:row_count])
put_col_slice("EXPOSURE", np.repeat(exposure, row_count))
put_col_slice("INTERVAL", np.repeat(interval, row_count))
put_col_slice("TIME", times)
put_col_slice("TIME_CENTROID", times)
def read_vis(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray:
"""Reads a slice of the DATA table
Returns:
vis: visibilities of shape [pol, channels, baselines]
"""
assert self._t is not None
blc = [start_channel, 0]
trc = [start_channel + num_channels - 1, self.num_pols - 1]
return self._t.getcolslice("DATA", blc, trc, startrow=start_row, nrow=num_baselines)
def read_flags(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray:
"""Reads a slice of the FLAG table
Returns:
flag: flags of shape [pol, channels, baselines]
"""
assert self._t is not None
blc = [start_channel, 0]
trc = [start_channel + num_channels - 1, self.num_pols - 1]
return self._t.getcolslice("FLAG", blc, trc, startrow=start_row, nrow=num_baselines)
def write_vis(
self,
start_row: int,
start_chan: int,
chan_count: int,
row_count: int,
vis: Union[np.ndarray, TensorRef],
):
"""Writes visibilities to the DATA column of the maintable"""
assert not self._readonly
if self._use_plasmastman:
if self._t is None:
if isinstance(vis, TensorRef):
# plasma measurement only support being created
# with a single payload.
self._create_table(self._model, vis)
self._ensure_rows(start_row + row_count)
else:
raise ValueError(vis)
else:
raise NotImplementedError("multiple writes to plasma ms not supported")
else:
assert self._t is not None
if isinstance(vis, TensorRef):
vis = vis.data
blc = [start_chan, 0]
trc = [start_chan + chan_count - 1, self.num_pols - 1]
self._t.putcolslice("DATA", vis, blc, trc, startrow=start_row, nrow=row_count)
def write_pointings(self, pointings: Sequence[Pointing]):
table = self._subtable("POINTING", readonly=False)
row_idx = table.nrows()
table.addrows(len(pointings))
def np_from_pointing_field(get_field: Callable[[Pointing], Any], dtype):
return np.fromiter(
(get_field(p) for p in pointings),
count=len(pointings),
dtype=dtype,
)
def np_from_pointing_azel_data(get_azel_data: Callable[[Pointing], tuple[float, float]]):
# Note that the last two dimensions here MUST match the
# dimensions we set in the column & storage manager definition
# in _create_pointing_table()
arr = np_from_pointing_field(get_azel_data, np.dtype((np.float64, 2)))
arr = np.expand_dims(arr, axis=1)
return arr
def putcol(col: str, value: np.ndarray):
table.putcol(
col,
value,
startrow=row_idx,
nrow=len(value),
)
putcol("INTERVAL", np.full(len(pointings), 0.0))
putcol("NUM_POLY", np.full(len(pointings), 0))
putcol("TRACKING", np_from_pointing_field(lambda p: p.tracking, bool))
putcol("ANTENNA_ID", np_from_pointing_field(lambda p: p.antenna_id, int))
putcol("DIRECTION", np_from_pointing_azel_data(lambda p: p.direction))
putcol("TARGET", np_from_pointing_azel_data(lambda p: p.target))
putcol(
"POINTING_OFFSET",
np_from_pointing_azel_data(lambda p: p.pointing_offset),
)
putcol(
"SOURCE_OFFSET",
np_from_pointing_azel_data(lambda p: p.source_offset),
)
times = np_from_pointing_field(lambda p: p.time, np.float64)
times = time_utils.unix_to_mjd(times)
putcol("TIME", times)
putcol("TIME_ORIGIN", times)
max_len = max(len(p.name) for p in pointings)
if max_len:
putcol("NAME", np_from_pointing_field(lambda p: p.name, f"U{max_len}"))
def _ensure_rows(self, nrows):
"""Add missing rows"""
assert self._t is not None
current_rows = self._t.nrows()
if current_rows < nrows:
self._t.addrows(nrows - current_rows)
def calc_baselines(ms: MeasurementSet) -> int:
"""Number of baselines in input
Gets the number of baselines from the measurement set
:param ms: measurement set
:return: number of baselines
"""
row_num = 0
firsttime = ms.read_column("TIME", row_num, 1)
nexttime = firsttime
while nexttime == firsttime:
row_num = row_num + 1
if row_num == ms.num_rows:
break
nexttime = ms.read_column("TIME", row_num, 1)
return row_num
def clamp_num_chan(ms: MeasurementSet, start_chan: int, num_chan: int):
"""
Clamps the num_chan values so the range they define is
contained by the channel dimensionality of the MS.
:param ms: The input measurement set
:param start_chan: The start channel
:param num_chan: The number of channels
:return: The adjusted number of channels
"""
if num_chan == 0:
num_chan = ms.num_channels
if start_chan > ms.num_channels:
raise ValueError(f"start_chan = {start_chan} > num_chan in MS ({ms.num_channels})")
end_chan = start_chan + num_chan
if end_chan > ms.num_channels:
num_chan = ms.num_channels - start_chan
logger.warning(
"start_chan + num_chan = %d > num_chan in MS (%d), reducing num_chan to %d",
end_chan,
ms.num_channels,
num_chan,
)
return num_chan
def vis_mjd_epoch(ms: MeasurementSet):
"""Find the MJD timestamp of the earliest visibility data"""
mjd_time = ms.t.select("TIME").sort("TIME", limit=1).getcell("TIME", 0)
return mjd_time
async def vis_reader(
ms,
start_chan=0,
num_chan=0,
num_timestamps=0,
timestamp_offset=0,
executor=None,
) -> AsyncIterable[tuple[np.ndarray, float]]:
"""Reads the visibilties and timestamps from a measurement set
:param ms: The input measurement set
:param start_chan: The start channel **index**
:param num_chan: The number of channels to read. If 0, all channels are read
:param num_timestamps: The number of timestamps to read. If 0, all timestamps are read
:param timestamp_offset: The offset applied to time for simulating an infinite series of data
:return: Yields a full set of baselines at a time for each timestamp
"""
needs_close = False
if isinstance(ms, str):
needs_close = True
ms = MeasurementSet.open(ms, Mode.READONLY)
# Channels
num_chan = clamp_num_chan(ms, start_chan, num_chan)
loop = asyncio.get_event_loop()
for mjd_timestamp, start_row_idx, row_count in ms_timestamp_rows(
ms, num_timestamps, timestamp_offset
):
vis = await loop.run_in_executor(
executor,
ms.read_vis,
start_row_idx,
start_chan,
num_chan,
row_count,
)
yield vis, mjd_timestamp
if needs_close:
ms.close()
def ms_timestamp_rows(
ms: MeasurementSet, num_timestamps=0, repeat_idx=0
) -> Iterable[tuple[float, int, int]]:
"""
Iterates over the timestamps in the MS, also returning the first row
index and row count associated with that timestamp in the ``MAIN``
table. Timestamps can optionally adjusted by a repeat_idx to allow for
looping over the MS multiple times with always increasing timestamps
(i.e. the first timestamp with repeat_idx=1 will be one timestep later
than the last timestamp with repeat_idx=0).
"""
# Baselines
num_baselines = calc_baselines(ms)
# Timestamps
timestamps_in_ms = ms.num_rows // num_baselines
if num_timestamps < 0:
raise ValueError(f"num_timestamps must be >= 0: {num_timestamps}")
if num_timestamps == 0:
num_timestamps = timestamps_in_ms
elif num_timestamps > timestamps_in_ms:
logger.warning(
"%d > num_timestamps in MS (%d), reducing to %d",
num_timestamps,
timestamps_in_ms,
timestamps_in_ms,
)
num_timestamps = timestamps_in_ms
# Pre-compute all times as integer/fraction values
# as required by the ICD payload
time_data = ms.read_column("TIME", 0, ms.num_rows)
diffs = np.diff(time_data)
if np.any(diffs < 0.0):
raise ValueError("Timestamps must be non-decreasing")
# Add an additional diff to the end so we verify the number of baselines in the last TS
baseline_ts_change_idxs = np.flatnonzero(np.concatenate([diffs, [1]])) + 1
num_baselines_per_ts = np.diff(baseline_ts_change_idxs)
if np.any(num_baselines_per_ts != num_baselines):
raise ValueError(
f"Data per timestamp ({num_baselines_per_ts}) doesn't "
f"match expected number of baselines ({num_baselines})"
)
time_offset = 0
if repeat_idx != 0:
if num_timestamps > 1:
time_range = time_data[num_timestamps * num_baselines - 1] - time_data[0]
time_increment = time_range / (num_timestamps - 1)
time_offset = repeat_idx * (time_range + time_increment)
else:
# no increment can be calculated so use 1s instead
time_offset = repeat_idx * 1.0
# Read until we exhaust the MS
start_row = 0
timestamp_count = 0
while start_row < ms.num_rows and timestamp_count < num_timestamps:
yield time_data[start_row] + time_offset, start_row, num_baselines
start_row += num_baselines
timestamp_count += 1
[docs]
class MSWriter(ContextManager):
"""
A class that handles the writing of data into a new MeasurementSet.
"""
def __init__(
self,
output_filename: str,
scan: Scan,
antennas: Sequence[Antenna],
baselines: Optional[Baselines] = None,
plasma_socket: Optional[str] = None,
):
"""
Handles the writing of a measurement set using icd payloads
Args:
output_filename (str): the output file name
tm (BaseTM): the telescope model
scan (Scan): the scan model
plasma_socket (Optional[str], optional): _description_. Defaults to None.
"""
num_stations = len(antennas)
if not baselines:
baselines = Baselines.generate(num_stations, True)
self._use_plasmastman = plasma_socket is not None
self._row_offset: Optional[int] = None
self.num_baselines = len(baselines)
self.num_pols = scan.scan_type.num_pols
self._scan = scan
ms_type = "Plasma" if self._use_plasmastman else "regular"
logger.info(
"Creating %s MS at %s for scan %i, %d stations, %d baselines and %d channels",
ms_type,
output_filename,
scan.scan_number,
num_stations,
self.num_baselines,
scan.scan_type.num_channels,
)
self.ms = MeasurementSet.create(output_filename, scan, antennas, baselines, plasma_socket)
def _write_data(
self,
row: int,
scan_id: int,
beam_id: str,
sw_id: str,
pol_id: str,
time: float,
interval: float,
exposure: float,
first_chan_idx: int,
chan_count: int,
uvw: np.ndarray,
vis: Union[TensorRef, np.ndarray],
):
"""Writes both coordinate and visibility data for a given timestamp dump"""
logger.debug(
"Writing visibility data at row %d for channel indices %d:%d",
row,
first_chan_idx,
first_chan_idx + chan_count,
)
# populate missing and current rows with flags
if not self._use_plasmastman:
self.ms.write_defaulted_rows(self.num_baselines, row)
self.ms.write_vis(row, first_chan_idx, chan_count, self.num_baselines, vis)
self.ms.write_flag(row, first_chan_idx, chan_count, self.num_baselines, None)
# TODO: unsafe assumption beam_idx matches field_idx
# TODO: unsafe assumption beam_id matches ms_beam_id'
# TODO: Shouldn't access private members
# pylint: disable=protected-access
beam = self.ms._model.beams[beam_id]
ms_field_id: int = list(self.ms._model.fields.keys()).index(beam.field.field_id)
self.ms.write_field_id(row, field_id=ms_field_id)
self.ms.write_datadesc_id(row, sw_id=sw_id, pol_id=pol_id)
self.ms.write_weight(row, 0, self.num_pols, self.num_baselines, None)
self.ms.write_sigma(row, 0, self.num_pols, self.num_baselines, None)
self.ms.write_scan(row, self.num_baselines, scan_id)
self.ms.write_coords(row, self.num_baselines, uvw, interval, exposure, time)
def _calculate_insertion_row(self, payload_seq_no: int, payload_time: float) -> int:
"""
Calculates the row number to write a spectral window
data slice. If payload_seq_no zero or negative then the
then row will calculated from bisecting the TIME column
using payload_time.
Args:
payload_seq_no (int): Spectral window sequence number.
payload_time (float): Payload time used only when
payload_seq_no is negative or zero.
Returns:
int: The row number
"""
if payload_seq_no < 0:
if self.ms.num_rows is None:
row = 1
else:
row = self.ms.get_row_num(payload_time) + 1
else:
row = payload_seq_no * self.num_baselines
if self._row_offset is None:
self._row_offset = row
return row - self._row_offset
def write_data_row(
self,
scan_id: int,
beam: Beam,
sw: SpectralWindow,
payload_seq_no: int,
mjd_time: float,
interval: float,
exposure: float,
uvw: np.ndarray,
vis: Union[TensorRef, np.ndarray],
):
"""
Write a full spectral window's worth of data to the main table
using payload_seq_no (or mjd_time) to determine the row location.
"""
num_channels = vis.shape[1]
assert sw.count == num_channels, "Visibility channels must match spectral window size"
self.write_data(
scan_id,
beam,
sw,
payload_seq_no,
mjd_time,
interval,
exposure,
0,
num_channels,
uvw,
vis,
)
def write_data(
self,
scan_id: int,
beam: Beam,
sw: SpectralWindow,
payload_seq_no: int,
mjd_time: float,
interval: float,
exposure: float,
first_chan: int,
chan_count: int,
uvw: np.ndarray,
vis: Union[TensorRef, np.ndarray],
):
"""
Writes a spectral window data slice to the main table using
payload_seq_no (or mjd_time) to determine the row location.
Note: first_chan refers to the first channel **index** to be
be written.
"""
warnings.warn(
"Writing subsets of channel data is deprecated. Use write_data_row() instead",
DeprecationWarning,
)
# Find out the row for this payload/timestamp
row = self._calculate_insertion_row(payload_seq_no, mjd_time)
self._write_data(
row,
scan_id,
beam.beam_id,
sw.spectral_window_id,
beam.polarisations.polarisation_id,
mjd_time,
interval,
exposure,
first_chan,
chan_count,
uvw,
vis,
)
logger.info(
"Written data for %f %s %s to row %i",
mjd_time,
beam.beam_id,
sw.spectral_window_id,
row,
)
def write_pointings(self, pointings: Sequence[Pointing]):
self.ms.write_pointings(pointings)
def close(self):
self.ms.close()
def __enter__(self):
"""context manager protocol"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""context manager protocol"""
self.close()