import collections
import dataclasses
import itertools
import logging
from realtime.receive.core import icd
from realtime.receive.modules.aggregators.config import AggregationConfig
from realtime.receive.modules.indexer import Indexer
from realtime.receive.modules.payload_data import PayloadData
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class PayloadTimeGroup:
"""Collection of PayloadData where timestamps are all within a certain time window.
The time window defined by ``[start_time_index, stop_time_index)`` fully contains all
payloads, but it might be bigger than the time window actually covered by
the payloads' times.
"""
start_time_index: int
"""Start time (inclusive) payload sequence number."""
stop_time_index: int
"""Stop time (exclusive) payload sequence number."""
payloads: list[PayloadData] = dataclasses.field(default_factory=list)
"""Received visibility payload data"""
def __post_init__(self):
"""Validate that all payloads fall within the declared time window."""
for payload in self.payloads:
if not self.start_time_index <= payload.sequence_number < self.stop_time_index:
raise ValueError(
f"Payload sequence number {payload.sequence_number} is not within "
f"the time window [{self.start_time_index}, {self.stop_time_index})"
)
@property
def time_window_size(self):
"""The size of the time window of this group."""
return self.stop_time_index - self.start_time_index
@dataclasses.dataclass
class _PayloadTimeGroupList:
"""Contains the data collected during a flush operation of the PayloadAggregator.
Used to pass payload groups from the aggregator to the next stage of processing.
"""
payload_groups: list[PayloadTimeGroup] = dataclasses.field(default_factory=list)
"""Groups of payloads taken from the PayloadAggregator for different time windows"""
dropped_payloads: int = 0
"""Number of payloads that have been dropped (part of no aggregation)."""
@property
def payload_count(self):
"""The total payload count in this aggregation operation."""
return sum(len(group.payloads) for group in self.payload_groups)
[docs]
class PayloadAggregator:
"""Collects individual payloads and groups them by time window on flushes.
Users should call :meth:`add_payload` whenever a new ICD payload is received
and needs to be stored for later aggregation.
Aggregation itself does **not** happen in this class. This class prepares
the payloads for aggregation by providing them in structured time-based
groups when ``flush()`` is called. The actual creation of Visibility
objects happens in the VisibilityGenerator class. A top-level ``Aggregator``
class has been created to manage the complete flow from payload
aggregation through visibility generation to consumption.
Aggregation of payloads happens when ``flush()`` is called, either
periodically or manually. During a flush, the payloads are grouped by
time windows based on the configured aggregation rules. Payloads that
are incomplete or arrive too late may be dropped, depending on the
configured tolerance. See :class:`AggregationConfig` for more details.
When finished using this class, users should ensure all payloads have been
flushed.
Detailed intended usage:
The primary use case for this class is to collect payloads from multiple
streams and organize them into time-based groups for later aggregation
into Visibility objects.
There is a two dimensional aggregation that is happening here. The first
dimension is the time dimension and the second the frequency dimension.
The purpose is to reduce the transaction count between the receiver and the
plasma store.
Intended behaviour:
The receiver passes all received payloads into the aggregator which adds them
to a stack. A flush is triggered to retrieve the group payloads. Only timesteps
in the stack that are "complete" are assembled into a visibility object.
TODO:
There are two possible failure modes that need to be addressed. The first is
if a timestamp is lost from a stream. There is no error correction for this
and the visibility object will be incomplete.
To avoid the receiver being blocked by a stream that is lost the aggregation
waits for a certain number of timestamps (tolerance) before it flushes the
incomplete visibility object.
"""
def __init__(
self,
config: AggregationConfig,
num_streams: int,
):
self._config = config
self._time_indexer = Indexer[float]()
"""
This simple mechanism is used for guessing the time index for
a received ICD payload (which is later on used, for example, to calculate
the row number under which the visibilities in the payload should be written
to in a Measurement Set). These payloads contain a timestamp but no
reference to their T0. Using this class to infer a time index for these
payloads works by assuming payloads are received in order (which might not
hold), and because payloads contain data for all visibilities (but maybe a
portion of the channels), so the MS row calculation doesn't break.
"""
self._added_payloads = 0
self._next_aggregation_start = 0
self._last_tstart = 0
# When aggregating we keep track of payloads on each time step,
# which we differentiate with the payload's sequence number
self._payloads_by_sequence_number: dict[int, list[PayloadData]] = collections.defaultdict(
list
)
self._num_streams = num_streams
@property
def _all_payloads(self):
return itertools.chain.from_iterable(self._payloads_by_sequence_number.values())
@property
def _has_payloads(self) -> bool:
return next(self._all_payloads, None) is not None
@property
def added_payloads(self) -> int:
"""The number of data payloads that have been added to this object."""
return self._added_payloads
[docs]
def add_payload(self, payload: icd.Payload) -> None:
"""Add a single payload to be considered for aggregation in the next Visibility object."""
payload_sequence_no = self._time_indexer.get_index(payload.timestamp)
payload_data = PayloadData.from_payload(payload, payload_sequence_no)
sequence_number_payloads = self._payloads_by_sequence_number[payload_sequence_no]
sequence_number_payloads.append(payload_data)
self._added_payloads += 1
[docs]
def reset_time_indexing(self) -> None:
"""Reset the internal time indexing used to calculate payload sequence numbers.
NOTE: This shouldn't be normally called, but can be useful in some
situations; e.g., when receiving data for different scans that is
tagged with the same timestamps.
:raises RuntimeError: no payloads have been added since the last :meth:`flush`.
"""
if self._has_payloads:
raise RuntimeError(
"Payloads have been added since last flush, can't reset time indexing"
)
self._time_indexer.reset()
self._next_aggregation_start = 0
[docs]
def flush(self, *, full_flush: bool = False) -> PayloadTimeGroup | None:
"""Flush the aggregated payloads into groups for visibility generation.
Groups payloads into time windows and prepares them for conversion into
Visibility objects. May drop payloads that are too old.
:param full_flush: If True, flush all payload buffers, else flush only
full or timed-out payload buffers. defaults to False.
"""
if not self._has_payloads:
return None
# get the time window we should consider for this aggregation
# and the groups of payload that will result in individual Visibilities
time_start, time_stop = self._calculate_visibility_time_window(full_flush=full_flush)
time_groups = self._pop_payloads_by_time_groups(
time_start, time_stop, full_flush=full_flush
)
if len(time_groups.payload_groups) == 0:
return None
# adapt to new interface with unknown scan id
time_group = PayloadTimeGroup(
start_time_index=min(
payload.start_time_index for payload in time_groups.payload_groups
),
stop_time_index=max(payload.stop_time_index for payload in time_groups.payload_groups)
+ 1,
payloads=[
payload
for payload_group in time_groups.payload_groups
for payload in payload_group.payloads
],
)
if not time_groups.payload_count and time_groups.dropped_payloads:
logger.warning(
"%d payloads that arrived too late have been dropped",
time_groups.dropped_payloads,
)
return time_group
def _calculate_visibility_time_window(self, *, full_flush: bool) -> tuple[int, int]:
"""Find the iterval of time-indexes to be flushed.
If no visibilities should be aggregated, then `t_start == t_stop`.
If the configuration integration_intervals_per_aggregation is set to a
positive the time window attempts to meet this requirement.
:param full_flush: when true uses a tolerance of 0 to flush all possible payloads
:return: the left-closed half-open interval (`[t_start, t_stop)`)
"""
t_start = max(
min(payload.sequence_number for payload in self._all_payloads),
self._next_aggregation_start,
)
t_max = max(
max(payload.sequence_number for payload in self._all_payloads),
t_start,
)
t_stop_max = t_max + 1
tolerance = 0 if full_flush else self._config.tolerance
# find highest time step where all streams have data
# t_stop will be the time step after that one
t_stop = max(t_stop_max - tolerance, t_start)
while t_stop < t_stop_max and self._all_payloads_received_for_t(t_stop):
t_stop += 1
return t_start, t_stop
def _pop_payloads_by_time_groups(
self, start: int, stop: int, *, full_flush: bool
) -> _PayloadTimeGroupList:
"""Collect and return groups of payloads for different time windows.
In order to respect user settings, this method might adjust the stop time for the next
aggregation, which is returned alongside the payload groups in the
PayloadTimeGroupList object.
"""
next_aggregation_start = stop
collection = _PayloadTimeGroupList()
# Find times for which we will drop data
times_to_drop = list(
itertools.takewhile(lambda t: t < start, self._payloads_by_sequence_number.keys())
)
for time_to_drop in times_to_drop:
collection.dropped_payloads += len(self._payloads_by_sequence_number.pop(time_to_drop))
# Get payload groups for each time window as necessary
requested_size = self._config.num_timestamps_per_aggregation
if requested_size > 0:
for group_start in range(start, stop, requested_size):
group_end = group_start + requested_size
# Force the creation of smaller Visibility when flushing fully
# otherwise ensure those payloads remain in the aggregator for a future flush
if group_end > stop:
if not full_flush:
next_aggregation_start = group_start
continue
group_end = stop
collection.payload_groups.append(
self._pop_payloads_between(group_start, group_end)
)
elif start == stop:
return collection
else:
collection.payload_groups.append(self._pop_payloads_between(start, stop))
if collection.payload_count > 0:
self._next_aggregation_start = next_aggregation_start
return collection
def _all_payloads_received_for_t(self, t: int):
return len(self._payloads_by_sequence_number[t]) == self._num_streams
def _payloads_in_range(self, start: int, stop: int):
return sum(1 for payload in self._all_payloads if start <= payload.sequence_number < stop)
def _pop_payloads_between(self, start: int, stop: int) -> PayloadTimeGroup:
payload_group = PayloadTimeGroup(start, stop)
all_payload_times = list(self._payloads_by_sequence_number.keys())
for payload_time in all_payload_times:
if payload_time >= stop:
continue
payload_group.payloads += self._payloads_by_sequence_number.pop(payload_time)
if __debug__:
remaining_payloads_in_range = self._payloads_in_range(start, stop)
if remaining_payloads_in_range:
raise AssertionError(
f"{remaining_payloads_in_range} payloads in [{start}, {stop})"
)
return payload_group