Source code for realtime.receive.core.msutils

import asyncio
import bisect
import collections
import logging
import os
import sys
import time
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import (
    Any,
    AsyncIterable,
    Callable,
    ContextManager,
    Dict,
    Iterable,
    Optional,
    OrderedDict,
    Sequence,
    Tuple,
    Union,
)

import numpy as np
from astropy.coordinates import Angle
from casacore import tables

from realtime.receive.core import baseline_utils, time_utils
from realtime.receive.core.antenna import Antenna
from realtime.receive.core.baselines import Baselines
from realtime.receive.core.channel_range import ChannelRange
from realtime.receive.core.pointing import Pointing
from realtime.receive.core.scan import (
    Beam,
    Channels,
    Field,
    FrequencyType,
    PhaseDirection,
    Polarisations,
    Scan,
    ScanType,
    SpectralWindow,
    StokesType,
)

logger = logging.getLogger(__name__)

# pylint: disable=too-many-lines,missing-function-docstring


@dataclass(frozen=True)
class TensorRef:
    """A reference to an ndarray in memory"""

    oid: bytes
    data: np.ndarray


@dataclass(frozen=True)
class MSScan:
    """
    Relational scan model as supported by measurement set v2
    """

    scan_number: int
    beams: OrderedDict[str, Beam]
    polarisations: OrderedDict[str, Polarisations]
    fields: OrderedDict[str, Field]
    channels: OrderedDict[str, Channels]
    spectral_windows: OrderedDict[str, SpectralWindow]

    datadesc: Dict[Tuple[str, str], int]  # pol_name, sw_id

    @classmethod
    def from_scan(cls, scan: Scan):
        """
        Infers the measurement set relational model for a complete
        ska ScanType (Naive traversal of beams would otherwise create
        duplicate table entries without this).
        """
        beams = collections.OrderedDict()
        datadesc = {}
        polarizations = collections.OrderedDict()
        fields = collections.OrderedDict()
        channels = collections.OrderedDict()
        spectral_windows = collections.OrderedDict()

        dd_id = 0
        for b in scan.scan_type.beams:
            beams[b.beam_id] = b
            fields[b.field.field_id] = b.field
            polarizations[b.polarisations.polarisation_id] = b.polarisations
            channels[b.channels.channels_id] = b.channels
            for sw in b.channels.spectral_windows:
                spectral_windows[sw.spectral_window_id] = sw

                # Generate a lookuptable only for pol-sw pairs
                # that exist in beams
                dd = (b.polarisations.polarisation_id, sw.spectral_window_id)
                if dd not in datadesc:
                    datadesc[dd] = dd_id
                    dd_id += 1

        return cls(
            scan.scan_number,
            beams,
            polarizations,
            fields,
            channels,
            spectral_windows,
            datadesc,
        )

    @classmethod
    def from_main_table(cls, maintable: tables.table):
        """
        Infers the measurement set relational model as best as
        possible from measurement set tables. Limitations for MSv2 include:
        * Beam function is lost
        * Channels created per beam
        * Channels channels_id is lost
        * Polarization name is lost
        """
        scan_number = maintable.getcol("SCAN_NUMBER")[0]
        beams = collections.OrderedDict()
        polarizations = collections.OrderedDict()
        fields = collections.OrderedDict()
        channels = collections.OrderedDict()
        spectral_windows = collections.OrderedDict()

        poltable = _subtable(maintable, "POLARIZATION", readonly=True)
        for row in range(poltable.nrows()):
            stokes = poltable.getcell("CORR_TYPE", row)
            # pol name is lost
            polarizations[str(row)] = Polarisations(
                ",".join([StokesType(stoke).name for stoke in stokes]),
                [StokesType(stoke) for stoke in stokes],
            )

        fieldtable = _subtable(maintable, "FIELD", readonly=True)
        units = fieldtable.col("PHASE_DIR").getkeyword("QuantumUnits")
        for row in range(fieldtable.nrows()):
            name = fieldtable.getcell("NAME", row)
            phase_dir = fieldtable.getcell("PHASE_DIR", row).swapaxes(0, 1)
            fields[name] = Field(
                name,
                PhaseDirection(
                    Angle(phase_dir[0], units[0]),
                    Angle(phase_dir[1], units[1]),
                    "",
                ),
            )

        swtable = _subtable(maintable, "SPECTRAL_WINDOW", readonly=True)
        for row in range(swtable.nrows()):
            name = swtable.getcell("NAME", row)

            strides = np.unique(
                swtable.getcell("CHAN_WIDTH", row) / swtable.getcell("RESOLUTION", row)
            )
            if len(strides) == 1 and strides[0].is_integer():
                stride = int(strides[0])
            else:
                logger.warning("Detected malformed strides in MS: %s. Defaulting to 1.", strides)
                stride = 1

            spectral_windows[name] = SpectralWindow(
                name,
                count=swtable.getcell("NUM_CHAN", row),
                start=0,
                freq_min=swtable.getcell("CHAN_FREQ", row)[0]
                - swtable.getcell("RESOLUTION", row)[0] / 2,
                freq_max=swtable.getcell("CHAN_FREQ", row)[-1]
                + swtable.getcell("RESOLUTION", row)[-1] / 2,
                stride=stride,
            )

        # Extract beams
        # cach of distinct field, spectral window, polarization tuples
        swbeamtable: tables.table = tables.taql(
            (
                "SELECT DISTINCT "
                "FIELD_ID, "
                "[SELECT SPECTRAL_WINDOW_ID FROM ::DATA_DESCRIPTION][DATA_DESC_ID] as SPECTRAL_WINDOW_ID, "
                "[SELECT POLARIZATION_ID FROM ::DATA_DESCRIPTION][DATA_DESC_ID] as POLARIZATION_ID "
                "FROM $1"
            ),
            tables=[maintable],
            locals={""},
        )
        # TODO: In MSv2 a beam is approximated as a unique combination of FIELD_ID and
        # POLARIZATION_ID, any matching SPECTRAL_WINDOW_ID will be associated to a single beam.
        beamtable: tables.table = tables.taql(
            "SELECT DISTINCT FIELD_ID, POLARIZATION_ID FROM $1",
            tables=[swbeamtable],
            locals={""},
        )

        def swquery(field_idx: int, polarization_idx: int):
            """Query for all spectral windows associated with a single beam"""
            return (
                "SELECT DISTINCT SPECTRAL_WINDOW_ID "
                "FROM $1 "
                f"WHERE FIELD_ID == {field_idx} AND POLARIZATION_ID == {polarization_idx}"
            )

        polarray = list(polarizations.values())
        fieldarray = list(fields.values())
        swarray = list(spectral_windows.values())
        for beam_id in range(beamtable.nrows()):
            beam_id_str = str(beam_id)
            field_idx = beamtable.getcell("FIELD_ID", beam_id)
            pol_idx = beamtable.getcell("POLARIZATION_ID", beam_id)
            beam_spectral_windows: tables.table = tables.taql(
                swquery(field_idx, pol_idx),
                tables=[swbeamtable],
                locals={""},
            )

            # TODO: channels is not stored in MSv2 so this will create
            # a unique channels container per beam
            channels[beam_id_str] = Channels(
                beam_id_str,
                [
                    swarray[beam_spectral_windows.getcell("SPECTRAL_WINDOW_ID", row)]
                    for row in range(beam_spectral_windows.nrows())
                ],
            )
            beams[beam_id_str] = Beam(
                beam_id_str,
                "missing",
                channels[beam_id_str],
                polarray[pol_idx],
                fieldarray[field_idx],
            )

        return cls(
            scan_number,
            beams,
            polarizations,
            fields,
            channels,
            spectral_windows,
            datadesc={},
        )


