# -*- coding: utf-8 -*-
"""
UDP Protocol Multi-stream SPEAD2 receiver
"""
import asyncio
import collections
import dataclasses
import enum
import functools
import logging
import operator
import time
import warnings
from typing import Dict, List, Optional
import numpy as np
import spead2.recv.asyncio
from overrides import override
from realtime.receive.core import socket_utils
from realtime.receive.core.common import autocast_fields
from realtime.receive.core.icd import ICD, ItemID, LowICD, MidICD, Payload, Telescope
from realtime.receive.modules.aggregation import PayloadAggregator
from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.data_reception_handler import DataReceptionHandler
from realtime.receive.modules.periodic_summary_logger import PeriodicSummaryLogger
from realtime.receive.modules.receivers import Config
from realtime.receive.modules.scan_lifecycle_handler import ScanLifecycleHandler
from realtime.receive.modules.utils.continuous_stats_receiver import (
ContinuousStatsReceiver,
KafkaStatsReceiver,
)
logger = logging.getLogger(__name__)
[docs]
def create_stream(io_thread_pool, ring_heaps, max_heaps):
"""Create a spead2 stream for the given options"""
return spead2.recv.asyncio.Stream(
io_thread_pool,
config=spead2.recv.StreamConfig(max_heaps=max_heaps, stop_on_stop_item=False),
ring_config=spead2.recv.RingStreamConfig(heaps=ring_heaps, contiguous_only=False),
)
[docs]
@dataclasses.dataclass(frozen=True)
class FixedLowSOSItemValues:
"""
Values from items send on Low's start-of-stream heaps that should be fixed
across streams. This includes all items except channel ID and frequency.
Objects of this class are used to aggregate information about the streams
of SKA Low data being received. In the future we want to expand this
aggregation to include the channel ID and frequency range, which isn't
accounted for here, as well as including this information for Mid. For more
details see https://jira.skatelescope.org/browse/YAN-1298.
"""
scan_id: int
hardware_id: int
baseline_count: int
station_beam_id: int
subarray_id: int
integration_period: float
frequency_resolution: float
output_resolution: int
zoom_window_id: int
cbf_firmware_version: int
cbf_source_id: str
[docs]
class StartOfStreamLogger(PeriodicSummaryLogger):
"""
A class that keeps track of how many start-of-stream (SOS) heaps have been
successfully received within a period of time. When asked to inform about
these heaps it returns either the count of SOS heaps for Mid, or a summary
of the item values of the SOS heaps for Low.
"""
MESSAGE = "Successfully received %(summary)s in %(elapsed_time).2f [s]"
def __init__(self, period):
PeriodicSummaryLogger.__init__(self, logger, period, StartOfStreamLogger.MESSAGE)
self._low_counter: collections.Counter = collections.Counter()
self._mid_counter: int = 0
[docs]
def reset(self):
self._low_counter.clear()
self._mid_counter = 0
[docs]
def summary(self):
if not self._low_counter and not self._mid_counter:
return False
if self._mid_counter:
return f"{self._mid_counter} Mid start-of-stream heaps"
summary = (
f"{self._low_counter.total()} Low start-of-stream heaps in "
f"{len(self._low_counter)} different set(s), details follow: "
f"{self._low_counter}"
)
return summary
[docs]
def record_low_start_of_stream(self, item_values: FixedLowSOSItemValues):
"""
Records that a Low SOS heap has been received, keeping track of the
subset of values sent in that heap that should be fixed across streams.
"""
self._low_counter[item_values] += 1
[docs]
def record_mid_start_of_stream(self):
"""Records that a Mid SOS heap has been received."""
self._mid_counter += 1
def _item_size(item):
assert item.itemsize_bits % 8 == 0
return item.itemsize_bits // 8 * functools.reduce(operator.mul, item.shape, 1)
def _get_item_group_size(item_group):
return sum(_item_size(item) for item in item_group.values())
[docs]
class Spead2ReceiverPayload(Payload):
"""A Payload that updates itself from data coming from spead2 heaps"""
LOW_CBF_SOURCE_ID = b"L"
def __init__(self, start_of_stream_logger):
super().__init__()
self._item_group = spead2.ItemGroup()
self.item_group_size = None
self._has_all_required_item_descriptors = False
self._sos_logger: StartOfStreamLogger = start_of_stream_logger
self._telescope: Telescope | None = None
@property
def _icd(self) -> ICD:
assert self._telescope is not None
return LowICD if self._telescope == Telescope.LOW else MidICD
[docs]
def set_item_descriptors(self, heap) -> bool:
"""
Updates the SPEAD ItemGroup of this Payload with the ItemDescriptors
from `heap`.
:param heap: A SPEAD heap.
:returns: Whether all required ItemDescriptors this Payload needs have
been found.
"""
updated_items = self._item_group.update(heap)
items_by_id = {item.id: item for item in updated_items.values()}
self.item_group_size = _get_item_group_size(self._item_group)
# CBF Low sends this item to identify itself, otherwise assume Mid
self._telescope = Telescope.MID
if ItemID.CBF_SOURCE_ID in items_by_id:
cbf_source_id = items_by_id[ItemID.CBF_SOURCE_ID].value
if cbf_source_id == Spead2ReceiverPayload.LOW_CBF_SOURCE_ID:
self._telescope = Telescope.LOW
else:
logger.warning(
"ItemDescriptor 0x%04x (CBF_SOURCE_ID) found, but value "
'"%s" != "%s" for Low, assuming "M" for Mid',
ItemID.CBF_SOURCE_ID,
cbf_source_id.decode(),
Spead2ReceiverPayload.LOW_CBF_SOURCE_ID.decode("ascii"),
)
self._has_all_required_item_descriptors = set(self._item_group.ids()) == self._icd.ITEM_IDS
if self._has_all_required_item_descriptors:
def get_item(x):
return self._item_group[x.value.id]
Items = self._icd.Items
# cache references to items in item group for faster access on
# future updates
# pylint: disable=attribute-defined-outside-init
self._corr_out_data_item = get_item(Items.CORRELATOR_OUTPUT_DATA)
if self._telescope == Telescope.LOW:
self._sps_epoch_item = get_item(Items.SPS_EPOCH)
self._epoch_offset_item = get_item(Items.EPOCH_OFFSET)
# the sdp-cbf-emulator can send multiple channels per stream
# even when emulating Low
corr_out_data_shape = self._corr_out_data_item.shape
self.baseline_count = corr_out_data_shape[-1]
if len(corr_out_data_shape) == 2:
logger.warning(
"Found multi-channel visibilities in SKA Low stream: "
"%r, dealing with it",
corr_out_data_shape,
)
self.channel_count = corr_out_data_shape[0]
else:
self.channel_count = 1
else:
self._baseline_count_item = get_item(Items.BASELINE_COUNT)
self._channel_count_item = get_item(Items.CHANNEL_COUNT)
self._channel_id_item = get_item(Items.CHANNEL_ID)
self._hardware_id_item = get_item(Items.HARDWARE_ID)
self._phase_bin_id_item = get_item(Items.PHASE_BIN_ID)
self._phase_bin_count_item = get_item(Items.PHASE_BIN_COUNT)
self._polarisation_id_item = get_item(Items.POLARISATION_ID)
self._scan_id_item = get_item(Items.SCAN_ID)
self._timestamp_count_item = get_item(Items.TIMESTAMP_COUNT)
self._timestamp_fraction_item = get_item(Items.TIMESTAMP_FRACTION)
# Log that the start-of-stream heap has arrived
if self._telescope == Telescope.LOW:
# Low also sends most values in the SOS heap
self.channel_id = get_item(Items.CHANNEL_ID).value
self.scan_id = get_item(Items.SCAN_ID).value
self.hardware_id = get_item(Items.HARDWARE_ID).value
self._sos_logger.record_low_start_of_stream(
FixedLowSOSItemValues(
self.scan_id,
self.hardware_id,
*[
get_item(item).value
for item in (
Items.BASELINE_COUNT,
Items.STATION_BEAM_ID,
Items.SUBARRAY_ID,
Items.INTEGRATION_PERIOD,
Items.FREQUENCY_RESOLUTION,
Items.OUTPUT_RESOLUTION,
Items.ZOOM_WINDOW_ID,
Items.CBF_FIRMWARE_VERSION,
Items.CBF_SOURCE_ID,
)
],
),
)
else:
self._sos_logger.record_mid_start_of_stream()
return self._has_all_required_item_descriptors
[docs]
def update(self, heap):
"""
Updates this Payload with the data extracted from the Items in the given
heap.
:param heap: A SPEAD heap.
"""
assert (
self._has_all_required_item_descriptors
), "ItemGroup doesn't have all required ItemDescriptors"
ig = self._item_group
ig.update(heap)
if self._telescope == Telescope.MID:
self.baseline_count = self._baseline_count_item.value
self.channel_count = self._channel_count_item.value
self.channel_id = self._channel_id_item.value
self.hardware_id = self._hardware_id_item.value
self.phase_bin_id = self._phase_bin_id_item.value
self.phase_bin_count = self._phase_bin_count_item.value
self.polarisation_id = self._polarisation_id_item.value
self.scan_id = self._scan_id_item.value
self.timestamp = self._icd.icd_to_unix(
self._timestamp_count_item.value,
self._timestamp_fraction_item.value,
)
corr_out_data = self._corr_out_data_item.value
self.time_centroid_indices = corr_out_data["TCI"]
self.correlated_data_fraction = corr_out_data["FD"]
self.cci = corr_out_data["CCI"]
self.visibilities = corr_out_data["VIS"]
else:
self.timestamp = self._icd.icd_to_unix(
self._sps_epoch_item.value,
self._epoch_offset_item.value,
)
# Low sends single-channel correlator output, but the payload's
# values always have both dimensions.
corr_out_data = self._corr_out_data_item.value
time_centroids = corr_out_data["TCI"]
data_fraction = corr_out_data["FD"]
vis = corr_out_data["VIS"]
if len(vis.shape) == 2:
time_centroids = np.expand_dims(time_centroids, axis=0)
data_fraction = np.expand_dims(data_fraction, axis=0)
vis = np.expand_dims(vis, axis=0)
self.time_centroid_indices = time_centroids
self.correlated_data_fraction = data_fraction
self.visibilities = vis
[docs]
def incomplete_heaps(stat):
"""Calculate the number of incomplete heaps"""
return stat["incomplete_heaps_evicted"] + stat["incomplete_heaps_flushed"]
[docs]
@dataclasses.dataclass
class ReceptionStats:
"""Statistics about the data reception process"""
total_bytes: int = 0
num_heaps: int = 0
num_incomplete: int = 0
duration: float = 0.0
per_stream_stats: List[Dict[str, int]] = dataclasses.field(default_factory=list)
NO_STATS = ReceptionStats()
[docs]
class StatsTracker:
"""A class that keeps track of reception statistics"""
def __init__(self, streams):
self.streams = streams
self._item_group_size = 0
self._first_reception_time = None
self._duration = None
[docs]
def collect(self) -> ReceptionStats:
"""Collect the current receiver statistics"""
if self._first_reception_time is None:
return NO_STATS
per_stream_stats = [dict(stream.stats) for stream in self.streams]
total_heaps = sum(stat["heaps"] for stat in per_stream_stats)
total_incomplete_heaps = sum(incomplete_heaps(stat) for stat in per_stream_stats)
return ReceptionStats(
(total_heaps - len(self.streams) * 2) * self._item_group_size,
total_heaps,
total_incomplete_heaps,
self.duration,
per_stream_stats,
)
@property
def duration(self):
"""
If in progress, how many seconds since data started being received
If stopped, how many seconds data was received for.
"""
if self._duration is not None:
return self._duration
assert self._first_reception_time is not None
return time.time() - self._first_reception_time
[docs]
def reception_started(self, start_time):
"""Specify the time reception started"""
if self._first_reception_time is None:
self._first_reception_time = start_time
[docs]
def reception_stopped(self):
"""Specify that reception has now stopped"""
assert self._duration is None
if self._first_reception_time is not None:
self._duration = self.duration
[docs]
def log_stats(stats: ReceptionStats):
"""Log the given reception statistics"""
if stats == NO_STATS:
logger.info("No reception statistics to log")
return
per_stream_stats_level = logging.DEBUG
level = logging.WARNING if stats.num_incomplete > 0 else logging.INFO
show_per_stream_stats = logger.isEnabledFor(per_stream_stats_level)
total_megabytes = stats.total_bytes / 1024 / 1024
suffix = ", per stream stats follow" if show_per_stream_stats else ""
logger.log(
level,
"Successfully received and processed %.3f [MB], %d heaps (%d incomplete) in %.3f [s] (%.3f [MB/s], %.3f [heaps/s])%s",
total_megabytes,
stats.num_heaps,
stats.num_incomplete,
stats.duration,
total_megabytes / stats.duration,
stats.num_heaps / stats.duration,
suffix,
)
if not show_per_stream_stats:
return
per_stream_log = functools.partial(logger.log, per_stream_stats_level)
per_stream_log("| Stream | Heaps | Incomplete | Worker blocked | ! |")
per_stream_log("+--------+-------+------------+----------------+---+")
for stream_num, stream_stats in enumerate(stats.per_stream_stats):
incomplete = incomplete_heaps(stream_stats)
blocked = stream_stats["worker_blocked"]
heaps = stream_stats["heaps"]
note = "*" if (incomplete + blocked) else " "
per_stream_log(
f"| {stream_num:6d} | {heaps:5d} | {incomplete:10d} | {blocked:14d} | {note} |"
)
[docs]
class LostDataTracker(PeriodicSummaryLogger):
"""
A class that keeps track of how many data heaps have been lost. When asked
to inform about lost data heaps it prints a warning if any have been
recorded, then it resets itself for a new count.
"""
MESSAGE = (
"Dropped %(summary)d data heaps in the last %(elapsed_time).2f [s] as "
"start_of_stream heap with item descriptors has not been received and "
"buffer is full"
)
LEVEL = logging.WARNING
def __init__(self, period):
self._lost_data_heaps: int = 0
PeriodicSummaryLogger.__init__(
self,
logger,
period,
LostDataTracker.MESSAGE,
level=LostDataTracker.LEVEL,
)
[docs]
def reset(self):
self._lost_data_heaps = 0
[docs]
def summary(self):
return self._lost_data_heaps
[docs]
def record_data_heap_lost(self):
"""Increment the number of heaps that have been lost"""
self._lost_data_heaps += 1
[docs]
class ItemDescStatus(enum.IntEnum):
"""
An enumeration describing the different status in which a stream can be in
with respect to having received ICD item descriptors for its data heaps.
"""
NOT_RECEIVED = 0
VALID = 1
INVALID = 2
[docs]
class TransportProtocols(enum.Enum):
"""Supported transport protocols for SPEAD reception."""
UDP = "udp"
TCP = "tcp"
[docs]
@dataclasses.dataclass
@autocast_fields
class Spead2ReceptionConfig(Config):
"""Set of options used to build a spead2 network receiver"""
method: str = "spead2_receivers"
num_streams: Optional[int] = None
"""
The number of streams this receiver should open. This option should be
preferred over `num_channels` and `channels_per_stream` (which are now
deprecated).
"""
num_channels: int = 1
"""
*DEPRECATED*, use `num_streams` instead.
The number of channels to receive.
"""
continuous_mode: bool = False
"""
Whether the receiver should re-create the streams and resume receiving data
after all end of streams are reached.
"""
channels_per_stream: int = 0
"""
*DEPRECATED*, use `num_streams` instead.
The number of channels for which data will be sent in a single stream. This
is used in the case where multiple ports are required with multiple
channels per port. Together with `num_channels` they define the number of
streams to listen for (and therefore the number of ports to open). ``0``
means that all channels should be received on a single stream.
"""
stats_receiver_interval: float = 1.0
"""
Period of time, in seconds, between publishing of receiver stats to kafka.
"""
stats_receiver_kafka_config: str = ""
"""
Kafka endpoint (of the form ``<host>[:<port>]:<topic>`` where receiver
statistics should be sent to. If empty no statistics are sent.
"""
data_loss_report_rate: float = 1.0
"""
The period, in seconds, at which lost data heaps should be reported, if
any.
"""
start_of_stream_report_rate: float = 1.0
"""
The period, in seconds, at which start-of-stream heaps should be reported
as they arrive.
"""
port_start: int = 41000
"""
The initial port number to which to bind the receiver streams. Successive
streams are opened in successive ports after this.
"""
bind_hostname: str = ""
"""
The IP address or hostname of the interface to which to bind for reception.
"""
transport_protocol: TransportProtocols = TransportProtocols.UDP
"""The network transport protocol used by spead2."""
pcap_file: str = ""
"""
Packet capture file to read SPEAD packets from. If set then data reception
will be done by reading data from this file instead of from the network.
The pcap file should have data for one or more valid SPEAD UDP streams. The
``num_streams`` and ``port_start`` configuration options are still used to
determine how many of these streams to "receive" from the file, with each
stream resulting from filtering the pcap file for a different, successfive
destination UDP port starting at ``port_start``.
"""
readiness_filename: str = ""
"""
If given, the name of the file to create (empty) on disk once the receiver
has finished setting itself up and is ready to receive data.
"""
max_pending_data_heaps: int = 10
"""
The number of data heaps on each stream to accumulate in memory before the
start-of-stream heap arrives. Data heaps cannot be processed before the
start-of-stream heap arrives, so a finite queue is implemented to keep some
of them around. Further data heaps that don’t fit in the queue are dropped
and permanently lost.
"""
test_failure: bool = False
"""
DEPRECATED, has no effect.
"""
ring_heaps: int = 16
"""
The number of ring heaps used in by each SPEAD stream.
"""
buffer_size: Optional[int] = None
"""
The socket buffer size to use. If not given, a default value is calculated
based on the default value set by spead2, and the limits imposed by the OS.
"""
max_packet_size: int = spead2.recv.Stream.DEFAULT_UDP_MAX_SIZE
"""
The maximum packet size to accept on the streams.
"""
receiver_threads: int = 1
"""
The number threads allocated to the spead2 I/O thread pool.
"""
reset_time_indexing_after_each_scan: bool = False
"""
Whether to reset the aggregator's time indexing when data reception for a
scan finishes.
"""
def __post_init__(self):
if self.num_streams is None:
warnings.warn(
(
"reception.num_channels and reception.channels_per_stream "
"are deprecated, use reception.num_streams instead"
),
category=DeprecationWarning,
)
if self.num_channels == 0:
raise ValueError(
"Reception num_channels configuration must be the real amount of files"
)
if self.channels_per_stream == 0:
self.channels_per_stream = self.num_channels
self.num_streams = self.num_channels // self.channels_per_stream
if self.buffer_size is None:
default_buffer_size = (
spead2.recv.Stream.DEFAULT_UDP_BUFFER_SIZE
if self.transport_protocol == TransportProtocols.UDP
else spead2.recv.Stream.DEFAULT_TCP_BUFFER_SIZE
)
os_max_buffer_size = socket_utils.max_socket_read_buffer_size(
self.transport_protocol.value
)
if os_max_buffer_size >= default_buffer_size:
self.buffer_size = default_buffer_size
else:
logger.debug(
(
"Adjusting default reception buffer_size (%d -> %d) "
"to match OS max settings"
),
default_buffer_size,
os_max_buffer_size,
)
self.buffer_size = os_max_buffer_size
[docs]
class receiver(ScanLifecycleHandler):
"""
SPEAD2 receiver
This class uses the spead2 library to receive a multiple number of streams,
each using a single UDP reader. As heaps are received they are given to a
single consumer.
This receiver supports UDP multicast addressing. This means that multiple
receivers (and therefore consumers) can access the same transmitted stream
if they bind to the same multicast IP address and port.
"""
config_class = Spead2ReceptionConfig
def __init__(
self,
config: Spead2ReceptionConfig,
aggregator: PayloadAggregator,
data_reception_handler: DataReceptionHandler | None = None,
):
self.config = config
self._data_reception_handler = data_reception_handler
self._current_scan_id: int | None = None
self._aggregator = aggregator
self._aggregator.inform_num_streams(self.config.num_streams)
self.reception_tracker = None
self.continuous_stats_receiver = ContinuousStatsReceiver(
self, interval=config.stats_receiver_interval
)
if config.stats_receiver_kafka_config:
self.continuous_stats_receiver.add_receiver(
KafkaStatsReceiver.create(config.stats_receiver_kafka_config)
)
self._lost_data_tracker = LostDataTracker(config.data_loss_report_rate)
self._start_of_stream_logger = StartOfStreamLogger(config.start_of_stream_report_rate)
self._end_scan_event = asyncio.Event()
[docs]
@override
async def start_scan(self, scan_id: int) -> None:
# No necessary action here.
# The receiver is always ready to receive data (it either automatically re-creates its
# streams, or finishes), so there is no need to open streams upon a Scan command
pass
[docs]
@override
async def end_scan(self, scan_id: int) -> None:
if scan_id == self._current_scan_id:
self._end_scan_event.set()
@property
def stats(self):
"""Return the latest receiver statistics"""
if not self.reception_tracker:
return ReceptionStats()
return self.reception_tracker.collect()
@property
def received_payloads(self) -> int:
"""The number of data payloads that have been received."""
return self._aggregator.added_payloads
@property
def aggregated_payloads(self) -> int:
"""The number of data payloads that have been received and aggregated."""
return self._aggregator.aggregated_payloads
@property
def visibilities_generated(self) -> int:
"""The number of visibilities that have been successfully generated."""
return self._aggregator.visibilities_generated
@property
def visibilities_consumed(self) -> int:
"""The number of visibilities that have been successfully consumed."""
return self._aggregator.visibilities_consumed
@property
def consumer(self) -> Consumer:
"""The consumer this receiver finally forwards data to"""
return self._aggregator.consumer
def _setup_streams(self, config):
io_thread_pool = spead2.ThreadPool(threads=config.receiver_threads)
start_time = time.time()
streams = []
recv_port = config.port_start
protocol = config.transport_protocol
for i in range(self.config.num_streams):
stream = create_stream(io_thread_pool, config.ring_heaps, 32)
port = recv_port + i
if config.pcap_file:
pcap_filter = f"udp dst port {port}"
stream.add_udp_pcap_file_reader(config.pcap_file, pcap_filter)
logger.debug(
'Created pcap-based stream %d from %s filtering with "%s"',
i,
config.pcap_file,
pcap_filter,
)
else:
add_reader = stream.add_udp_reader
if protocol == TransportProtocols.TCP:
add_reader = stream.add_tcp_reader
try:
add_reader(
port,
bind_hostname=config.bind_hostname,
buffer_size=config.buffer_size,
max_size=config.max_packet_size,
)
except Exception as e:
raise RuntimeError(f"Cannot read from port {port}") from e
logger.debug("Created stream %d on port %d", i, port)
streams.append(stream)
logger.info(
"Created %d %s receive streams in %.3f [ms]",
self.config.num_streams,
"pcap-based" if config.pcap_file else protocol.value,
(time.time() - start_time) * 1000,
)
payloads = [
Spead2ReceiverPayload(self._start_of_stream_logger)
for _ in range(self.config.num_streams)
]
return io_thread_pool, streams, payloads
[docs]
async def run(self, ready_event: Optional[asyncio.Event] = None):
"""Receive all heaps, passing them to the consumer"""
await self.continuous_stats_receiver.start()
while True:
io_thread_pool, streams, payloads = self._setup_streams(self.config)
self.reception_tracker = StatsTracker(streams)
self._end_scan_event.clear()
receive_tasks = [
self._process_stream_heaps(stream, payload)
for stream, payload in zip(streams, payloads)
]
self._signal_ready_to_receive(ready_event)
results = await asyncio.gather(*receive_tasks, return_exceptions=False)
streams_stopped_by_end_scan = sum(1 for result in results if result)
if streams_stopped_by_end_scan:
logger.warning(
"%d streams stopped early due to SDP's EndScan", streams_stopped_by_end_scan
)
await self._aggregator.flush()
self.reception_tracker.reception_stopped()
self._current_scan_id = None
if self._data_reception_handler:
await self._data_reception_handler.last_scan_data_received()
if self.config.reset_time_indexing_after_each_scan:
self._aggregator.reset_time_indexing()
io_thread_pool.stop()
await self.continuous_stats_receiver.collect_and_send_stats()
log_stats(self.reception_tracker.collect())
if not self.config.continuous_mode:
break
await self.continuous_stats_receiver.stop()
def _signal_ready_to_receive(self, ready_event: Optional[asyncio.Event] = None):
readiness_filename = self.config.readiness_filename
if readiness_filename:
with open(readiness_filename, "wb"):
pass
logger.debug(
"Created %s to signal we are ready to receive data",
readiness_filename,
)
if ready_event:
ready_event.set()
logger.info("Ready to receive data")
self.continuous_stats_receiver.scan_id = -1
async def _process_stream_heaps(self, stream, payload) -> bool:
end_scan_wait_task = asyncio.create_task(self._end_scan_event.wait())
end_scan_wait_task.add_done_callback(lambda _task: stream.stop())
try:
await self._really_process_stream_heaps(stream, payload)
return self._end_scan_event.is_set()
finally:
if not end_scan_wait_task.done():
end_scan_wait_task.cancel()
stream.stop()
async def _really_process_stream_heaps(self, stream, payload):
pending_data_heaps = []
item_desc_status = ItemDescStatus.NOT_RECEIVED
async for heap in stream:
now = time.time()
self.reception_tracker.reception_started(now)
self._lost_data_tracker.maybe_log(now)
self._start_of_stream_logger.maybe_log(now)
# Handle stream control heaps
if heap.is_start_of_stream():
if not payload.set_item_descriptors(heap):
item_desc_status = ItemDescStatus.INVALID
logger.error(
"start-of-stream heap received, "
"but doesn't contain all required item descriptors, "
"all incoming heaps will be discarded"
)
else:
item_desc_status = ItemDescStatus.VALID
# These need to be done sequentially because the payload
# object is updated with the data from each heap.
for pending_heap in pending_data_heaps:
await self._process_data_heap(pending_heap, payload)
if pending_data_heaps:
pending_data_heaps.clear()
continue
elif heap.is_end_of_stream():
stream.stop()
break
# Handle data heaps
if isinstance(heap, spead2.recv.IncompleteHeap):
continue
if item_desc_status == ItemDescStatus.INVALID:
continue
elif item_desc_status == ItemDescStatus.NOT_RECEIVED:
if len(pending_data_heaps) == self.config.max_pending_data_heaps:
self._lost_data_tracker.record_data_heap_lost()
else:
pending_data_heaps.append(heap)
continue
await self._process_data_heap(heap, payload)
now = time.time()
self._lost_data_tracker.maybe_log(now, force=True)
self._start_of_stream_logger.maybe_log(now, force=True)
async def _process_data_heap(self, heap, payload):
payload.update(heap)
if self._current_scan_id is None:
self._current_scan_id = int(payload.scan_id)
if self._data_reception_handler:
await self._data_reception_handler.first_scan_data_received(self._current_scan_id)
self.reception_tracker.inform_item_group_size(payload.item_group_size)
self.continuous_stats_receiver.scan_id = payload.scan_id
self._aggregator.add_payload(payload)