Source code for realtime.receive.modules.aggregation

import asyncio
import collections
import dataclasses
import itertools
import logging
import time
from contextlib import AbstractAsyncContextManager
from typing import Sequence

from realtime.receive.core import icd
from realtime.receive.core.common import autocast_fields
from ska_sdp_datamodels.visibility.vis_model import Visibility

from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.indexer import Indexer
from realtime.receive.modules.payload_data import PayloadData
from realtime.receive.modules.scan_providers import ScanProvider
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils.periodic_task import PeriodicTask
from realtime.receive.modules.utils.sdp_visibility import VisibilityBuilder

logger = logging.getLogger(__name__)


[docs] @dataclasses.dataclass @autocast_fields class AggregationConfig: """ Options controlling how data aggregation happens in the receiver before it's handed over to Consumers. """ time_period: float = 5 """ Period, in seconds, after which payloads should be aggregated. If this is a non-positive number then aggregation doesn't happen in the background, but can still be triggered manually. It is still triggered automatically at shutdown regardless. """ integration_interval_tolerance: int = 5 """ Number of integration intervals to avoid considering in an aggregation step to cater for slower streams that haven't caught up with receiving some of their data. """ payloads_per_backoff: int = 20 """ This is a *HIGHLY TECHNICAL* option, so only change if you know what you're doing. Number of payloads to put into the temporary `VisibilityBuilder` object during the periodic background aggregation before giving control back to the IO loop. Adding payloads into the `VisibilityBuilder` is a CPU intensive task, so if many payloads are being added this can stall the other coroutines that are executing in the system concurrently under the same IO loop. Yielding back control to the IO loop every now and then allows other coroutines to progress at the expense of this final aggregation step taking longer. """ backoff_time: float = 0.0 """ This is a *HIGHLY TECHNICAL* option, so only change if you know what you're doing. Time to back off for when yielding back control to the IO loop during payload aggregation into a `VisibilityBuilder`. See `payloads_per_backoff` for more details. By default we simply yield back without requesting any sleeping time in between. """ @property def tolerance(self): """Shorthand for self.integration_interval_tolerance""" return self.integration_interval_tolerance
[docs] class PayloadAggregator(AbstractAsyncContextManager): """ A class that takes individual payloads and aggregates them following the given settings. Aggregated data is put together into a Visiblity object, and given to a Consumer object to consume. Users need to call ``add_payload`` every time a new ICD payload that needs to be aggregated is found. Aggregation happens asynchronously and periodically. Users might also want to trigger the aggregation of any pending payloads and the consumption of the resulting Visibility objects, if any, by calling the ``flush`` method. Finally, the ``aclose`` method should be invoked when users are done using this class, which ensures all background tasks are completed, and a final flush of pending payloads. This behavior is also available when using objects of this class as an asynchronous context manager. """ def __init__( self, config: AggregationConfig, consumer: Consumer, scan_provider: ScanProvider, tm: TelescopeManager, num_streams: int | None = None, first_channel_id: int | None = None, ): self._config = config self._consumer = consumer self._scan_provider = scan_provider self._tm = tm self._time_indexer = Indexer() self._added_payloads = 0 self._aggregated_payloads = 0 self._visibilities_generated = 0 self._visibilities_consumed = 0 self._visibilities_flagged_fraction = 0.0 self._next_aggregation_start = 0 # When aggregating we keep track of payloads on each time step, # which we differentiate with the payload's sequence number self._payloads_by_sequence_number: dict[int, list[PayloadData]] = collections.defaultdict( list ) self._num_streams = num_streams self._first_channel_id = first_channel_id self._periodic_aggregation_task = ( PeriodicTask(config.time_period, self._flush) if config.time_period > 0 else None ) # _flush() is called in a background task *and* directly via flush(), # we want to make sure only one of them runs at a time self._flush_lock = asyncio.Lock() @property def _all_payloads(self): return itertools.chain.from_iterable(self._payloads_by_sequence_number.values()) @property def _payloads_added_since_last_flush(self) -> bool: return next(self._all_payloads, None) is not None
[docs] def inform_num_streams(self, num_streams: int) -> None: """ Inform this object what the total number of streams that are being received are. """ self._num_streams = num_streams
@property def added_payloads(self) -> int: """The number of data payloads that have been added to this object.""" return self._added_payloads @property def aggregated_payloads(self) -> int: """ The number of data payloads that have been aggregated into Visibility objects. """ return self._aggregated_payloads @property def visibilities_generated(self) -> int: """The number of visibilities that have been generated.""" return self._visibilities_generated @property def visibilities_consumed(self) -> int: """The number of visibilities that have been successfully consumed.""" return self._visibilities_consumed @property def visibilities_flagged_fraction(self) -> float: """ The faction of data flagged as missing in the last generated Visibility. """ return self._visibilities_flagged_fraction @property def consumer(self) -> Consumer: """The consumer this aggregator forwards visibilities to""" return self._consumer async def _consume(self, visibility: Visibility) -> None: self._visibilities_generated += 1 try: await self._consumer.consume(visibility) except Exception: # pylint: disable=broad-except logger.exception("Unexpected error while consuming visibility") else: self._visibilities_consumed += 1
[docs] def add_payload(self, payload: icd.Payload) -> None: """ Add a single payload to be considered for aggregation in the next Visibility object. :param payload: The payload to add """ # we have been informed about the total number of streams. If not, # we shouldn't allow adding payloads as that will trigger the usage of # this member in the asynchronous aggregation task assert self._num_streams is not None current_scan = self._scan_provider.current_scan if current_scan is None: logger.warning( "Skipping payload because there is currently no scan being observed", ) return if current_scan.scan_number != payload.scan_id: logger.warning( "Skipping payload because its scan ID doesn't match the current scan " "being observed (%d != %d)", payload.scan_id, current_scan.scan_number, ) return payload_sequence_no = self._time_indexer.get_index(payload.timestamp) payload_data = PayloadData.from_payload(payload, payload_sequence_no, current_scan) sequence_number_payloads = self._payloads_by_sequence_number[payload_sequence_no] sequence_number_payloads.append(payload_data) self._added_payloads += 1
[docs] async def flush(self, ignore_tolerance=True) -> None: """ Create and consume Visibility objects for any payloads pending aggregation. If a background aggregation is currently taking place, we wait for it to finish before running a new one. :param ignore_tolerance: Whether to ignore the tolerance settings, in which case all data recieved until now will be included in the resulting `Visibility`. """ await self._flush(ignore_tolerance=ignore_tolerance)
[docs] def reset_time_indexing(self) -> None: """ Reset the internal time indexing used to calculate payload sequence numbers. This shouldn't be normally called, but can be useful in some situations; e.g., when receiving data for different scans that is tagged with the same timestamps. This method can only be called if no payloads have been called since the last `flush`; otherwise a ``RuntimeError`` is raised. """ if self._payloads_added_since_last_flush: raise RuntimeError( "Payloads have been added since last flush, can't reset time indexing" ) self._time_indexer.reset() self._next_aggregation_start = 0
[docs] async def add_payloads_and_flush( self, payloads: Sequence[icd.Payload], **flush_kwargs ) -> None: """Adds all payloads and tries to immediately create a Visibility.""" for payload in payloads: self.add_payload(payload) await self.flush(**flush_kwargs)
@staticmethod def _sorted_by_time(payloads: Sequence[PayloadData]): return not payloads or all( payloads[i].sequence_number <= payloads[i + 1].sequence_number for i in range(len(payloads) - 1) ) async def _flush(self, ignore_tolerance=False) -> None: async with self._flush_lock: await self._flush_unprotected(ignore_tolerance) async def _flush_unprotected(self, ignore_tolerance: bool) -> None: timer_start = time.monotonic() if not self._payloads_added_since_last_flush: return # get the time window we'll build a Visibility for, # pop those payloads and put them into a Visibility t_start, t_stop = self._calculate_visibility_time_window(ignore_tolerance) payloads_to_add, payloads_to_drop = self._pop_payloads_between(t_start, t_stop) if not payloads_to_add: if payloads_to_drop: logger.warning( "%d payloads that arrived too late are being dropped", len(payloads_to_drop), ) return self._next_aggregation_start = t_stop vis_builder = await self._build_visibility( t_start, t_stop, payloads_to_add, ) assert not vis_builder.is_empty self._visibilities_flagged_fraction = vis_builder.flagged_fraction visibility = vis_builder.get_visibility() timer_end = time.monotonic() times, baselines, channels, pols = visibility.vis.shape logger.info( "%6d payloads aggregated in %6.3f [s] into Visibility with shape " "(%3d, %d, %d, %d) with time index in [%4d, %4d). " "Flagged fraction: %.3f", len(payloads_to_add), timer_end - timer_start, times, baselines, channels, pols, t_start, t_stop, self._visibilities_flagged_fraction, ) await self._consume(visibility) async def _build_visibility( self, t_start: int, t_stop: int, payloads_to_add: Sequence[PayloadData] ): # henceforth we assume all payloads carry the same #pols and #channels, # and belong to the same scan example_payload = payloads_to_add[0] num_pols = example_payload.visibilities.shape[-1] channels_per_payload = example_payload.channel_count scan = example_payload.scan # TODO (rtobar) # This is not correct if no payloads have been added for the first stream, # and is a reflection of the fact that this information comes only from # the SPEAD streams, which is an unreliable source. Instead we should be # told explicitly what's the first channel ID that any of the receiver's # streams should expect data for (possibly covered in ADR-81). first_channel_id = min(payload.first_channel_id for payload in payloads_to_add) vis_shape = ( t_stop - t_start, self._tm.num_baselines, self._num_streams * channels_per_payload, num_pols, ) vis_builder = VisibilityBuilder(vis_shape, scan, self._tm, t_start, first_channel_id) # start at 1 to avoid yielding in the first iteration unnecessarily for i, payload in enumerate(payloads_to_add, start=1): if i % self._config.payloads_per_backoff == 0: await asyncio.sleep(self._config.backoff_time) assert ( t_start <= payload.sequence_number < t_stop ), f"constraint not satisfied: {t_start} <= {payload.sequence_number} < {t_stop}" vis_builder.add_payload(payload) self._aggregated_payloads += len(payloads_to_add) return vis_builder def _calculate_visibility_time_window(self, ignore_tolerance: bool): """ Find the half-open interval (`[t_start, t_stop)`) of time-indexes that should be aggregated next. If no visibilities should be aggregated, then `t_start == t_stop`. """ t_start = max( min(payload.sequence_number for payload in self._all_payloads), self._next_aggregation_start, ) t_max = max( max(payload.sequence_number for payload in self._all_payloads), t_start, ) t_stop_max = t_max + 1 tolerance = 0 if ignore_tolerance else self._config.tolerance # find highest time step where all streams have data # t_stop will be the time step after that one t_stop = max(t_stop_max - tolerance, t_start) while t_stop < t_stop_max and self._all_payloads_received_for_t(t_stop): t_stop += 1 assert t_start <= t_stop <= t_stop_max return t_start, t_stop def _all_payloads_received_for_t(self, t: int): return len(self._payloads_by_sequence_number[t]) == self._num_streams def _payloads_in_range(self, start: int, stop: int): return sum(1 for payload in self._all_payloads if start <= payload.sequence_number < stop) def _pop_payloads_between(self, start: int, stop: int): dropped_payloads = [] popped_payloads = [] all_payload_times = list(self._payloads_by_sequence_number.keys()) for payload_time in all_payload_times: if payload_time >= stop: continue payloads = self._payloads_by_sequence_number.pop(payload_time) if payload_time < start: dropped_payloads += payloads else: popped_payloads += payloads if __debug__: remaining_payloads_in_range = self._payloads_in_range(start, stop) if remaining_payloads_in_range: raise AssertionError( f"{remaining_payloads_in_range} payloads in [{start}, {stop})" ) return popped_payloads, dropped_payloads
[docs] async def astart(self): """Start the background aggregation task, if required.""" if self._periodic_aggregation_task: await self._periodic_aggregation_task.astart() return self
[docs] async def astop(self) -> None: """Finish all background tasks and flush all pending payloads.""" if self._periodic_aggregation_task: await self._periodic_aggregation_task.astop() await self.flush() self.reset_time_indexing()
async def __aenter__(self): await self.astart() return self async def __aexit__(self, *_args, **_kwargs): await self.astop()