def _subtable(t: tables.table, name: str, readonly=True):
    return tables.table(t.getkeyword(name), readonly=readonly, ack=False)


def _fill_observation_subtable(ms):
    observation = _subtable(ms, "OBSERVATION", readonly=False)
    username = os.getenv("USERNAME", "Unknown")
    observation.addrows(1)

    def write(col, v):
        observation.putcell(col, 0, v)

    write("SCHEDULE", [""])
    write("PROJECT", "")
    write("OBSERVER", username)
    write("TELESCOPE_NAME", "cbf-sdp-emulator")
    write("TIME_RANGE", (0, 0))


def _fill_antenna_subtable(ms, antennas: Sequence[Antenna]):
    num_stations = len(antennas)
    antenna_names = np.array([antenna.name for antenna in antennas])
    antenna_positions = np.array([antenna.pos for antenna in antennas])
    antenna_dish_diameters = np.array([antenna.dish_diameter for antenna in antennas])
    antenna_flags = np.repeat(False, num_stations)
    antenna_mounts = np.repeat("FIXED", num_stations)

    antenna_t = _subtable(ms, "ANTENNA", readonly=False)
    antenna_t.addrows(num_stations)
    antenna_t.putcol("POSITION", antenna_positions)
    antenna_t.putcol("DISH_DIAMETER", antenna_dish_diameters)
    antenna_t.putcol("FLAG_ROW", antenna_flags)
    antenna_t.putcol("MOUNT", antenna_mounts)
    antenna_t.putcol("NAME", antenna_names)
    antenna_t.putcol("STATION", antenna_names)


def _fill_history_subtable(ms):
    history = _subtable(ms, "HISTORY", readonly=False)
    history.addrows(1)
    now = time.time()
    now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now))
    now_mjd = time_utils.unix_to_mjd(now)

    def write(col, v):
        history.putcell(col, 0, v)

    write("MESSAGE", f"Measurement Set created at {now_str}")
    write("APPLICATION", "cbf-sdp-emulator")
    write("PRIORITY", "INFO")
    write("ORIGIN", "cbf-sdp-emulator")
    write("TIME", now_mjd)
    write("OBSERVATION_ID", -1)
    write("APP_PARAMS", sys.argv[1:])
    write("CLI_COMMAND", sys.argv)


