import asyncio
import bisect
import collections
import logging
import os
import sys
import time
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import (

import numpy as np
from astropy.coordinates import Angle
from casacore import tables

from realtime.receive.core import baseline_utils, time_utils
from realtime.receive.core.antenna import Antenna
from realtime.receive.core.baselines import Baselines
from realtime.receive.core.channel_range import ChannelRange
from realtime.receive.core.pointing import Pointing
from realtime.receive.core.scan import (

logger = logging.getLogger(__name__)

class TensorRef:
    """A reference to an ndarray in memory"""

    oid: bytes
    data: np.ndarray

class MSScan:
    Relational scan model as supported by measurement set v2

    scan_number: int
    beams: OrderedDict[str, Beam]
    polarisations: OrderedDict[str, Polarisations]
    fields: OrderedDict[str, Field]
    channels: OrderedDict[str, Channels]
    spectral_windows: OrderedDict[str, SpectralWindow]

    datadesc: Dict[Tuple[str, str], int]  # pol_name, sw_id

    def from_scan(cls, scan: Scan):
        Infers the measurement set relational model for a complete
        ska ScanType (Naive traversal of beams would otherwise create
        duplicate table entries without this).
        beams = collections.OrderedDict()
        datadesc = {}
        polarizations = collections.OrderedDict()
        fields = collections.OrderedDict()
        channels = collections.OrderedDict()
        spectral_windows = collections.OrderedDict()

        dd_id = 0
        for b in scan.scan_type.beams:
            beams[b.beam_id] = b
            fields[b.field.field_id] = b.field
            polarizations[b.polarisations.polarisation_id] = b.polarisations
            channels[b.channels.channels_id] = b.channels
            for sw in b.channels.spectral_windows:
                spectral_windows[sw.spectral_window_id] = sw

                # Generate a lookuptable only for pol-sw pairs
                # that exist in beams
                dd = (b.polarisations.polarisation_id, sw.spectral_window_id)
                if dd not in datadesc:
                    datadesc[dd] = dd_id
                    dd_id += 1

        return cls(

    def from_main_table(cls, maintable: tables.table):
        Infers the measurement set relational model as best as
        possible from measurement set tables. Limitations for MSv2 include:
        * Beam function is lost
        * Channels created per beam
        * Channels channels_id is lost
        * Polarization name is lost
        scan_number = maintable.getcol("SCAN_NUMBER")[0]
        beams = collections.OrderedDict()
        polarizations = collections.OrderedDict()
        fields = collections.OrderedDict()
        channels = collections.OrderedDict()
        spectral_windows = collections.OrderedDict()

        poltable = _subtable(maintable, "POLARIZATION", readonly=True)
        for row in range(poltable.nrows()):
            stokes = poltable.getcell("CORR_TYPE", row)
            # pol name is lost
            polarizations[str(row)] = Polarisations(
                ",".join([StokesType(stoke).name for stoke in stokes]),
                [StokesType(stoke) for stoke in stokes],

        fieldtable = _subtable(maintable, "FIELD", readonly=True)
        units = fieldtable.col("PHASE_DIR").getkeyword("QuantumUnits")
        for row in range(fieldtable.nrows()):
            name = fieldtable.getcell("NAME", row)
            phase_dir = fieldtable.getcell("PHASE_DIR", row).swapaxes(0, 1)
            fields[name] = Field(
                    Angle(phase_dir[0], units[0]),
                    Angle(phase_dir[1], units[1]),

        swtable = _subtable(maintable, "SPECTRAL_WINDOW", readonly=True)
        for row in range(swtable.nrows()):
            name = swtable.getcell("NAME", row)

            strides = np.unique(
                swtable.getcell("CHAN_WIDTH", row) / swtable.getcell("RESOLUTION", row)
            if len(strides) == 1 and strides[0].is_integer():
                stride = int(strides[0])
                logger.warning("Detected malformed strides in MS: %s. Defaulting to 1.", strides)
                stride = 1

            spectral_windows[name] = SpectralWindow(
                count=swtable.getcell("NUM_CHAN", row),
                freq_min=swtable.getcell("CHAN_FREQ", row)[0]
                - swtable.getcell("RESOLUTION", row)[0] / 2,
                freq_max=swtable.getcell("CHAN_FREQ", row)[-1]
                + swtable.getcell("RESOLUTION", row)[-1] / 2,

        # Extract beams
        # cach of distinct field, spectral window, polarization tuples
        swbeamtable: tables.table = tables.taql(
                "SELECT DISTINCT "
                "FIELD_ID, "
                "FROM $1"
        # TODO: In MSv2 a beam is approximated as a unique combination of FIELD_ID and
        # POLARIZATION_ID, any matching SPECTRAL_WINDOW_ID will be associated to a single beam.
        beamtable: tables.table = tables.taql(

        def swquery(field_idx: int, polarization_idx: int):
            """Query for all spectral windows associated with a single beam"""
            return (
                "FROM $1 "
                f"WHERE FIELD_ID == {field_idx} AND POLARIZATION_ID == {polarization_idx}"

        polarray = list(polarizations.values())
        fieldarray = list(fields.values())
        swarray = list(spectral_windows.values())
        for beam_id in range(beamtable.nrows()):
            beam_id_str = str(beam_id)
            field_idx = beamtable.getcell("FIELD_ID", beam_id)
            pol_idx = beamtable.getcell("POLARIZATION_ID", beam_id)
            beam_spectral_windows: tables.table = tables.taql(
                swquery(field_idx, pol_idx),

            # TODO: channels is not stored in MSv2 so this will create
            # a unique channels container per beam
            channels[beam_id_str] = Channels(
                    swarray[beam_spectral_windows.getcell("SPECTRAL_WINDOW_ID", row)]
                    for row in range(beam_spectral_windows.nrows())
            beams[beam_id_str] = Beam(

        return cls(

def _subtable(t: tables.table, name: str, readonly=True):
    return tables.table(t.getkeyword(name), readonly=readonly, ack=False)

def _fill_observation_subtable(ms):
    observation = _subtable(ms, "OBSERVATION", readonly=False)
    username = os.getenv("USERNAME", "Unknown")

    def write(col, v):
        observation.putcell(col, 0, v)

    write("SCHEDULE", [""])
    write("PROJECT", "")
    write("OBSERVER", username)
    write("TELESCOPE_NAME", "cbf-sdp-emulator")
    write("TIME_RANGE", (0, 0))

def _fill_antenna_subtable(ms, antennas: Sequence[Antenna]):
    num_stations = len(antennas)
    antenna_names = np.array([ for antenna in antennas])
    antenna_positions = np.array([antenna.pos for antenna in antennas])
    antenna_dish_diameters = np.array([antenna.dish_diameter for antenna in antennas])
    antenna_flags = np.repeat(False, num_stations)
    antenna_mounts = np.repeat("FIXED", num_stations)

    antenna_t = _subtable(ms, "ANTENNA", readonly=False)
    antenna_t.putcol("POSITION", antenna_positions)
    antenna_t.putcol("DISH_DIAMETER", antenna_dish_diameters)
    antenna_t.putcol("FLAG_ROW", antenna_flags)
    antenna_t.putcol("MOUNT", antenna_mounts)
    antenna_t.putcol("NAME", antenna_names)
    antenna_t.putcol("STATION", antenna_names)

def _fill_history_subtable(ms):
    history = _subtable(ms, "HISTORY", readonly=False)
    now = time.time()
    now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(now))
    now_mjd = time_utils.unix_to_mjd(now)

    def write(col, v):
        history.putcell(col, 0, v)

    write("MESSAGE", f"Measurement Set created at {now_str}")
    write("APPLICATION", "cbf-sdp-emulator")
    write("PRIORITY", "INFO")
    write("ORIGIN", "cbf-sdp-emulator")
    write("TIME", now_mjd)
    write("OBSERVATION_ID", -1)
    write("APP_PARAMS", sys.argv[1:])
    write("CLI_COMMAND", sys.argv)

def _fill_feed_subtable(
    ms: tables.table,
    num_stations: int,
    beams: OrderedDict[str, Beam],
    spectral_windows: OrderedDict[str, SpectralWindow],
    feed = _subtable(ms, "FEED", readonly=False)
    num_receptors = 2

    sw_names = list(spectral_windows.keys())
    for beam_idx, beam in enumerate(beams.values()):
        for sw in beam.channels.spectral_windows:
            rownr = feed.nrows()
            feed.putcol("ANTENNA_ID", np.arange(num_stations), startrow=rownr)
                np.repeat(beam_idx, num_stations),
            feed.putcol("FEED_ID", np.zeros(num_stations), startrow=rownr)
                np.repeat(sw_names.index(sw.spectral_window_id), num_stations),
                np.zeros((num_stations, 2, num_receptors)),
                np.repeat(num_receptors, num_stations),
                np.repeat(np.array([["X", "Y"]]), num_stations, axis=0),
                    (num_stations, num_receptors, num_receptors),
                np.zeros((num_stations, num_receptors)),
            rownr += num_stations

def _fill_polarizations_subtable(ms: tables.table, pols: Sequence[Polarisations]):
    poltable = _subtable(ms, "POLARIZATION", readonly=False)
    for pol in pols:
        pol_id = poltable.nrows()
            np.array([s.value for s in pol.correlation_type]),
            np.array([s.product for s in pol.correlation_type]),
        poltable.putcell("FLAG_ROW", pol_id, False)
        poltable.putcell("NUM_CORR", pol_id, len(pol.correlation_type))

def _fill_spectral_window_subtable(ms: tables.table, spectral_windows: Sequence[SpectralWindow]):
    swtable = _subtable(ms, "SPECTRAL_WINDOW", readonly=False)
    for sw in spectral_windows:
        sw_id = swtable.nrows()
        swtable.putcell("CHAN_FREQ", sw_id, sw.frequencies)
        swtable.putcell("CHAN_WIDTH", sw_id, np.repeat(sw.channel_width, sw.count))
        swtable.putcell("EFFECTIVE_BW", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
        swtable.putcell("FLAG_ROW", sw_id, 0)
        swtable.putcell("FREQ_GROUP", sw_id, 0)
        swtable.putcell("FREQ_GROUP_NAME", sw_id, "")
        swtable.putcell("IF_CONV_CHAIN", sw_id, 0)
        swtable.putcell("MEAS_FREQ_REF", sw_id, FrequencyType.TOPO.value)
        swtable.putcell("NAME", sw_id, sw.spectral_window_id)
        swtable.putcell("NET_SIDEBAND", sw_id, 0)
        swtable.putcell("NUM_CHAN", sw_id, sw.count)
        swtable.putcell("REF_FREQUENCY", sw_id, sw.freq_min)
        swtable.putcell("RESOLUTION", sw_id, np.repeat(sw.channel_bandwidth, sw.count))
        swtable.putcell("TOTAL_BANDWIDTH", sw_id, sw.freq_max - sw.freq_min)

def _fill_data_description_subtable(
    ms: tables.table,
    beams: Sequence[Beam],
    spectral_windows: OrderedDict[str, SpectralWindow],
    polarisations: OrderedDict[str, Polarisations],
    ddtable = _subtable(ms, "DATA_DESCRIPTION", readonly=False)

    sw_ids = list(spectral_windows.keys())
    pol_names = list(polarisations.keys())
    for beam in beams:
        for sw in beam.channels.spectral_windows:
            dd_id = ddtable.nrows()

def _fill_state_subtable(ms: tables.table):
    stable = _subtable(ms, "STATE", readonly=False)

def _fill_processor_subtable(ms: tables.table):
    stable = _subtable(ms, "PROCESSOR", readonly=False)

def _fill_field_subtable(ms: tables.table, fields: Sequence[Field]):
    Writes a new entry to the field subtable.

    YAN-1249 The FIELD table also needs the REFERENCE_DIR and
    the DELAY_DIR.

    PHASE_DIR - the direction of the phase center
    DELAY_DIR - the direction the correlator delays point to.
    REFERENCE_DIR - in single dish mode used if position switching is
    being used - for interfermeters is the same as the PHASE and or DELAY
    Stephen Ord (21-Apr-2023)

    The SDP Telmodel only has a PHASE_DIR in the FIELD schema

    Technical Debt - we are assuming the DELAY/PHASE and reference DIR
    are the same.

    YAN-1324 has been raised to determine if this is correct - or
    if an architecture change is required.

    fieldtable = tables.table(ms.getkeyword("FIELD"), readonly=False)
    for field in fields:
        rownr = fieldtable.nrows()
                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
            ).swapaxes(0, 1),
                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
            ).swapaxes(0, 1),

                    np.reshape(field.phase_dir.ra.rad, -1),
                    np.reshape(field.phase_dir.dec.rad, -1),
            ).swapaxes(0, 1),

        fieldtable.putcell("NAME", rownr, field.field_id)

class Mode(Enum):
    """MeasurementSet open modes"""

    READONLY = 0
    CREATE = 2

[docs] class MeasurementSet(ContextManager): """ A simple wrapper around a Measurement Set. Writing is finalized during _t destructor. """ _t: Optional[tables.table] _name: str _readonly: bool _use_plasmastman: bool _plasma_socket: Optional[str] _antennas: Optional[Sequence[Antenna]] _model: MSScan @property def name(self) -> str: return self._name @property def t(self) -> tables.table: assert self._t is not None return self._t @property def num_rows(self): assert self._t is not None return self._t.nrows() @property def num_stations(self): return self._num_stations @property def num_channels(self): """Gets the number of channels used by the DATA column""" return self._num_channels @property def channel_range(self): warnings.warn( "This actually returns the channel indexes that are available " "to write to in the data. For a true insight into the channels " "in the data look at the spectral windows returned by " "calculate_scan_type()", DeprecationWarning, ) return ChannelRange(0, self.num_channels) @property def num_pols(self): """Gets the number of polarizations used by the DATA column""" return self._num_pols def calculate_scan_type(self): """ Gets the SKA Scan Model from the MeasurementSet. (MS tables only store a subset of the scan model currently thus this should only be used for testing) """ return ScanType("0", beams=list(self._model.beams.values())) @classmethod def open(cls, name: str, mode: Mode = Mode.READONLY): """Opens a MeasurementSet""" return cls(name, mode=mode) @classmethod def create( cls, name: str, scan: Scan, antennas: Sequence[Antenna], baselines: Baselines, plasma_socket: Optional[str] = None, ): """Creates a new measurement set from scan and tm data""" return MeasurementSet( name, mode=Mode.CREATE, scan=scan, antennas=antennas, baselines=baselines, plasma_socket=plasma_socket, ) def __init__( self, name: str, mode: Mode, scan: Optional[Scan] = None, antennas: Optional[Sequence[Antenna]] = None, baselines: Optional[Baselines] = None, plasma_socket: Optional[str] = None, ): self._name = name self._readonly = mode == Mode.READONLY self._antennas = antennas self._use_plasmastman = plasma_socket is not None self._plasma_socket = plasma_socket if mode in (Mode.READONLY, Mode.READWRITE): self._open() elif mode == Mode.CREATE: if scan is None or antennas is None or baselines is None: raise ValueError("scan, antennas and baselines required to create an MS") baseline_utils.validate_antenna_and_baseline_counts(len(antennas), len(baselines)) self._a1 = baselines.antenna1 self._a2 = baselines.antenna2 self._create(scan) else: raise ValueError(mode) def _open(self): """ Begins the opening of a MeasurementSet. Must be called once after constructing with readonly=True. """ self._t = tables.table(self._name, readonly=self._readonly, ack=False) self._num_stations = self._subtable("ANTENNA").nrows() swindow = self._subtable("SPECTRAL_WINDOW") self._num_channels: int = int(swindow.getcol("NUM_CHAN")[0]) self._num_pols: int = self._subtable("POLARIZATION").getcol("NUM_CORR")[0] # relational ms datamodel self._model = MSScan.from_main_table(self._t) def _create(self, scan: Scan): """ Begins the creation of a MeasurementSet. Must be called once after constructing with readonly=False. """
        assert not self._readonly
        assert self._antennas is not None

        # relational datamodel of scan
        self._model = MSScan.from_scan(scan)
        self._num_stations = len(self._antennas)

        # cached variables (for DATA dimensions)
        self._num_channels = scan.scan_type.num_channels
        self._num_pols = scan.scan_type.num_pols

        if self._use_plasmastman:
            self._t = None
        else:
            self._create_table(self._model) def __enter__(self):
        """context manager protocol"""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """context manager protocol"""
        self.close()

    def read_cell(self, columnname, row: int) -> np.ndarray:
        assert self._t is not None
        return self._t.getcell(columnname, row)

    def read_column(self, columnname, startrow=0, nrow=None) -> np.ndarray:
        """Reads num_rows rows from columnname starting at first_row"""
        assert self._t is not None
        nrow = nrow or self.num_rows
        return self._t.getcol(columnname, startrow=startrow, nrow=nrow)

    def read_coords(self, startrow, nrow) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        assert self._t is not None
        uvw = self.read_column("UVW", startrow=startrow, nrow=nrow)
        return uvw[:, 0], uvw[:, 1], uvw[:, 2]

    def read_weight(self, start_row, start_pol, num_pols, num_baselines) -> np.ndarray:
        assert self._t is not None
        blc = [start_pol]
        trc = [start_pol + num_pols - 1]
        return self._t.getcolslice("WEIGHT", blc, trc, startrow=start_row, nrow=num_baselines)

    def get_row_num(self, payload_time: float) -> int:
        """
        Args:
            payload_time (float): payload time to lookup in the maintable row

        Returns:
            int: the first row number where the time matches
        """
        assert self._t is not None
        if self._t.nrows() == 0:
            return 0

        times: np.ndarray = self.read_column("TIME", 0, self._t.nrows())
        if 0.0 in times:
            logger.warning("read_column has returned an array with anomalous 0.0 entry - retrying") row = bisect.bisect_left(times, payload_time)
        row_interval = bisect.bisect_right(times, times[0])
        return row // row_interval

    def write_defaulted_rows(self, num_baselines: int, end_row: int):
        """
        Add new rows and override tables with default data
        """
        assert self._t is not None
        while end_row >= self._t.nrows():
            nrows = self._t.nrows()
            self._t.addrows(num_baselines)
            # Default flags is True to indicate where visibilities are dropped/invalid
            row_range = range(nrows, nrows + num_baselines)
            self._t.putcell(
                "FLAG",
                row_range,
                np.ones([self._num_channels, self._num_pols]),
            )

    def write_field_id(self, row: int, field_id: int):
        assert self._t is not None
        self._t.putcell(
            columnname="FIELD_ID",
            rownr=row,
            value=field_id,
        )

    def write_datadesc_id(self, row: int, sw_id: str, pol_id: str):
        assert self._t is not None
        self._t.putcell(
            columnname="DATA_DESC_ID",
            rownr=row,
            value=self._model.datadesc[(pol_id, sw_id)],
        )

    def write_weight(
        self,
        start_row: int,
        start_pol: int,
        num_pols: int,
        row_count: int,
        weight: Optional[Union[np.ndarray, TensorRef]] = None,
    ):
        assert not self._readonly
        assert self._t is not None
        if weight is None:
            weight = np.ones((row_count, num_pols), dtype=float)
        assert isinstance(weight, np.ndarray), "tileddata measurement sets require ndarray weights"
        blc = [start_pol]
        trc = [start_pol + num_pols - 1]
        self._t.putcolslice("WEIGHT", weight, blc, trc, startrow=start_row, nrow=row_count)

    def write_flag(
        self,
        start_row: int,
        start_chan: int,
        chan_count: int,
        row_count: int,
        flag: Optional[Union[np.ndarray, TensorRef]] = None,
    ):
        """
        Writes flag column to the main table. If flag is None then 0 will be written over the row range. """ assert not self._readonly assert self._t is not None if flag is None: flag = np.zeros((row_count, chan_count, self.num_pols), dtype=bool) assert isinstance(flag, np.ndarray), "tileddata measurement sets require ndarray flags" blc = [start_chan, 0] trc = [start_chan + chan_count - 1, self.num_pols - 1] self._t.putcolslice("FLAG", flag, blc, trc, startrow=start_row, nrow=row_count) def write_sigma( self, start_row: int, start_pol: int, pol_count: int, row_count: int, sigma: Optional[Union[np.ndarray, TensorRef]] = None, ): assert not self._readonly assert self._t is not None if sigma is None: sigma = np.ones((row_count, pol_count), dtype=float) assert isinstance(sigma, np.ndarray), "tileddata measurement sets require ndarray sigmas" blc = [start_pol] trc = [start_pol + pol_count - 1] self._t.putcolslice("SIGMA", sigma, blc, trc, startrow=start_row, nrow=row_count) def write_scan(self, start_row: int, row_count: int, scan_id: int): """Writes the scan number""" assert self._t is not None self._t.putcol( "SCAN_NUMBER", np.repeat(scan_id, row_count), startrow=start_row, nrow=row_count, ) # pylint: disable=unused-argument def write_coords(self, start_row: int, row_count: int, uvw, interval, exposure, time): assert not self._readonly assert self._t is not None times = np.repeat(time, row_count) def put_col_slice(col_name, values): self._t.putcol(col_name, values, startrow=start_row, nrow=row_count) # type: ignore put_col_slice("UVW", uvw) put_col_slice("ANTENNA1", self._a1[:row_count]) put_col_slice("ANTENNA2", self._a2[:row_count]) put_col_slice("EXPOSURE", np.repeat(exposure, row_count)) put_col_slice("INTERVAL", np.repeat(interval, row_count)) put_col_slice("TIME", times) put_col_slice("TIME_CENTROID", times) def read_vis(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray: """Reads a slice of the DATA table Returns: vis: visibilities of shape [pol, channels, baselines] """ assert self._t is not None blc = [start_channel, 0] trc = [start_channel + num_channels - 1, self.num_pols - 1]
        return self._t.getcolslice("DATA", blc, trc, startrow=start_row, nrow=num_baselines)

    def read_flags(self, start_row, start_channel, num_channels, num_baselines) -> np.ndarray:
        """Reads a slice of the FLAG table

        Returns:
            flag: flags of shape [pol, channels, baselines]
        """
        assert self._t is not None
        blc = [start_channel, 0]
        trc = [start_channel + num_channels - 1, self.num_pols - 1]
        return self._t.getcolslice("FLAG", blc, trc, startrow=start_row, nrow=num_baselines)

    def write_vis(
        self,
        start_row: int,
        start_chan: int,
        chan_count: int,
        row_count: int,
        vis: Union[np.ndarray, TensorRef],
    ):
        """Writes visibilities to the DATA column of the maintable"""
        assert not self._readonly
        if self._use_plasmastman:
            if self._t is None:
                if isinstance(vis, TensorRef):
                    # plasma measurement only support being created
                    # with a single payload.
                    self._create_table(self._model, vis)
                    self._ensure_rows(start_row + row_count)
                else:
                    raise ValueError(vis)
            else:
                raise NotImplementedError("multiple writes to plasma ms not supported")
        else:
            assert self._t is not None
            if isinstance(vis, TensorRef):
                vis =
            blc = [start_chan, 0]
            trc = [start_chan + chan_count - 1, self.num_pols - 1]
        self._t.putcolslice("DATA", vis, blc, trc, startrow=start_row, nrow=row_count)

    def write_pointings(self, pointings: Sequence[Pointing]):
        table = self._subtable("POINTING", readonly=False)
        row_idx = table.nrows()
        table.addrows(len(pointings))

        def np_from_pointing_field(get_field: Callable[[Pointing], Any], dtype):
            return np.fromiter(
                (get_field(p) for p in pointings),
                count=len(pointings),
                dtype=dtype,
            )

        def np_from_pointing_azel_data(get_azel_data: Callable[[Pointing], tuple[float, float]]):
            # Note that the last two dimensions here MUST match the
            # dimensions we set in the column & storage manager definition
            # in _create_pointing_table()
            arr = np_from_pointing_field(get_azel_data, np.dtype((np.float64, 2)))
            arr = np.expand_dims(arr, axis=1)
            return arr

        def putcol(col: str, value: np.ndarray):
            table.putcol(
                col,
                value,
                startrow=row_idx,
                nrow=len(value),
            )

        putcol("INTERVAL", np.full(len(pointings), 0.0))
        putcol("NUM_POLY", np.full(len(pointings), 0))
        putcol("TRACKING", np_from_pointing_field(lambda p: p.tracking, bool))
        putcol("ANTENNA_ID", np_from_pointing_field(lambda p: p.antenna_id, int))
        putcol("DIRECTION", np_from_pointing_azel_data(lambda p: p.direction))
        putcol("TARGET", np_from_pointing_azel_data(lambda p:
        putcol(
            "POINTING_OFFSET",
            np_from_pointing_azel_data(lambda p: p.pointing_offset),
        )
        putcol(
            "SOURCE_OFFSET",
            np_from_pointing_azel_data(lambda p: p.source_offset),
        )
        times = np_from_pointing_field(lambda p: p.time, np.float64)
        times = time_utils.unix_to_mjd(times)
        putcol("TIME", times)
        putcol("TIME_ORIGIN", times)
        max_len = max(len( for p in pointings)
        if max_len:
            putcol("NAME", np_from_pointing_field(lambda p:, f"U{max_len}"))

    def _ensure_rows(self, nrows):
        """Add missing rows"""
        assert self._t is not None
        current_rows = self._t.nrows()
        if current_rows < nrows:
            self._t.addrows(nrows - current_rows)
def calc_baselines(ms: MeasurementSet) -> int: """Number of baselines in input Gets the number of baselines from the measurement set :param ms: measurement set :return: number of baselines """ row_num = 0 firsttime = ms.read_column("TIME", row_num, 1) nexttime = firsttime while nexttime == firsttime: row_num = row_num + 1 if row_num == ms.num_rows: break nexttime = ms.read_column("TIME", row_num, 1) return row_num def clamp_num_chan(ms: MeasurementSet, start_chan: int, num_chan: int): """ Clamps the num_chan values so the range they define is contained by the channel dimensionality of the MS. :param ms: The input measurement set :param start_chan: The start channel :param num_chan: The number of channels :return: The adjusted number of channels """ if num_chan == 0: num_chan = ms.num_channels if start_chan > ms.num_channels: raise ValueError(f"start_chan = {start_chan} > num_chan in MS ({ms.num_channels})") end_chan = start_chan + num_chan if end_chan > ms.num_channels: num_chan = ms.num_channels - start_chan logger.warning( "start_chan + num_chan = %d > num_chan in MS (%d), reducing num_chan to %d", end_chan, ms.num_channels, num_chan, ) return num_chan def vis_mjd_epoch(ms: MeasurementSet): """Find the MJD timestamp of the earliest visibility data""" mjd_time ="TIME").sort("TIME", limit=1).getcell("TIME", 0) return mjd_time async def vis_reader( ms, start_chan=0, num_chan=0, num_timestamps=0, timestamp_offset=0, executor=None, ) -> AsyncIterable[tuple[np.ndarray, float]]: """Reads the visibilties and timestamps from a measurement set :param ms: The input measurement set :param start_chan: The start channel **index** :param num_chan: The number of channels to read. If 0, all channels are read :param num_timestamps: The number of timestamps to read. If 0, all timestamps are read :param timestamp_offset: The offset applied to time for simulating an infinite series of data :return: Yields a full set of baselines at a time for each timestamp """ needs_close = False if isinstance(ms, str): needs_close = True ms =, Mode.READONLY) # Channels num_chan = clamp_num_chan(ms, start_chan, num_chan) loop = asyncio.get_event_loop() for mjd_timestamp, start_row_idx, row_count in ms_timestamp_rows( ms, num_timestamps, timestamp_offset ): vis = await loop.run_in_executor( executor, ms.read_vis, start_row_idx, start_chan, num_chan, row_count, ) yield vis, mjd_timestamp if needs_close: ms.close() def ms_timestamp_rows( ms: MeasurementSet, num_timestamps=0, repeat_idx=0 ) -> Iterable[tuple[float, int, int]]: """ Iterates over the timestamps in the MS, also returning the first row index and row count associated with that timestamp in the ``MAIN`` table. Timestamps can optionally adjusted by a repeat_idx to allow for looping over the MS multiple times with always increasing timestamps (i.e. the first timestamp with repeat_idx=1 will be one timestep later than the last timestamp with repeat_idx=0). """ # Baselines num_baselines = calc_baselines(ms) # Timestamps timestamps_in_ms = ms.num_rows // num_baselines if num_timestamps < 0: raise ValueError(f"num_timestamps must be >= 0: {num_timestamps}") if num_timestamps == 0: num_timestamps = timestamps_in_ms elif num_timestamps > timestamps_in_ms: logger.warning( "%d > num_timestamps in MS (%d), reducing to %d", num_timestamps, timestamps_in_ms, timestamps_in_ms, ) num_timestamps = timestamps_in_ms # Pre-compute all times as integer/fraction values # as required by the ICD payload time_data = ms.read_column("TIME", 0, ms.num_rows) diffs = np.diff(time_data) if np.any(diffs < 0.0): raise ValueError("Timestamps must be non-decreasing") # Add an additional diff to the end so we verify the number of baselines in the last TS baseline_ts_change_idxs = np.flatnonzero(np.concatenate([diffs, [1]])) + 1 num_baselines_per_ts = np.diff(baseline_ts_change_idxs) if np.any(num_baselines_per_ts != num_baselines): raise ValueError( f"Data per timestamp ({num_baselines_per_ts}) doesn't " f"match expected number of baselines ({num_baselines})" ) time_offset = 0 if repeat_idx != 0: if num_timestamps > 1: time_range = time_data[num_timestamps * num_baselines - 1] - time_data[0] time_increment = time_range / (num_timestamps - 1) time_offset = repeat_idx * (time_range + time_increment) else: # no increment can be calculated so use 1s instead time_offset = repeat_idx * 1.0 # Read until we exhaust the MS start_row = 0 timestamp_count = 0 while start_row < ms.num_rows and timestamp_count < num_timestamps: yield time_data[start_row] + time_offset, start_row, num_baselines start_row += num_baselines timestamp_count += 1
[docs] class MSWriter(ContextManager): """ A class that handles the writing of data into a new MeasurementSet. """ def __init__( self, output_filename: str, scan: Scan, antennas: Sequence[Antenna], baselines: Optional[Baselines] = None, plasma_socket: Optional[str] = None, ): """ Handles the writing of a measurement set using icd payloads Args: output_filename (str): the output file name tm (BaseTM): the telescope model scan (Scan): the scan model plasma_socket (Optional[str], optional): _description_. Defaults to None. """ num_stations = len(antennas) if not baselines: baselines = Baselines.generate(num_stations, True) self._use_plasmastman = plasma_socket is not None self._row_offset: Optional[int] = None self.num_baselines = len(baselines) self.num_pols = scan.scan_type.num_pols self._scan = scan ms_type = "Plasma" if self._use_plasmastman else "regular" "Creating %s MS at %s for scan %i, %d stations, %d baselines and %d channels", ms_type, output_filename, scan.scan_number, num_stations, self.num_baselines, scan.scan_type.num_channels, ) = MeasurementSet.create(output_filename, scan, antennas, baselines, plasma_socket) def _write_data( self, row: int, scan_id: int, beam_id: str, sw_id: str, pol_id: str, time: float, interval: float, exposure: float, first_chan_idx: int, chan_count: int, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """Writes both coordinate and visibility data for a given timestamp dump""" logger.debug( "Writing visibility data at row %d for channel indices %d:%d", row, first_chan_idx, first_chan_idx + chan_count, ) # populate missing and current rows with flags if not self._use_plasmastman:, row), first_chan_idx, chan_count, self.num_baselines, vis), first_chan_idx, chan_count, self.num_baselines, None) # TODO: unsafe assumption beam_idx matches field_idx # TODO: unsafe assumption beam_id matches ms_beam_id' # TODO: Shouldn't access private members # pylint: disable=protected-access beam =[beam_id] ms_field_id: int = list(, field_id=ms_field_id), sw_id=sw_id, pol_id=pol_id), 0, self.num_pols, self.num_baselines, None), 0, self.num_pols, self.num_baselines, None), self.num_baselines, scan_id), self.num_baselines, uvw, interval, exposure, time) def _calculate_insertion_row(self, payload_seq_no: int, payload_time: float) -> int: """ Calculates the row number to write a spectral window data slice. If payload_seq_no zero or negative then the then row will calculated from bisecting the TIME column using payload_time. Args: payload_seq_no (int): Spectral window sequence number. payload_time (float): Payload time used only when payload_seq_no is negative or zero. Returns: int: The row number """ if payload_seq_no < 0: if is None: row = 1 else: row = + 1 else: row = payload_seq_no * self.num_baselines if self._row_offset is None: self._row_offset = row return row - self._row_offset def write_data_row( self, scan_id: int, beam: Beam, sw: SpectralWindow, payload_seq_no: int, mjd_time: float, interval: float, exposure: float, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """ Write a full spectral window's worth of data to the main table using payload_seq_no (or mjd_time) to determine the row location. """ num_channels = vis.shape[1] assert sw.count == num_channels, "Visibility channels must match spectral window size" self.write_data( scan_id, beam, sw, payload_seq_no, mjd_time, interval, exposure, 0, num_channels, uvw, vis, ) def write_data( self, scan_id: int, beam: Beam, sw: SpectralWindow, payload_seq_no: int, mjd_time: float, interval: float, exposure: float, first_chan: int, chan_count: int, uvw: np.ndarray, vis: Union[TensorRef, np.ndarray], ): """ Writes a spectral window data slice to the main table using payload_seq_no (or mjd_time) to determine the row location. Note: first_chan refers to the first channel **index** to be be written. """ warnings.warn( "Writing subsets of channel data is deprecated. Use write_data_row() instead", DeprecationWarning, ) # Find out the row for this payload/timestamp row = self._calculate_insertion_row(payload_seq_no, mjd_time) self._write_data( row, scan_id, beam.beam_id, sw.spectral_window_id, beam.polarisations.polarisation_id, mjd_time, interval, exposure, first_chan, chan_count, uvw, vis, ) "Written data for %f %s %s to row %i", mjd_time, beam.beam_id, sw.spectral_window_id, row, ) def write_pointings(self, pointings: Sequence[Pointing]): def close(self): def __enter__(self): """context manager protocol""" return self def __exit__(self, exc_type, exc_val, exc_tb): """context manager protocol""" self.close()