Source code for realtime.receive.modules.aggregators.payload_aggregator

import collections
import dataclasses
import itertools
import logging

from realtime.receive.core import icd

from realtime.receive.modules.aggregators.config import AggregationConfig
from realtime.receive.modules.indexer import Indexer
from realtime.receive.modules.payload_data import PayloadData

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class PayloadTimeGroup:
    """Collection of PayloadData where timestamps are all within a certain time window.

    The time window defined by ``[start_time_index, stop_time_index)`` fully contains all
    payloads, but it might be bigger than the time window actually covered by
    the payloads' times.
    """

    start_time_index: int
    """Start time (inclusive) payload sequence number."""

    stop_time_index: int
    """Stop time (exclusive) payload sequence number."""

    payloads: list[PayloadData] = dataclasses.field(default_factory=list)
    """Received visibility payload data"""

    def __post_init__(self):
        """Validate that all payloads fall within the declared time window."""
        for payload in self.payloads:
            if not self.start_time_index <= payload.sequence_number < self.stop_time_index:
                raise ValueError(
                    f"Payload sequence number {payload.sequence_number} is not within "
                    f"the time window [{self.start_time_index}, {self.stop_time_index})"
                )

    @property
    def time_window_size(self):
        """The size of the time window of this group."""
        return self.stop_time_index - self.start_time_index


@dataclasses.dataclass
class _PayloadTimeGroupList:
    """Contains the data collected during a flush operation of the PayloadAggregator.

    Used to pass payload groups from the aggregator to the next stage of processing.
    """

    payload_groups: list[PayloadTimeGroup] = dataclasses.field(default_factory=list)
    """Groups of payloads taken from the PayloadAggregator for different time windows"""

    dropped_payloads: int = 0
    """Number of payloads that have been dropped (part of no aggregation)."""

    @property
    def payload_count(self):
        """The total payload count in this aggregation operation."""
        return sum(len(group.payloads) for group in self.payload_groups)