def _fill_feed_subtable(
    ms: tables.table,
    num_stations: int,
    beams: OrderedDict[str, Beam],
    spectral_windows: OrderedDict[str, SpectralWindow],
):
    feed = _subtable(ms, "FEED", readonly=False)
    num_receptors = 2

    sw_names = list(spectral_windows.keys())
    for beam_idx, beam in enumerate(beams.values()):
        for sw in beam.channels.spectral_windows:
            rownr = feed.nrows()
            feed.addrows(num_stations)
            feed.putcol("ANTENNA_ID", np.arange(num_stations), startrow=rownr)
            feed.putcol(
                "BEAM_ID",
                np.repeat(beam_idx, num_stations),
                startrow=rownr,
            )
            feed.putcol("FEED_ID", np.zeros(num_stations), startrow=rownr)
            feed.putcol(
                "SPECTRAL_WINDOW_ID",
                np.repeat(sw_names.index(sw.spectral_window_id), num_stations),
                startrow=rownr,
            )
            feed.putcol(
                "BEAM_OFFSET",
                np.zeros((num_stations, 2, num_receptors)),
                startrow=rownr,
            )
            feed.putcol(
                "NUM_RECEPTORS",
                np.repeat(num_receptors, num_stations),
                startrow=rownr,
            )
            feed.putcol(
                "POLARIZATION_TYPE",
                np.repeat(np.array([["X", "Y"]]), num_stations, axis=0),
                startrow=rownr,
            )
            feed.putcol(
                "POL_RESPONSE",
                np.zeros(
                    (num_stations, num_receptors, num_receptors),
                    dtype="complex",
                ),
                startrow=rownr,
            )
            feed.putcol(
                "RECEPTOR_ANGLE",
                np.zeros((num_stations, num_receptors)),
                startrow=rownr,
            )
            rownr += num_stations


def _fill_polarizations_subtable(ms: tables.table, pols: Sequence[Polarisations]):
    poltable = _subtable(ms, "POLARIZATION", readonly=False)
    for pol in pols:
        pol_id = poltable.nrows()
        poltable.addrows(1)
        poltable.putcell(
            "CORR_TYPE",
            pol_id,
            np.array([s.value for s in pol.correlation_type]),
        )
        poltable.putcell(
            "CORR_PRODUCT",
            pol_id,
            np.array([s.product for s in pol.correlation_type]),
        )
        poltable.putcell("FLAG_ROW", pol_id, False)
        poltable.putcell("NUM_CORR", pol_id, len(pol.correlation_type))


