import asyncio
import logging
import typing
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from deprecated import deprecated
from overrides import override
from realtime.receive.core import Scan
from realtime.receive.core.icd import Payload
from realtime.receive.modules.aggregators.config import AggregationConfig
from realtime.receive.modules.aggregators.ops.batch_by_payload_seq_no import BatchByPayloadSeqNo
from realtime.receive.modules.aggregators.ops.consumer_dispatcher import ConsumerDispatcher
from realtime.receive.modules.aggregators.ops.visibility_generator import VisibilityGenerator
from realtime.receive.modules.aggregators.payload_aggregator import PayloadAggregator
from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.reception_range import ReceiverReceptionRanges
from realtime.receive.modules.scan_providers import ScanProvider
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils.collections import keydefaultdict
from realtime.receive.modules.utils.periodic_task import PeriodicTask
from realtime.receive.modules.utils.sdp_visibility import find_visibility_beam
logger = logging.getLogger(__name__)
ScanTypeID: typing.TypeAlias = str
BeamPipelineIndex: typing.TypeAlias = tuple[ScanTypeID, Payload.BeamID]
class _BeamPipeline(typing.NamedTuple):
"""Aggregator pipeline for combining visibilities of the same beam.
Collects and groups payloads by time (via PayloadAggregator) and
converts payload groups to Visibility objects (via VisibilityGenerator).
Includes a ``PayloadAggregator`` instance to collect payloads by time and, once
``num_streams`` is known, a ``VisibilityGenerator`` to turn payload groups
into Visibility objects.
Payload objects contain a CBF beam ID, while the resulting Visibility objects
contain an SDP beam ID. The CBF -> SDP beam ID mapping is performed based on
the contents of the SDP Execution Block ``beams``.
"""
payload_aggregator: PayloadAggregator
payload_batcher: BatchByPayloadSeqNo
visibility_generator: VisibilityGenerator
consumer_dispatcher: ConsumerDispatcher
def add_payload(self, payload: Payload):
"""Add a single payload to be considered for aggregation in the next Visibility object."""
self.payload_aggregator.add_payload(payload)
async def flush(self, *, full_flush: bool = False) -> None:
"""Flush aggregated payloads into groups for visibility generation.
Groups payloads into time windows and prepares them for conversion into
Visibility objects. May drop payloads that are too old.
:param full_flush: If True, flush all payload buffers, else flush only
full or timed-out payload buffers. defaults to False.
"""
if time_group := self.payload_aggregator.flush(full_flush=full_flush):
await self.consumer_dispatcher.dispatch_each(
self.visibility_generator.aiterate(self.payload_batcher.batch(time_group))
)
[docs]
class MultiBeamPipeline(AbstractAsyncContextManager):
"""Aggregates and streams visibility beams of an observation to a consumer.
Detects scan boundaries sends scan events to the consumers.
Generated Visibility objects are passed to a single consumer on periodic
flush cycles from a background asyncio task.
This class builds and manages a collection of :class:`BeamPipeline` instances
per beam to implement the :meth:`add_payload()` interface.
"""
def __init__(
self,
config: AggregationConfig,
consumer: Consumer,
tm: TelescopeManager,
scan_provider: ScanProvider,
reception_ranges: ReceiverReceptionRanges,
):
self._config = config
self._consumer = consumer
self._consumer_dispatcher = ConsumerDispatcher(self._consumer)
self._tm = tm
self._scan_provider = scan_provider
self._current_scan = None
self._sent_scans = set[Scan]()
self._reception_ranges = reception_ranges
self._beam_pipelines = keydefaultdict[BeamPipelineIndex, _BeamPipeline](
self._create_beam_pipeline
)
# Statistics tracking
self._stopping = False
# Periodic task for automatic aggregation
self._periodic_flush_task = (
PeriodicTask(config.time_period, self.flush) if config.time_period > 0 else None
)
self._exit_stack = AsyncExitStack()
def _create_beam_pipeline(self, scan_type_beam_idx: BeamPipelineIndex) -> _BeamPipeline:
"""Create a BeamPipeline for a given beam ID and scan."""
scan_type_id, beam_id = scan_type_beam_idx
beam = find_visibility_beam(beam_id, self._scan_provider.scan_types[scan_type_id])
if beam is None:
raise ValueError(f"No SDP beam found for CBF beam {beam_id}")
beam_reception_range = self._reception_ranges[scan_type_id][beam.beam_id]
return _BeamPipeline(
PayloadAggregator(self._config, beam_reception_range.stream_count),
BatchByPayloadSeqNo(
self._scan_provider,
scan_type_beam_idx,
self._config.num_timestamps_per_aggregation,
),
VisibilityGenerator(self._config, self._tm, beam_reception_range),
# shared instance to only dispatch single start and end scan
# events across all beams
self._consumer_dispatcher,
)
[docs]
def add_payload(self, payload: Payload) -> None:
"""Add a payload to be aggregated."""
if self._stopping:
raise RuntimeError("can't add payload while stopping aggregation")
beam_id = payload.beam_id
scan_type_id = self._scan_provider.query_scan(payload.timestamp).scan_type_id
beam_pipeline = self._beam_pipelines[(scan_type_id, beam_id)]
beam_pipeline.add_payload(payload)
[docs]
async def flush(self, *, full_flush: bool = False) -> None:
"""Flush the aggregator and process the results.
:param full_flush: If True, flush all payload buffers, else flush only
full or timed-out payload buffers. defaults to False.
"""
await asyncio.gather(
*(
beam_pipeline.flush(full_flush=full_flush)
for beam_pipeline in self._beam_pipelines.values()
)
)
[docs]
def reset_time_indexing(self) -> None:
"""Reset the internal time indexing used for payload sequencing.
This delegates to the payload aggregator's reset method.
"""
for beam in self._beam_pipelines.values():
beam.payload_aggregator.reset_time_indexing()
# Properties for statistics and access to internal components
@property
def added_payloads(self) -> int:
"""Number of payloads added to the aggregator."""
return sum(
beam.payload_aggregator.added_payloads for beam in self._beam_pipelines.values()
)
@property
def aggregated_payloads(self) -> int:
"""Number of payloads that have been aggregated into visibilities."""
return sum(
beam.visibility_generator.aggregated_payloads for beam in self._beam_pipelines.values()
)
@property
def visibilities_generated(self) -> int:
"""Number of visibilities that have been generated."""
return sum(
beam.visibility_generator.visibilities_generated
for beam in self._beam_pipelines.values()
)
@property
def visibilities_consumed(self) -> int:
"""Number of visibilities that have been successfully consumed."""
return self._consumer_dispatcher.visibilities_consumed
@property
@deprecated(reason="used only by tests")
def consumer(self) -> Consumer:
"""The consumer this aggregation system forwards visibilities to."""
return self._consumer
# Async context manager methods
[docs]
async def astart(self):
"""Start the aggregation system.
This starts the periodic aggregation task if configured.
"""
if self._periodic_flush_task:
await self._exit_stack.enter_async_context(self._periodic_flush_task)
self._exit_stack.push_async_callback(self.flush, full_flush=True)
await self._exit_stack.enter_async_context(self._consumer_dispatcher)
return self
[docs]
async def astop(self):
"""Stop the aggregation system.
This stops the periodic task and performs a final flush of any
remaining payloads.
"""
self._stopping = True
try:
await self._exit_stack.aclose()
finally:
self._stopping = False
@override
async def __aenter__(self):
await self.astart()
return self
@override
async def __aexit__(self, exc_type, exc_value, traceback):
await self.astop()