[docs] class PayloadAggregator: """Collects individual payloads and groups them by time window on flushes. Users should call :meth:`add_payload` whenever a new ICD payload is received and needs to be stored for later aggregation. Aggregation itself does **not** happen in this class. This class prepares the payloads for aggregation by providing them in structured time-based groups when ``flush()`` is called. The actual creation of Visibility objects happens in the VisibilityGenerator class. A top-level ``Aggregator`` class has been created to manage the complete flow from payload aggregation through visibility generation to consumption. Aggregation of payloads happens when ``flush()`` is called, either periodically or manually. During a flush, the payloads are grouped by time windows based on the configured aggregation rules. Payloads that are incomplete or arrive too late may be dropped, depending on the configured tolerance. See :class:`AggregationConfig` for more details. When finished using this class, users should ensure all payloads have been flushed. Detailed intended usage: The primary use case for this class is to collect payloads from multiple streams and organize them into time-based groups for later aggregation into Visibility objects. There is a two dimensional aggregation that is happening here. The first dimension is the time dimension and the second the frequency dimension. The purpose is to reduce the transaction count between the receiver and the plasma store. Intended behaviour: The receiver passes all received payloads into the aggregator which adds them to a stack. A flush is triggered to retrieve the group payloads. Only timesteps in the stack that are "complete" are assembled into a visibility object. TODO: There are two possible failure modes that need to be addressed. The first is if a timestamp is lost from a stream. There is no error correction for this and the visibility object will be incomplete. To avoid the receiver being blocked by a stream that is lost the aggregation waits for a certain number of timestamps (tolerance) before it flushes the incomplete visibility object. """ def __init__( self, config: AggregationConfig, num_streams: int, ): self._config = config self._time_indexer = Indexer[float]() """ This simple mechanism is used for guessing the time index for a received ICD payload (which is later on used, for example, to calculate the row number under which the visibilities in the payload should be written to in a Measurement Set). These payloads contain a timestamp but no reference to their T0. Using this class to infer a time index for these payloads works by assuming payloads are received in order (which might not hold), and because payloads contain data for all visibilities (but maybe a portion of the channels), so the MS row calculation doesn't break. """ self._added_payloads = 0 self._next_aggregation_start = 0 self._last_tstart = 0 # When aggregating we keep track of payloads on each time step, # which we differentiate with the payload's sequence number self._payloads_by_sequence_number: dict[int, list[PayloadData]] = collections.defaultdict( list ) self._num_streams = num_streams @property def _all_payloads(self): return itertools.chain.from_iterable(self._payloads_by_sequence_number.values()) @property def _has_payloads(self) -> bool: return next(self._all_payloads, None) is not None @property def added_payloads(self) -> int: """The number of data payloads that have been added to this object.""" return self._added_payloads
[docs] def add_payload(self, payload: icd.Payload) -> None: """Add a single payload to be considered for aggregation in the next Visibility object.""" payload_sequence_no = self._time_indexer.get_index(payload.timestamp) payload_data = PayloadData.from_payload(payload, payload_sequence_no) sequence_number_payloads = self._payloads_by_sequence_number[payload_sequence_no] sequence_number_payloads.append(payload_data) self._added_payloads += 1
[docs] def reset_time_indexing(self) -> None: """Reset the internal time indexing used to calculate payload sequence numbers. NOTE: This shouldn't be normally called, but can be useful in some situations; e.g., when receiving data for different scans that is tagged with the same timestamps. :raises RuntimeError: no payloads have been added since the last :meth:`flush`. """ if self._has_payloads: raise RuntimeError( "Payloads have been added since last flush, can't reset time indexing" ) self._time_indexer.reset() self._next_aggregation_start = 0
[docs] def flush(self, *, full_flush: bool = False) -> PayloadTimeGroup | None: """Flush the aggregated payloads into groups for visibility generation. Groups payloads into time windows and prepares them for conversion into Visibility objects. May drop payloads that are too old. :param full_flush: If True, flush all payload buffers, else flush only full or timed-out payload buffers. defaults to False. """ if not self._has_payloads: return None # get the time window we should consider for this aggregation # and the groups of payload that will result in individual Visibilities time_start, time_stop = self._calculate_visibility_time_window(full_flush=full_flush) time_groups = self._pop_payloads_by_time_groups( time_start, time_stop, full_flush=full_flush ) if len(time_groups.payload_groups) == 0: return None # adapt to new interface with unknown scan id time_group = PayloadTimeGroup( start_time_index=min( payload.start_time_index for payload in time_groups.payload_groups ), stop_time_index=max(payload.stop_time_index for payload in time_groups.payload_groups) + 1, payloads=[ payload for payload_group in time_groups.payload_groups for payload in payload_group.payloads ], ) if not time_groups.payload_count and time_groups.dropped_payloads: logger.warning( "%d payloads that arrived too late have been dropped", time_groups.dropped_payloads, ) return time_group
def _calculate_visibility_time_window(self, *, full_flush: bool) -> tuple[int, int]: """Find the iterval of time-indexes to be flushed. If no visibilities should be aggregated, then `t_start == t_stop`. If the configuration integration_intervals_per_aggregation is set to a positive the time window attempts to meet this requirement. :param full_flush: when true uses a tolerance of 0 to flush all possible payloads :return: the left-closed half-open interval (`[t_start, t_stop)`) """ t_start = max( min(payload.sequence_number for payload in self._all_payloads), self._next_aggregation_start, ) t_max = max( max(payload.sequence_number for payload in self._all_payloads), t_start, ) t_stop_max = t_max + 1 tolerance = 0 if full_flush else self._config.tolerance # find highest time step where all streams have data # t_stop will be the time step after that one t_stop = max(t_stop_max - tolerance, t_start) while t_stop < t_stop_max and self._all_payloads_received_for_t(t_stop): t_stop += 1 return t_start, t_stop def _pop_payloads_by_time_groups( self, start: int, stop: int, *, full_flush: bool ) -> _PayloadTimeGroupList: """Collect and return groups of payloads for different time windows. In order to respect user settings, this method might adjust the stop time for the next aggregation, which is returned alongside the payload groups in the PayloadTimeGroupList object. """ next_aggregation_start = stop collection = _PayloadTimeGroupList() # Find times for which we will drop data times_to_drop = list( itertools.takewhile(lambda t: t < start, self._payloads_by_sequence_number.keys()) ) for time_to_drop in times_to_drop: collection.dropped_payloads += len(self._payloads_by_sequence_number.pop(time_to_drop)) # Get payload groups for each time window as necessary requested_size = self._config.num_timestamps_per_aggregation if requested_size > 0: for group_start in range(start, stop, requested_size): group_end = group_start + requested_size # Force the creation of smaller Visibility when flushing fully # otherwise ensure those payloads remain in the aggregator for a future flush if group_end > stop: if not full_flush: next_aggregation_start = group_start continue group_end = stop collection.payload_groups.append( self._pop_payloads_between(group_start, group_end) ) elif start == stop: return collection else: collection.payload_groups.append(self._pop_payloads_between(start, stop)) if collection.payload_count > 0: self._next_aggregation_start = next_aggregation_start return collection def _all_payloads_received_for_t(self, t: int): return len(self._payloads_by_sequence_number[t]) == self._num_streams def _payloads_in_range(self, start: int, stop: int): return sum(1 for payload in self._all_payloads if start <= payload.sequence_number < stop) def _pop_payloads_between(self, start: int, stop: int) -> PayloadTimeGroup: payload_group = PayloadTimeGroup(start, stop) all_payload_times = list(self._payloads_by_sequence_number.keys()) for payload_time in all_payload_times: if payload_time >= stop: continue payload_group.payloads += self._payloads_by_sequence_number.pop(payload_time) if __debug__: remaining_payloads_in_range = self._payloads_in_range(start, stop) if remaining_payloads_in_range: raise AssertionError( f"{remaining_payloads_in_range} payloads in [{start}, {stop})" ) return payload_group