def _fill_spectral_window_subtable(ms: tables.table, spectral_windows: Sequence[SpectralWindow]):
    swtable = _subtable(ms, "SPECTRAL_WINDOW", readonly=False)
    for sw in spectral_windows:
        sw_id = swtable.nrows()
        swtable.addrows(1)
        swtable.putcell("CHAN_FREQ", sw_id, sw.frequencies)
        swtable.putcell("CHAN_WIDTH", sw_id, np.repeat(sw.channel_width, sw.count))
        swtable.putcell("EFFECTIVE_BW", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
        swtable.putcell("FLAG_ROW", sw_id, 0)
        swtable.putcell("FREQ_GROUP", sw_id, 0)
        swtable.putcell("FREQ_GROUP_NAME", sw_id, "")
        swtable.putcell("IF_CONV_CHAIN", sw_id, 0)
        swtable.putcell("MEAS_FREQ_REF", sw_id, FrequencyType.TOPO.value)
        swtable.putcell("NAME", sw_id, sw.spectral_window_id)
        swtable.putcell("NET_SIDEBAND", sw_id, 0)
        swtable.putcell("NUM_CHAN", sw_id, sw.count)
        swtable.putcell("REF_FREQUENCY", sw_id, sw.freq_min)
        swtable.putcell("RESOLUTION", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
        swtable.putcell("TOTAL_BANDWIDTH", sw_id, sw.freq_max - sw.freq_min)


def _fill_data_description_subtable(
    ms: tables.table,
    beams: Sequence[Beam],
    spectral_windows: OrderedDict[str, SpectralWindow],
    polarisations: OrderedDict[str, Polarisations],
):
    ddtable = _subtable(ms, "DATA_DESCRIPTION", readonly=False)

    sw_ids = list(spectral_windows.keys())
    pol_names = list(polarisations.keys())
    for beam in beams:
        for sw in beam.channels.spectral_windows:
            dd_id = ddtable.nrows()
            ddtable.addrows(1)
            ddtable.putcell(
                "POLARIZATION_ID",
                dd_id,
                pol_names.index(beam.polarisations.polarisation_id),
            )
            ddtable.putcell(
                "SPECTRAL_WINDOW_ID",
                dd_id,
                sw_ids.index(sw.spectral_window_id),
            )


def _fill_state_subtable(ms: tables.table):
    stable = _subtable(ms, "STATE", readonly=False)
    stable.addrows(1)


def _fill_processor_subtable(ms: tables.table):
    stable = _subtable(ms, "PROCESSOR", readonly=False)
    stable.addrows(1)


def _fill_field_subtable(ms: tables.table, fields: Sequence[Field]):
    """
    Writes a new entry to the field subtable.

    YAN-1249 The FIELD table also needs the REFERENCE_DIR and
    the DELAY_DIR.

    PHASE_DIR - the direction of the phase center
    DELAY_DIR - the direction the correlator delays point to.
    REFERENCE_DIR - in single dish mode used if position switching is
    being used - for interfermeters is the same as the PHASE and or DELAY
    Stephen Ord (21-Apr-2023)
    REF: https://casa.nrao.edu/Memos/229.html

    The SDP Telmodel only has a PHASE_DIR in the FIELD schema

    Technical Debt - we are assuming the DELAY/PHASE and reference DIR
    are the same.

    YAN-1324 has been raised to determine if this is correct - or
    if an architecture change is required.

    """
    fieldtable = tables.table(ms.getkeyword("FIELD"), readonly=False)
    for field in fields:
        rownr = fieldtable.nrows()
        fieldtable.addrows(1)
        fieldtable.putcell(
            "PHASE_DIR",
            rownr,
            np.array(
                [
                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
                ]
            ).swapaxes(0, 1),
        )
        fieldtable.putcell(
            "DELAY_DIR",
            rownr,
            np.array(
                [
                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
                ]
            ).swapaxes(0, 1),
        )

        fieldtable.putcell(
            "REFERENCE_DIR",
            rownr,
            np.array(
                [
                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
                ]
            ).swapaxes(0, 1),
        )

        fieldtable.putcell("NAME", rownr, field.field_id)


class Mode(Enum):
    """MeasurementSet open modes"""

    READONLY = 0
    READWRITE = 1
    CREATE = 2


[docs] class MeasurementSet(ContextManager): """ A simple wrapper around a Measurement Set. Writing is finalized during _t destructor. """ _t: Optional[tables.table] _name: str _readonly: bool _use_plasmastman: bool _plasma_socket: Optional[str] _antennas: Optional[Sequence[Antenna]] _model: MSScan @property def name(self) -> str: return self._name @property def t(self) -> tables.table: assert self._t is not None return self._t @property def num_rows(self): assert self._t is not None return self._t.nrows() @property def num_stations(self): return self._num_stations @property def num_channels(self): """Gets the number of channels used by the DATA column""" return self._num_channels @property def channel_range(self): warnings.warn( "This actually returns the channel indexes that are available " "to write to in the data. For a true insight into the channels " "in the data look at the spectral windows returned by " "calculate_scan_type()", DeprecationWarning, ) return ChannelRange(0, self.num_channels) @property def num_pols(self): """Gets the number of polarizations used by the DATA column""" return self._num_pols def calculate_scan_type(self): """ Gets the SKA Scan Model from the MeasurementSet. (MS tables only store a subset of the scan model currently thus this should only be used for testing) """ return ScanType("0", beams=list(self._model.beams.values())) @classmethod def open(cls, name: str, mode: Mode = Mode.READONLY): """Opens a MeasurementSet""" return cls(name, mode=mode) @classmethod def create( cls, name: str, scan: Scan, antennas: Sequence[Antenna], baselines: Baselines, plasma_socket: Optional[str] = None, ): """Creates a new measurement set from scan and tm data""" return MeasurementSet( name, mode=Mode.CREATE, scan=scan, antennas=antennas, baselines=baselines, plasma_socket=plasma_socket, ) def __init__( self, name: str, mode: Mode, scan: Optional[Scan] = None, antennas: Optional[Sequence[Antenna]] = None, baselines: Optional[Baselines] = None, plasma_socket: Optional[str] = None, ): self._name = name self._readonly = mode == Mode.READONLY self._antennas = antennas self._use_plasmastman = plasma_socket is not None self._plasma_socket = plasma_socket if mode in (Mode.READONLY, Mode.READWRITE): self._open() elif mode == Mode.CREATE: if scan is None or antennas is None or baselines is None: raise ValueError("scan, antennas and baselines required to create an MS") baseline_utils.validate_antenna_and_baseline_counts(len(antennas), len(baselines)) self._a1 = baselines.antenna1 self._a2 = baselines.antenna2 self._create(scan) else: raise ValueError(mode) def _open(self): """ Begins the opening of a MeasurementSet. Must be called once after constructing with readonly=True. """ self._t = tables.table(self._name, readonly=self._readonly, ack=False) self._num_stations = self._subtable("ANTENNA").nrows() swindow = self._subtable("SPECTRAL_WINDOW") self._num_channels: int = int(swindow.getcol("NUM_CHAN")[0]) self._num_pols: int = self._subtable("POLARIZATION").getcol("NUM_CORR")[0] # relational ms datamodel self._model = MSScan.from_main_table(self._t) def _create(self, scan: Scan): """ Begins the creation of a MeasurementSet. Must be called once after constructing with readonly=False. """ assert not self._readonly assert self._antennas is not None # relational datamodel of scan self._model = MSScan.from_scan(scan) self._num_stations = len(self._antennas) # cached variables (for DATA dimensions) self._num_channels = scan.scan_type.num_channels self._num_pols = scan.scan_type.num_pols if self._use_plasmastman: self._t = None else: self._create_table(self._model) def _create_table(self, model: MSScan, vis: Optional[TensorRef] = None): """ Creates a MeasurementSet folder structure using a TiledColumn Storage Manager for storing data on the filesystem and populates the table and antenna array members _t, _a1 and _a2 """ assert not self._readonly assert self._antennas is not None num_baselines = len(self._a1) # Setting up a specific storage manager only for the DATA column # for the time being; more can come in the future if not self._use_plasmastman: assert vis is None, "vis not required when creating a TiledData measurement set" data_coldesc = tables.makearrcoldesc( "DATA", "", valuetype="complex", shape=(self._num_channels, self._num_pols), datamanagergroup="TiledData", datamanagertype="TiledShapeStMan", keywords={"UNIT": "Jy"}, ) data_dminfo = { "DEFAULTTILESHAPE": np.array( [ self.num_pols, self._num_channels, 2 * num_baselines, ], dtype=np.int32, ), "MAXIMUMCACHESIZE": 0, } else: assert vis is not None, "vis is required when creating a PlasmaData Measurement Set" data_coldesc = tables.makearrcoldesc( "DATA", "", valuetype="complex", shape=(self._num_channels, self._num_pols), datamanagergroup="TiledData", datamanagertype="PlasmaStMan", keywords={"UNIT": "Jy"}, ) data_dminfo = { "PLASMASOCKET": self._plasma_socket, "TENSOROBJECTIDS": {"DATA": vis.oid}, } tabdesc = tables.maketabdesc( [ data_coldesc, tables.makearrcoldesc( "FLAG", "", valuetype="bool", shape=(self._num_channels, self._num_pols), datamanagergroup="TiledFlag", datamanagertype="TiledShapeStMan", ), tables.makearrcoldesc( "WEIGHT", "", valuetype="float", shape=(self._num_pols,), datamanagergroup="TiledWeight", datamanagertype="TiledShapeStMan", ), tables.makearrcoldesc( "SIGMA", "", valuetype="float", shape=(self._num_pols,), datamanagergroup="TiledSigma", datamanagertype="TiledShapeStMan", ), ] ) dminfo = tables.makedminfo( tabdesc, { "TiledData": data_dminfo, "TiledFlag": { "DEFAULTTILESHAPE": np.array( [ self.num_pols, self._num_channels, 2 * num_baselines, ], dtype=np.int32, ), "MAXIMUMCACHESIZE": 0, }, "TiledWeight": { "DEFAULTTILESHAPE": np.array( [self._num_pols, 2 * num_baselines], dtype=np.int32 ), "MAXIMUMCACHESIZE": 0, }, "TiledSigma": { "DEFAULTTILESHAPE": np.array( [self._num_pols, 2 * num_baselines], dtype=np.int32 ), "MAXIMUMCACHESIZE": 0, }, }, ) self._t = tables.default_ms(self._name, tabdesc=tabdesc, dminfo=dminfo) MeasurementSet._create_pointing_table(self._t) _fill_observation_subtable(self._t) _fill_antenna_subtable(self._t, self._antennas) _fill_polarizations_subtable(self._t, tuple(model.polarisations.values())) _fill_spectral_window_subtable(self._t, tuple(model.spectral_windows.values())) _fill_data_description_subtable( self._t, tuple(model.beams.values()), model.spectral_windows, model.polarisations, ) _fill_history_subtable(self._t) _fill_feed_subtable( self._t, self._num_stations, self._model.beams, self._model.spectral_windows, ) _fill_field_subtable(self._t, tuple(model.fields.values())) _fill_state_subtable(self._t) _fill_processor_subtable(self._t) @staticmethod def _create_pointing_table(main_table: tables.table): def make_pointing_coldesc(name: str, comment: str, ref=None): measinfo = {"type": "direction"} if ref: measinfo["Ref"] = ref return tables.makearrcoldesc( name, 0.0, valuetype="double", comment=comment, shape=(1, 2), datamanagertype="TiledColumnStMan", keywords={ "QuantumUnits": ["rad", "rad"], "MEASINFO": measinfo, }, ) tabdesc = tables.maketabdesc( [ make_pointing_coldesc("TARGET", "Target direction", "AZEL"), make_pointing_coldesc("DIRECTION", "Antenna pointing direction", "AZEL"), make_pointing_coldesc("POINTING_OFFSET", "A priori pointing correction in El/xEl"), make_pointing_coldesc("SOURCE_OFFSET", "Offset from source in El/xEl"), ], ) pointing_table = tables.default_ms_subtable( "POINTING", name=f"{main_table.name()}/POINTING", tabdesc=tabdesc, ) pointing_table.close() def _subtable(self, name, readonly=True): assert self._t is not None return _subtable(self._t, name, readonly=readonly) def close(self): if self._t is not None: self._t.close() self._t = None # type: ignore def __enter__(self): """context manager protocol""" return self def __exit__(self, exc_type, exc_val, exc_tb): """context manager protocol""" self.close() def read_cell(self, columnname, row: int) -> np.ndarray: assert self._t is not None return self._t.getcell(columnname, row) def read_column(self, columnname, startrow=0, nrow=None) -> np.ndarray: """Reads num_rows rows from columnname starting at first_row""" assert self._t is not None nrow = nrow or self.num_rows return self._t.getcol(columnname, startrow=startrow, nrow=nrow) def read_coords(self, startrow, nrow) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: assert self._t is not None uvw = self.read_column("UVW", startrow=startrow, nrow=nrow) return uvw[:, 0], uvw[:, 1], uvw[:, 2] def read_weight(self, start_row, start_pol, num_pols, num_baselines) -> np.ndarray: assert self._t is not None blc = [start_pol] trc = [start_pol + num_pols - 1] return self._t.getcolslice("WEIGHT", blc, trc, startrow=start_row, nrow=num_baselines) def get_row_num(self, payload_time: float) -> int: """ Args: payload_time (float): payload time to lookup in the maintable row Returns: int: the first row number where the time matches """ assert self._t is not None if self._t.nrows() == 0: return 0 times: np.ndarray = self.read_column("TIME", 0, self._t.nrows()) if 0.0 in times: logger.warning("read_column has returned an array with anomalous 0.0 entry - retrying") # finds the insertion point to the left of any matching entry row = bisect.bisect_left(times, payload_time) # type: ignore row_interval = bisect.bisect_right(times, times[0]) # type: ignore return row // row_interval def write_defaulted_rows(self, num_baselines: int, end_row: int): """ Add new rows and override tables with default data """ assert self._t is not None while end_row >= self._t.nrows(): nrows = self._t.nrows() self._t.addrows(num_baselines) # Default flags is True to indicate where visibilities are dropped/invalid row_range = range(nrows, nrows + num_baselines) self._t.putcell( "FLAG", row_range, np.ones([self._num_channels, self._num_pols]), ) def write_field_id(self, row: int, field_id: int): assert self._t is not None self._t.putcell( columnname="FIELD_ID", rownr=row, value=field_id, ) def write_datadesc_id(self, row: int, sw_id: str, pol_id: str): assert self._t is not None self._t.putcell( columnname="DATA_DESC_ID", rownr=row, value=self._model.datadesc[(pol_id, sw_id)], ) def write_weight( self, start_row: int, start_pol: int, num_pols: int, row_count: int, weight: Optional[Union[np.ndarray, TensorRef]] = None, ): assert not self._readonly assert self._t is not None if weight is None: weight = np.ones((row_count, num_pols), dtype=float) assert isinstance(weight, np.ndarray), "tileddata measurement sets require ndarray weights" blc = [start_pol] trc = [start_pol + num_pols - 1] self._t.putcolslice("WEIGHT", weight, blc, trc, startrow=start_row, nrow=row_count) def write_flag( self, start_row: int, start_chan: int, chan_count: int, row_count: int, flag: Optional[Union[np.ndarray, TensorRef]] = None, ): """ Writes flag column to the main table. If flag is None then 0 will be written over the row range. """ assert not self._readonly assert self._t is not None if flag is None: flag = np.zeros((row_count, chan_count, self.num_pols), dtype=bool) assert isinstance(flag, np.ndarray), "tileddata measurement sets require ndarray flags" blc = [start_chan, 0] trc = [start_chan + chan_count - 1, self.num_pols - 1] self._t.putcolslice("FLAG", flag, blc, trc, startrow=start_row, nrow=row_count) def write_sigma( self, start_row: int, start_pol: int, pol_count: int, row_count: int, sigma: Optional[Union[np.ndarray, TensorRef]] = None, ): assert not self._readonly assert self._t is not None if sigma is None: sigma = np.ones((row_count, pol_count), dtype=float) assert isinstance(sigma, np.ndarray), "tileddata measurement sets require ndarray sigmas" blc = [start_pol] trc = [start_pol + pol_count - 1] self._t.putcolslice("SIGMA", sigma, blc, trc, startrow=start_row, nrow=row_count) def write_scan(self, start_row: int, row_count: int, scan_id: int): """Writes the scan number""" assert self._t is not None self._t.putcol( "SCAN_NUMBER", np.repeat(scan_id, row_count), startrow=start_row, nrow=row_count, ) # pylint: disable=unused-argument def write_coords(self, start_row: int, row_count: int, uvw, interval, exposure, time): assert not self._readonly assert self._t is not None times = np.repeat(time, row_count) def put_col_slice(col_name, values): self._t.putcol(col_name, values, startrow=start_row, nrow=row_count) # type: ignore put_col_slice("UVW", uvw) put_col_slice("ANTENNA1", self._a1[:row_count]) put_col_slice("ANTENNA2", self._a2[:row_count]) put_col_slice("EXPOSURE", np.repeat(exposure, row_count)) put_col_slice("INTERVAL", np.repeat(interval, row_count)) put_col_slice("TIME", times) put_col_slice("TIME_CENTROID", times) def read_vis(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray: """Reads a slice of the DATA table Returns: vis: visibilities of shape [pol, channels, baselines] """ assert self._t is not None blc = [start_channel, 0] trc = [start_channel + num_channels - 1, self.num_pols - 1] return self._t.getcolslice("DATA", blc, trc, startrow=start_row, nrow=num_baselines) def read_flags(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray: """Reads a slice of the FLAG table Returns: flag: flags of shape [pol, channels, baselines] """ assert self._t is not None blc = [start_channel, 0] trc = [start_channel + num_channels - 1, self.num_pols - 1] return self._t.getcolslice("FLAG", blc, trc, startrow=start_row, nrow=num_baselines) def write_vis( self, start_row: int, start_chan: int, chan_count: int, row_count: int, vis: Union[np.ndarray, TensorRef], ): """Writes visibilities to the DATA column of the maintable""" assert not self._readonly if self._use_plasmastman: if self._t is None: if isinstance(vis, TensorRef): # plasma measurement only support being created # with a single payload. self._create_table(self._model, vis) self._ensure_rows(start_row + row_count) else: raise ValueError(vis) else: raise NotImplementedError("multiple writes to plasma ms not supported") else: assert self._t is not None if isinstance(vis, TensorRef): vis = vis.data blc = [start_chan, 0] trc = [start_chan + chan_count - 1, self.num_pols - 1] self._t.putcolslice("DATA", vis, blc, trc, startrow=start_row, nrow=row_count) def write_pointings(self, pointings: Sequence[Pointing]): table = self._subtable("POINTING", readonly=False) row_idx = table.nrows() table.addrows(len(pointings)) def np_from_pointing_field(get_field: Callable[[Pointing], Any], dtype): return np.fromiter( (get_field(p) for p in pointings), count=len(pointings), dtype=dtype, ) def np_from_pointing_azel_data(get_azel_data: Callable[[Pointing], tuple[float, float]]): # Note that the last two dimensions here MUST match the # dimensions we set in the column & storage manager definition # in _create_pointing_table() arr = np_from_pointing_field(get_azel_data, np.dtype((np.float64, 2))) arr = np.expand_dims(arr, axis=1) return arr def putcol(col: str, value: np.ndarray): table.putcol( col, value, startrow=row_idx, nrow=len(value), ) putcol("INTERVAL", np.full(len(pointings), 0.0)) putcol("NUM_POLY", np.full(len(pointings), 0)) putcol("TRACKING", np_from_pointing_field(lambda p: p.tracking, bool)) putcol("ANTENNA_ID", np_from_pointing_field(lambda p: p.antenna_id, int)) putcol("DIRECTION", np_from_pointing_azel_data(lambda p: p.direction)) putcol("TARGET", np_from_pointing_azel_data(lambda p: p.target)) putcol( "POINTING_OFFSET", np_from_pointing_azel_data(lambda p: p.pointing_offset), ) putcol( "SOURCE_OFFSET", np_from_pointing_azel_data(lambda p: p.source_offset), ) times = np_from_pointing_field(lambda p: p.time, np.float64) times = time_utils.unix_to_mjd(times) putcol("TIME", times) putcol("TIME_ORIGIN", times) max_len = max(len(p.name) for p in pointings) if max_len: putcol("NAME", np_from_pointing_field(lambda p: p.name, f"U{max_len}")) def _ensure_rows(self, nrows): """Add missing rows""" assert self._t is not None current_rows = self._t.nrows() if current_rows < nrows: self._t.addrows(nrows - current_rows)
def calc_baselines(ms: MeasurementSet) -> int: """Number of baselines in input Gets the number of baselines from the measurement set :param ms: measurement set :return: number of baselines """ row_num = 0 firsttime = ms.read_column("TIME", row_num, 1) nexttime = firsttime while nexttime == firsttime: row_num = row_num + 1 if row_num == ms.num_rows: break nexttime = ms.read_column("TIME", row_num, 1) return row_num def clamp_num_chan(ms: MeasurementSet, start_chan: int, num_chan: int): """ Clamps the num_chan values so the range they define is contained by the channel dimensionality of the MS. :param ms: The input measurement set :param start_chan: The start channel :param num_chan: The number of channels :return: The adjusted number of channels """ if num_chan == 0: num_chan = ms.num_channels if start_chan > ms.num_channels: raise ValueError(f"start_chan = {start_chan} > num_chan in MS ({ms.num_channels})") end_chan = start_chan + num_chan if end_chan > ms.num_channels: num_chan = ms.num_channels - start_chan logger.warning( "start_chan + num_chan = %d > num_chan in MS (%d), reducing num_chan to %d", end_chan, ms.num_channels, num_chan, ) return num_chan def vis_mjd_epoch(ms: MeasurementSet): """Find the MJD timestamp of the earliest visibility data""" mjd_time = ms.t.select("TIME").sort("TIME", limit=1).getcell("TIME", 0) return mjd_time async def vis_reader( ms, start_chan=0, num_chan=0, num_timestamps=0, timestamp_offset=0, executor=None, ) -> AsyncIterable[tuple[np.ndarray, float]]: """Reads the visibilties and timestamps from a measurement set :param ms: The input measurement set :param start_chan: The start channel **index** :param num_chan: The number of channels to read. If 0, all channels are read :param num_timestamps: The number of timestamps to read. If 0, all timestamps are read :param timestamp_offset: The offset applied to time for simulating an infinite series of data :return: Yields a full set of baselines at a time for each timestamp """ needs_close = False if isinstance(ms, str): needs_close = True ms = MeasurementSet.open(ms, Mode.READONLY) # Channels num_chan = clamp_num_chan(ms, start_chan, num_chan) loop = asyncio.get_event_loop() for mjd_timestamp, start_row_idx, row_count in ms_timestamp_rows( ms, num_timestamps, timestamp_offset ): vis = await loop.run_in_executor( executor, ms.read_vis, start_row_idx, start_chan, num_chan, row_count, ) yield vis, mjd_timestamp if needs_close: ms.close() def ms_timestamp_rows( ms: MeasurementSet, num_timestamps=0, repeat_idx=0 ) -> Iterable[tuple[float, int, int]]: """ Iterates over the timestamps in the MS, also returning the first row index and row count associated with that timestamp in the ``MAIN`` table. Timestamps can optionally adjusted by a repeat_idx to allow for looping over the MS multiple times with always increasing timestamps (i.e. the first timestamp with repeat_idx=1 will be one timestep later than the last timestamp with repeat_idx=0). """ # Baselines num_baselines = calc_baselines(ms) # Timestamps timestamps_in_ms = ms.num_rows // num_baselines if num_timestamps < 0: raise ValueError(f"num_timestamps must be >= 0: {num_timestamps}") if num_timestamps == 0: num_timestamps = timestamps_in_ms elif num_timestamps > timestamps_in_ms: logger.warning( "%d > num_timestamps in MS (%d), reducing to %d", num_timestamps, timestamps_in_ms, timestamps_in_ms, ) num_timestamps = timestamps_in_ms # Pre-compute all times as integer/fraction values # as required by the ICD payload time_data = ms.read_column("TIME", 0, ms.num_rows) diffs = np.diff(time_data) if np.any(diffs < 0.0): raise ValueError("Timestamps must be non-decreasing") # Add an additional diff to the end so we verify the number of baselines in the last TS baseline_ts_change_idxs = np.flatnonzero(np.concatenate([diffs, [1]])) + 1 num_baselines_per_ts = np.diff(baseline_ts_change_idxs) if np.any(num_baselines_per_ts != num_baselines): raise ValueError( f"Data per timestamp ({num_baselines_per_ts}) doesn't " f"match expected number of baselines ({num_baselines})" ) time_offset = 0 if repeat_idx != 0: if num_timestamps > 1: time_range = time_data[num_timestamps * num_baselines - 1] - time_data[0] time_increment = time_range / (num_timestamps - 1) time_offset = repeat_idx * (time_range + time_increment) else: # no increment can be calculated so use 1s instead time_offset = repeat_idx * 1.0 # Read until we exhaust the MS start_row = 0 timestamp_count = 0 while start_row < ms.num_rows and timestamp_count < num_timestamps: yield time_data[start_row] + time_offset, start_row, num_baselines start_row += num_baselines timestamp_count += 1
[docs] class MSWriter(ContextManager): """ A class that handles the writing of data into a new MeasurementSet. """ def __init__( self, output_filename: str, scan: Scan, antennas: Sequence[Antenna], baselines: Optional[Baselines] = None, plasma_socket: Optional[str] = None, ): """ Handles the writing of a measurement set using icd payloads Args: output_filename (str): the output file name tm (BaseTM): the telescope model scan (Scan): the scan model plasma_socket (Optional[str], optional): _description_. Defaults to None. """ num_stations = len(antennas) if not baselines: baselines = Baselines.generate(num_stations, True) self._use_plasmastman = plasma_socket is not None self._row_offset: Optional[int] = None self.num_baselines = len(baselines) self.num_pols = scan.scan_type.num_pols self._scan = scan ms_type = "Plasma" if self._use_plasmastman else "regular" logger.info( "Creating %s MS at %s for scan %i, %d stations, %d baselines and %d channels", ms_type, output_filename, scan.scan_number, num_stations, self.num_baselines, scan.scan_type.num_channels, ) self.ms = MeasurementSet.create(output_filename, scan, antennas, baselines, plasma_socket) def _write_data( self, row: int, scan_id: int, beam_id: str, sw_id: str, pol_id: str, time: float, interval: float, exposure: float, first_chan_idx: int, chan_count: int, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """Writes both coordinate and visibility data for a given timestamp dump""" logger.debug( "Writing visibility data at row %d for channel indices %d:%d", row, first_chan_idx, first_chan_idx + chan_count, ) # populate missing and current rows with flags if not self._use_plasmastman: self.ms.write_defaulted_rows(self.num_baselines, row) self.ms.write_vis(row, first_chan_idx, chan_count, self.num_baselines, vis) self.ms.write_flag(row, first_chan_idx, chan_count, self.num_baselines, None) # TODO: unsafe assumption beam_idx matches field_idx # TODO: unsafe assumption beam_id matches ms_beam_id' # TODO: Shouldn't access private members # pylint: disable=protected-access beam = self.ms._model.beams[beam_id] ms_field_id: int = list(self.ms._model.fields.keys()).index(beam.field.field_id) self.ms.write_field_id(row, field_id=ms_field_id) self.ms.write_datadesc_id(row, sw_id=sw_id, pol_id=pol_id) self.ms.write_weight(row, 0, self.num_pols, self.num_baselines, None) self.ms.write_sigma(row, 0, self.num_pols, self.num_baselines, None) self.ms.write_scan(row, self.num_baselines, scan_id) self.ms.write_coords(row, self.num_baselines, uvw, interval, exposure, time) def _calculate_insertion_row(self, payload_seq_no: int, payload_time: float) -> int: """ Calculates the row number to write a spectral window data slice. If payload_seq_no zero or negative then the then row will calculated from bisecting the TIME column using payload_time. Args: payload_seq_no (int): Spectral window sequence number. payload_time (float): Payload time used only when payload_seq_no is negative or zero. Returns: int: The row number """ if payload_seq_no < 0: if self.ms.num_rows is None: row = 1 else: row = self.ms.get_row_num(payload_time) + 1 else: row = payload_seq_no * self.num_baselines if self._row_offset is None: self._row_offset = row return row - self._row_offset def write_data_row( self, scan_id: int, beam: Beam, sw: SpectralWindow, payload_seq_no: int, mjd_time: float, interval: float, exposure: float, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """ Write a full spectral window's worth of data to the main table using payload_seq_no (or mjd_time) to determine the row location. """ num_channels = vis.shape[1] assert sw.count == num_channels, "Visibility channels must match spectral window size" self.write_data( scan_id, beam, sw, payload_seq_no, mjd_time, interval, exposure, 0, num_channels, uvw, vis, ) def write_data( self, scan_id: int, beam: Beam, sw: SpectralWindow, payload_seq_no: int, mjd_time: float, interval: float, exposure: float, first_chan: int, chan_count: int, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """ Writes a spectral window data slice to the main table using payload_seq_no (or mjd_time) to determine the row location. Note: first_chan refers to the first channel **index** to be be written. """ warnings.warn( "Writing subsets of channel data is deprecated. Use write_data_row() instead", DeprecationWarning, ) # Find out the row for this payload/timestamp row = self._calculate_insertion_row(payload_seq_no, mjd_time) self._write_data( row, scan_id, beam.beam_id, sw.spectral_window_id, beam.polarisations.polarisation_id, mjd_time, interval, exposure, first_chan, chan_count, uvw, vis, ) logger.info( "Written data for %f %s %s to row %i", mjd_time, beam.beam_id, sw.spectral_window_id, row, ) def write_pointings(self, pointings: Sequence[Pointing]): self.ms.write_pointings(pointings) def close(self): self.ms.close() def __enter__(self): """context manager protocol""" return self def __exit__(self, exc_type, exc_val, exc_tb): """context manager protocol""" self.close()