# -*- coding: utf-8 -*-
"""
Primary send functions for ska-sdp-cbf-emulator
"""
import asyncio
import dataclasses
import logging
import time
from typing import Generator
from realtime.receive.core.common import autocast_fields, from_dict
from ska_sdp_cbf_emulator import transmitters
from ska_sdp_cbf_emulator.data_source import (
CorrelatedDataSource,
HardcodedDataSource,
HardcodedDataSourceConfig,
MeasurementSetDataSource,
MeasurementSetDataSourceConfig,
)
logger = logging.getLogger(__name__)
[docs]
@dataclasses.dataclass
@autocast_fields
class SdpConfigDbConfig:
"""
Set of options used to establish a connection to the SDP Configuration DB.
"""
host: str = "127.0.0.1"
"""The host to connect to."""
port: int = 2379
"""The port to connect to"""
backend: str = "etcd3"
"""The backend to use"""
[docs]
@dataclasses.dataclass
@autocast_fields
# pylint: disable-next=too-many-instance-attributes
class SenderConfig:
"""Configuration for a data sending operation"""
scan_id: int = 1
"""The Scan ID to use for all payloads in transmission."""
subarray_id: int = 1
"""If emulating LOW, the subarray id"""
beam_id: int = 1
"""If emulating LOW, the station beam id"""
strict_beam_id_conformance: bool = True
"""If the beam_id field should be checked for strict conformance"""
zoom_window_id: int = 1
"""If emulating LOW, the zoom window id"""
time_interval: float = 0
"""
The period of time to wait between sending data for successive data dumps.
Positive values are used as-is. A value of 0 means to use the time
differences in the successive visibility time steps. Negative values
mean to not wait, sending data as fast as possible.
"""
ms: MeasurementSetDataSourceConfig | None = None
"""The configuration for reading the input MS."""
hardcoded: HardcodedDataSourceConfig | None = None
"""The configuration for generating hardcoded visibilities"""
transmission: transmitters.Config = dataclasses.field(default_factory=transmitters.Config)
"""The configuration for transmitting data over the network."""
sdp_config_db: SdpConfigDbConfig = dataclasses.field(default_factory=SdpConfigDbConfig)
"""The configuration to connect to the SDP Configuration Database"""
hardware_id: int = 0xBEEF
"""The hardware id of the emulator"""
def __post_init__(self):
if self.scan_id < 0:
raise ValueError(f"scan_id must be >= 0: {self.scan_id}")
[docs]
def to_sender_config(dict_config: dict) -> SenderConfig:
"""Turn a dictionary into a SenderConfig object."""
for key in ["sender", "sdp_config_db", "transmission"]:
if key not in dict_config:
dict_config[key] = {}
config = from_dict(SenderConfig, dict_config["sender"])
config.sdp_config_db = from_dict(SdpConfigDbConfig, dict_config["sdp_config_db"])
config.transmission = transmitters.create_config(**dict_config["transmission"])
if "ms" in dict_config:
config.ms = from_dict(MeasurementSetDataSourceConfig, dict_config["ms"])
if "hardcoded" in dict_config:
config.hardcoded = from_dict(HardcodedDataSourceConfig, dict_config["hardcoded"])
return config
[docs]
async def packetise(config: SenderConfig):
"""
Reads data off a Measurement Set and transmits it using the transmitter
specified in the configuration.
Uses the vis_reader get data from the measurement set then gives it to the
transmitter for packaging and transmission. This code is transmission
protocol agnostic.
"""
if config.ms is not None:
data_source = MeasurementSetDataSource(config.ms)
elif config.hardcoded is not None:
data_source = HardcodedDataSource(config.hardcoded)
else:
raise RuntimeError("No data source has been configured")
async with data_source:
return await packetise_visibilities(config, data_source)
[docs]
async def packetise_visibilities(config: SenderConfig, data_source: CorrelatedDataSource):
"""
Reads data from the provided data source and transmits it using
transmitter specified in the configuration.
"""
log_info = [
("scan id", config.scan_id),
("time interval", f"{config.time_interval} (0 == as per source, <0 == fly through)"),
("data source", data_source.name),
]
log_info += data_source.info
for name, value in log_info:
logger.info("%-20s: %s", name, value)
# prime coroutine-like generator
intervals = _create_interval_generator(config.time_interval)
transmitter = await transmitters.create(
config.transmission,
transmitters.TransmitterInitData(
config.scan_id,
data_source.num_baselines,
data_source.channels,
transmitters.TransmitterLowInitData(
config.subarray_id,
config.beam_id,
config.zoom_window_id,
data_source.integration_period,
data_source.channel_frequencies,
data_source.channel_resolution,
data_source.visibility_epoch,
config.strict_beam_id_conformance,
),
config.hardware_id,
),
)
# Iterate over timesteps in the data
start_time = time.time()
async with transmitter:
prev_send_start = time.time()
async for datum in data_source.data():
# Gets interval value to emulate, adjust to remove runtime overhead
waiting_time = intervals.send(datum.unix_timestamp)
next(intervals)
if waiting_time > 0:
waiting_time -= time.time() - prev_send_start
if waiting_time > 0:
await asyncio.sleep(waiting_time)
prev_send_start = time.time()
assert data_source.channels.count == datum.visibilities.shape[1]
await transmitter.send(datum)
# Print time taken.
duration = time.time() - start_time
data_size = transmitter.bytes_sent / 1024 / 1024
logger.info(
"Scan %s sent %.3f [MB], %d heaps in %.3f [s] (%.3f [MB/s], %.3f [heaps/s])",
config.scan_id,
data_size,
transmitter.heaps_sent,
duration,
(data_size / duration),
transmitter.heaps_sent / duration,
)
return transmitter.heaps_sent
def _create_interval_generator(time_interval: float) -> Generator[float, float, None]:
"""
Create a generator that will yield intervals according to
the time_interval value:
* > 0: fixed
* ==0: difference between successive values
* < 0: immediate
"""
if time_interval > 0:
def intervals():
while True:
yield time_interval
elif time_interval == 0:
def intervals():
interval = 0
prev_vis_time = yield
yield 0
while True:
vis_time = yield
if vis_time < prev_vis_time:
yield interval
else:
interval = vis_time - prev_vis_time
yield interval
prev_vis_time = vis_time
else:
def intervals():
while True:
yield 0
interval_generator = intervals()
next(interval_generator)
return interval_generator