import asyncio
import collections
import dataclasses
import itertools
import logging
import time
from contextlib import AbstractAsyncContextManager
from typing import Sequence
from realtime.receive.core import icd
from realtime.receive.core.common import autocast_fields
from ska_sdp_datamodels.visibility.vis_model import Visibility
from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.indexer import Indexer
from realtime.receive.modules.payload_data import PayloadData
from realtime.receive.modules.scan_providers import ScanProvider
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils.periodic_task import PeriodicTask
from realtime.receive.modules.utils.sdp_visibility import VisibilityBuilder
logger = logging.getLogger(__name__)
[docs]
@dataclasses.dataclass
@autocast_fields
class AggregationConfig:
"""
Options controlling how data aggregation happens in the receiver before
it's handed over to Consumers.
"""
time_period: float = 5
"""
Period, in seconds, after which payloads should be aggregated. If this is
a non-positive number then aggregation doesn't happen in the background,
but can still be triggered manually. It is still triggered automatically
at shutdown regardless.
"""
integration_interval_tolerance: int = 5
"""
Number of integration intervals to avoid considering in an aggregation step
to cater for slower streams that haven't caught up with receiving some of
their data.
"""
payloads_per_backoff: int = 20
"""
This is a *HIGHLY TECHNICAL* option, so only change if you know what you're
doing.
Number of payloads to put into the temporary `VisibilityBuilder` object
during the periodic background aggregation before giving control back to
the IO loop. Adding payloads into the `VisibilityBuilder` is a CPU
intensive task, so if many payloads are being added this can stall the
other coroutines that are executing in the system concurrently under the
same IO loop. Yielding back control to the IO loop every now and then
allows other coroutines to progress at the expense of this final
aggregation step taking longer.
"""
backoff_time: float = 0.0
"""
This is a *HIGHLY TECHNICAL* option, so only change if you know what you're
doing.
Time to back off for when yielding back control to the IO loop during
payload aggregation into a `VisibilityBuilder`. See `payloads_per_backoff`
for more details. By default we simply yield back without requesting any
sleeping time in between.
"""
@property
def tolerance(self):
"""Shorthand for self.integration_interval_tolerance"""
return self.integration_interval_tolerance
[docs]
class PayloadAggregator(AbstractAsyncContextManager):
"""
A class that takes individual payloads and aggregates them following the
given settings. Aggregated data is put together into a Visiblity object,
and given to a Consumer object to consume.
Users need to call ``add_payload`` every time a new ICD payload that needs
to be aggregated is found.
Aggregation happens asynchronously and periodically. Users might also want
to trigger the aggregation of any pending payloads and the consumption of
the resulting Visibility objects, if any, by calling the ``flush`` method.
Finally, the ``aclose`` method should be invoked when users are done using
this class, which ensures all background tasks are completed, and a final
flush of pending payloads. This behavior is also available when using
objects of this class as an asynchronous context manager.
"""
def __init__(
self,
config: AggregationConfig,
consumer: Consumer,
scan_provider: ScanProvider,
tm: TelescopeManager,
num_streams: int | None = None,
first_channel_id: int | None = None,
):
self._config = config
self._consumer = consumer
self._scan_provider = scan_provider
self._tm = tm
self._time_indexer = Indexer()
self._added_payloads = 0
self._aggregated_payloads = 0
self._visibilities_generated = 0
self._visibilities_consumed = 0
self._visibilities_flagged_fraction = 0.0
self._next_aggregation_start = 0
# When aggregating we keep track of payloads on each time step,
# which we differentiate with the payload's sequence number
self._payloads_by_sequence_number: dict[int, list[PayloadData]] = collections.defaultdict(
list
)
self._num_streams = num_streams
self._first_channel_id = first_channel_id
self._periodic_aggregation_task = (
PeriodicTask(config.time_period, self._flush) if config.time_period > 0 else None
)
# _flush() is called in a background task *and* directly via flush(),
# we want to make sure only one of them runs at a time
self._flush_lock = asyncio.Lock()
@property
def _all_payloads(self):
return itertools.chain.from_iterable(self._payloads_by_sequence_number.values())
@property
def _payloads_added_since_last_flush(self) -> bool:
return next(self._all_payloads, None) is not None
@property
def added_payloads(self) -> int:
"""The number of data payloads that have been added to this object."""
return self._added_payloads
@property
def aggregated_payloads(self) -> int:
"""
The number of data payloads that have been aggregated into Visibility
objects.
"""
return self._aggregated_payloads
@property
def visibilities_generated(self) -> int:
"""The number of visibilities that have been generated."""
return self._visibilities_generated
@property
def visibilities_consumed(self) -> int:
"""The number of visibilities that have been successfully consumed."""
return self._visibilities_consumed
@property
def visibilities_flagged_fraction(self) -> float:
"""
The faction of data flagged as missing in the last generated
Visibility.
"""
return self._visibilities_flagged_fraction
@property
def consumer(self) -> Consumer:
"""The consumer this aggregator forwards visibilities to"""
return self._consumer
async def _consume(self, visibility: Visibility) -> None:
self._visibilities_generated += 1
try:
await self._consumer.consume(visibility)
except Exception: # pylint: disable=broad-except
logger.exception("Unexpected error while consuming visibility")
else:
self._visibilities_consumed += 1
[docs]
def add_payload(self, payload: icd.Payload) -> None:
"""
Add a single payload to be considered for aggregation in the next
Visibility object.
:param payload: The payload to add
"""
# we have been informed about the total number of streams. If not,
# we shouldn't allow adding payloads as that will trigger the usage of
# this member in the asynchronous aggregation task
assert self._num_streams is not None
current_scan = self._scan_provider.current_scan
if current_scan is None:
logger.warning(
"Skipping payload because there is currently no scan being observed",
)
return
if current_scan.scan_number != payload.scan_id:
logger.warning(
"Skipping payload because its scan ID doesn't match the current scan "
"being observed (%d != %d)",
payload.scan_id,
current_scan.scan_number,
)
return
payload_sequence_no = self._time_indexer.get_index(payload.timestamp)
payload_data = PayloadData.from_payload(payload, payload_sequence_no, current_scan)
sequence_number_payloads = self._payloads_by_sequence_number[payload_sequence_no]
sequence_number_payloads.append(payload_data)
self._added_payloads += 1
[docs]
async def flush(self, ignore_tolerance=True) -> None:
"""
Create and consume Visibility objects for any payloads pending
aggregation. If a background aggregation is currently taking place,
we wait for it to finish before running a new one.
:param ignore_tolerance: Whether to ignore the tolerance settings, in
which case all data recieved until now will be included in the
resulting `Visibility`.
"""
await self._flush(ignore_tolerance=ignore_tolerance)
[docs]
def reset_time_indexing(self) -> None:
"""
Reset the internal time indexing used to calculate payload sequence
numbers. This shouldn't be normally called, but can be useful in some
situations; e.g., when receiving data for different scans that is
tagged with the same timestamps.
This method can only be called if no payloads have been called since
the last `flush`; otherwise a ``RuntimeError`` is raised.
"""
if self._payloads_added_since_last_flush:
raise RuntimeError(
"Payloads have been added since last flush, can't reset time indexing"
)
self._time_indexer.reset()
self._next_aggregation_start = 0
[docs]
async def add_payloads_and_flush(
self, payloads: Sequence[icd.Payload], **flush_kwargs
) -> None:
"""Adds all payloads and tries to immediately create a Visibility."""
for payload in payloads:
self.add_payload(payload)
await self.flush(**flush_kwargs)
@staticmethod
def _sorted_by_time(payloads: Sequence[PayloadData]):
return not payloads or all(
payloads[i].sequence_number <= payloads[i + 1].sequence_number
for i in range(len(payloads) - 1)
)
async def _flush(self, ignore_tolerance=False) -> None:
async with self._flush_lock:
await self._flush_unprotected(ignore_tolerance)
async def _flush_unprotected(self, ignore_tolerance: bool) -> None:
timer_start = time.monotonic()
if not self._payloads_added_since_last_flush:
return
# get the time window we'll build a Visibility for,
# pop those payloads and put them into a Visibility
t_start, t_stop = self._calculate_visibility_time_window(ignore_tolerance)
payloads_to_add, payloads_to_drop = self._pop_payloads_between(t_start, t_stop)
if not payloads_to_add:
if payloads_to_drop:
logger.warning(
"%d payloads that arrived too late are being dropped",
len(payloads_to_drop),
)
return
self._next_aggregation_start = t_stop
vis_builder = await self._build_visibility(
t_start,
t_stop,
payloads_to_add,
)
assert not vis_builder.is_empty
self._visibilities_flagged_fraction = vis_builder.flagged_fraction
visibility = vis_builder.get_visibility()
timer_end = time.monotonic()
times, baselines, channels, pols = visibility.vis.shape
logger.info(
"%6d payloads aggregated in %6.3f [s] into Visibility with shape "
"(%3d, %d, %d, %d) with time index in [%4d, %4d). "
"Flagged fraction: %.3f",
len(payloads_to_add),
timer_end - timer_start,
times,
baselines,
channels,
pols,
t_start,
t_stop,
self._visibilities_flagged_fraction,
)
await self._consume(visibility)
async def _build_visibility(
self, t_start: int, t_stop: int, payloads_to_add: Sequence[PayloadData]
):
# henceforth we assume all payloads carry the same #pols and #channels,
# and belong to the same scan
example_payload = payloads_to_add[0]
num_pols = example_payload.visibilities.shape[-1]
channels_per_payload = example_payload.channel_count
scan = example_payload.scan
# TODO (rtobar)
# This is not correct if no payloads have been added for the first stream,
# and is a reflection of the fact that this information comes only from
# the SPEAD streams, which is an unreliable source. Instead we should be
# told explicitly what's the first channel ID that any of the receiver's
# streams should expect data for (possibly covered in ADR-81).
first_channel_id = min(payload.first_channel_id for payload in payloads_to_add)
vis_shape = (
t_stop - t_start,
self._tm.num_baselines,
self._num_streams * channels_per_payload,
num_pols,
)
vis_builder = VisibilityBuilder(vis_shape, scan, self._tm, t_start, first_channel_id)
# start at 1 to avoid yielding in the first iteration unnecessarily
for i, payload in enumerate(payloads_to_add, start=1):
if i % self._config.payloads_per_backoff == 0:
await asyncio.sleep(self._config.backoff_time)
assert (
t_start <= payload.sequence_number < t_stop
), f"constraint not satisfied: {t_start} <= {payload.sequence_number} < {t_stop}"
vis_builder.add_payload(payload)
self._aggregated_payloads += len(payloads_to_add)
return vis_builder
def _calculate_visibility_time_window(self, ignore_tolerance: bool):
"""
Find the half-open interval (`[t_start, t_stop)`) of
time-indexes that should be aggregated next. If no visibilities
should be aggregated, then `t_start == t_stop`.
"""
t_start = max(
min(payload.sequence_number for payload in self._all_payloads),
self._next_aggregation_start,
)
t_max = max(
max(payload.sequence_number for payload in self._all_payloads),
t_start,
)
t_stop_max = t_max + 1
tolerance = 0 if ignore_tolerance else self._config.tolerance
# find highest time step where all streams have data
# t_stop will be the time step after that one
t_stop = max(t_stop_max - tolerance, t_start)
while t_stop < t_stop_max and self._all_payloads_received_for_t(t_stop):
t_stop += 1
assert t_start <= t_stop <= t_stop_max
return t_start, t_stop
def _all_payloads_received_for_t(self, t: int):
return len(self._payloads_by_sequence_number[t]) == self._num_streams
def _payloads_in_range(self, start: int, stop: int):
return sum(1 for payload in self._all_payloads if start <= payload.sequence_number < stop)
def _pop_payloads_between(self, start: int, stop: int):
dropped_payloads = []
popped_payloads = []
all_payload_times = list(self._payloads_by_sequence_number.keys())
for payload_time in all_payload_times:
if payload_time >= stop:
continue
payloads = self._payloads_by_sequence_number.pop(payload_time)
if payload_time < start:
dropped_payloads += payloads
else:
popped_payloads += payloads
if __debug__:
remaining_payloads_in_range = self._payloads_in_range(start, stop)
if remaining_payloads_in_range:
raise AssertionError(
f"{remaining_payloads_in_range} payloads in [{start}, {stop})"
)
return popped_payloads, dropped_payloads
[docs]
async def astart(self):
"""Start the background aggregation task, if required."""
if self._periodic_aggregation_task:
await self._periodic_aggregation_task.astart()
return self
[docs]
async def astop(self) -> None:
"""Finish all background tasks and flush all pending payloads."""
if self._periodic_aggregation_task:
await self._periodic_aggregation_task.astop()
await self.flush()
self.reset_time_indexing()
async def __aenter__(self):
await self.astart()
return self
async def __aexit__(self, *_args, **_kwargs):
await self.astop()