"""Takes a SPEAD2 HEAP and writes it to an apache plasma store. This uses the sdp-dal-prototype
API and will fail if that cannot be loaded
"""
from __future__ import annotations
import asyncio
import enum
import errno
import functools
import logging
import threading
import time
from collections.abc import Sequence
import numpy as np
import ska_sdp_dal as rpc
from overrides import overrides
from pydantic import dataclasses
from realtime.receive.core import Scan, ScanType
from realtime.receive.core.icd import Telescope
from realtime.receive.core.uvw_engine import UVWEngine
from sdp_dal_schemas import PROCEDURES, TABLES
from ska_sdp_datamodels.visibility import Visibility
from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils import table_utils
from ._config import Config
logger = logging.getLogger(__name__)
References = Sequence[rpc.connection.Ref]
[docs]
class PlasmaExitedAction(enum.Enum):
"""Actions to take when the consumer detects Plasma has exited."""
RECONNECT_AND_RETRY = enum.auto()
"""Try reconnecting to Plasma and upon success retry the current operation."""
IGNORE = enum.auto()
"""Ignore the fact that Plasma exited and don't continue."""
async def _handle_plasma_exit(
consumer: PlasmaWriterConsumer, action: PlasmaExitedAction, func, *args, **kwargs
):
try:
await func(consumer, *args, **kwargs)
except OSError as e:
# If plasma goes down, interacting with its socket raises a "bad file descriptor" error:
# * When ska-sdp-dal is select()-ing on the fd, the original EBADF is set.
# * When ska-sdp-dal is retrieving the notification message, plasma.cc internally turns
# that into a simple error with the message given here.
if (
e.errno != errno.EBADF
and "Failed to read object notification from Plasma socket" not in e.args[0]
):
raise
match action:
case PlasmaExitedAction.RECONNECT_AND_RETRY:
logger.warning(
"Plasma exited while running %s, attempting reconnect and trying again",
func.__name__,
)
consumer.disconnect()
consumer.connect()
await func(consumer, *args, **kwargs)
case PlasmaExitedAction.IGNORE:
logger.warning(
"Plasma exited while running %s, will not attempt again", func.__name__
)
def _as_decorator(action: PlasmaExitedAction):
def decorator(func):
@functools.wraps(func)
async def _wrapper(consumer, *args, **kwargs):
await _handle_plasma_exit(consumer, action, func, *args, **kwargs)
return _wrapper
return decorator
reconnect_on_plasma_exit = _as_decorator(PlasmaExitedAction.RECONNECT_AND_RETRY)
ignore_plasma_exit = _as_decorator(PlasmaExitedAction.IGNORE)
[docs]
@dataclasses.dataclass
class PlasmaWriterConfig(Config):
"""Configuration for the `plasma_writer` consumer."""
name: str = "plasma_writer"
plasma_path: str = "/tmp/plasma"
"""The UNIX socket to connect to."""
payloads_in_flight: int = 10
"""
The maximum number of payloads to keep in flight in the Plasma Store before
requesting their results and releasing our local references to them. If
more payloads are written to Plasma the oldest references are relesed.
"""
remove_old_references_timeout: float = 0.0
"""
The maximum amount of time to wait, in seconds, for a pending response when
removing references to old input objects in the Plasma store.
"""
wait_for_all_responses_timeout: float = 5.0
"""
The maximum amount of time to wait, in seconds, for all pending responses
to be read from Plasma when the consumer is shut down.
"""
[docs]
@dataclasses.dataclass
class InvocationResults:
"""Summary of RPC call results"""
success: int = 0
fail: int = 0
unknown: int = 0
[docs]
def add(self, result: int):
"""Add a single result, which signals either success or failure"""
if result == 0:
self.success += 1
else:
self.fail += 1
[docs]
class PlasmaWriterConsumer(Consumer):
"""A heap consumer that writes incoming data into a Plasma Store.
Because data consumption happens inside the event loop we need to defer the
data writing to a different thread. We do this by creating a single-threaded
executor that we then use to schedule the access to the Plasma Store.
"""
config_class = PlasmaWriterConfig
@overrides
def __init__(
self,
config: PlasmaWriterConfig,
tm: TelescopeManager,
uvw_engine: UVWEngine,
):
super().__init__(config, tm, uvw_engine)
self._plasma_path = config.plasma_path
self.store: rpc.Store | None = None
self.caller: rpc.Caller | None = None
self._rpc_inout_refs: list[tuple[References, References]] = []
self._common_refs: dict = {}
self._scan_type_refs: dict[str, dict] = {}
self._scan_id_refs: dict[int, dict] = {}
self._invocation_results = InvocationResults()
self.connect()
[docs]
def disconnect(self):
"""Disconnect from the Plasma store, if still connected."""
if self.store:
connection = self.store.conn
try:
connection._socket.close() # pylint: disable=protected-access
except Exception as ex: # pylint: disable=broad-exception-caught
logger.warning("Failed to close notification socket, ignoring: %s", ex)
connection.client.disconnect()
[docs]
def connect(self):
"""Connect to the plasma store and reset all local tensor references."""
self.store = rpc.Store(self.config.plasma_path)
self.caller = rpc.Caller(
[PROCEDURES[name] for name in ("process_visibility", "start_scan", "end_scan")],
self.store,
broadcast=True,
verbose=True,
minimum_processors=0,
)
self._rpc_inout_refs.clear()
self._common_refs.clear()
self._scan_type_refs.clear()
self._scan_id_refs.clear()
@property
def invocation_results(self):
"""A summary of the results from RPC invocations."""
return self._invocation_results
@property
def invocations_in_flight(self):
"""
The number of RPC invocations that are still in flight, and for which
we haven't collected their output.
"""
return len(self._rpc_inout_refs)
@property
def num_processors(self):
"""The number of processors known to this caller."""
# pylint: disable-next=protected-access
return len(self.caller._processors)
[docs]
def find_processors(self):
"""Search the caller for processors."""
self.caller.find_processors(verbose=True)
return self.num_processors
def _invoke_rpc(self, rpc_call_name, *input_refs) -> None:
logger.debug("Remotely invoking %s", rpc_call_name)
output_refs = getattr(self.caller, rpc_call_name)(*input_refs)
# Keep arguments and responses in scope
self._rpc_inout_refs.append((input_refs, output_refs))
[docs]
@reconnect_on_plasma_exit
async def start_scan(self, scan: Scan) -> None:
"""Invoke the start_scan RPC procedure on remote processors."""
scan_id = scan.scan_number
scan_id_tensor = self.store.put_new_tensor(np.array([scan_id], dtype=np.int64))
assert scan_id not in self._scan_id_refs
self._scan_id_refs[scan_id] = scan_id_tensor
self.find_processors()
# pylint: disable-next=no-member
self._invoke_rpc("start_scan", scan_id_tensor)
[docs]
@reconnect_on_plasma_exit
async def end_scan(self, scan: Scan) -> None:
"""Invoke the end_scan RPC procedure on remote processors."""
self.find_processors()
scan_id_tensor = self._get_existing_scan_id_tensor_ref(scan.scan_number)
# pylint: disable-next=no-member
self._invoke_rpc("end_scan", scan_id_tensor)
def _get_existing_scan_id_tensor_ref(self, scan_id) -> rpc.connection.Ref:
if scan_id not in self._scan_id_refs:
logger.warning("Tensor for scan %d not found in Plasma, creating one anyway", scan_id)
self._scan_id_refs[scan_id] = self.store.put_new_tensor(
np.array([scan_id], dtype=np.int64)
)
return self._scan_id_refs[scan_id]
[docs]
@overrides
@reconnect_on_plasma_exit
async def consume(self, visibility: Visibility):
self.find_processors()
await self._remove_old_references()
self._write_data_to_plasma(visibility)
[docs]
@overrides
@ignore_plasma_exit
async def astop(self):
timed_out = await self._wait_for_all_responses()
if timed_out:
logger.warning(
"Timed out while waiting for pending RPC responses, will disconnect now anyway"
)
self._remove_all_common_refs()
self.disconnect()
@staticmethod
def _get_polarizations(scan_type: ScanType):
"""Get a columnar table serialization of polarizations for a single ScanType."""
return {
"polarizations_id": [beam.polarisations.polarisation_id for beam in scan_type.beams],
"corr_type": [
list(map(lambda t: t.name, beam.polarisations.correlation_type))
for beam in scan_type.beams
],
}
@staticmethod
def _get_spectral_windows(scan_type: ScanType):
"""Get a columnar table serialization of spectral windows for a single ScanType."""
def get_spectral_windows_attribute(name: str):
return [
getattr(sw, name)
for beam in scan_type.beams
for sw in beam.channels.spectral_windows
]
return {
"channels_id": [
beam.channels.channels_id
for beam in scan_type.beams
for sw in beam.channels.spectral_windows
],
"spectral_window_id": get_spectral_windows_attribute("spectral_window_id"),
"count": get_spectral_windows_attribute("count"),
"start": get_spectral_windows_attribute("start"),
"stride": get_spectral_windows_attribute("stride"),
"freq_min": get_spectral_windows_attribute("freq_min"),
"freq_max": get_spectral_windows_attribute("freq_max"),
}
@staticmethod
def _get_channels(scan_type: ScanType):
"""Get a columnar table serialization of channels for a single ScanType."""
return {
"channels_id": [beam.channels.channels_id for beam in scan_type.beams],
}
@staticmethod
def _get_fields(scan_type: ScanType):
"""Get a columnar table serialization of fields for a single ScanType."""
return {
"field_id": [beam.field.field_id for beam in scan_type.beams if beam.field],
"phase_dir": [
{
"ra": beam.field.phase_dir.ra.rad,
"dec": beam.field.phase_dir.dec.rad,
"reference_time": beam.field.phase_dir.reference_time,
"reference_frame": beam.field.phase_dir.reference_frame,
}
for beam in scan_type.beams
if beam.field
],
}
@staticmethod
def _get_beams(scan_type: ScanType):
"""Get a columnar table serialization of beams for a single ScanType."""
# TODO(cgray): this integer mapping of beams can't be used directly with measurement
# sets due to BEAM_ID corresponding to a row entry. Unless beams across all
# scans are written this will need additional mapping.
return {
"beam_id": [beam.beam_id for beam in scan_type.beams if beam.field],
"field_id": [beam.field.field_id for beam in scan_type.beams if beam.field],
"channels_id": [beam.channels.channels_id for beam in scan_type.beams if beam.field],
"polarizations_id": [
beam.polarisations.polarisation_id for beam in scan_type.beams if beam.field
],
"function": [beam.function for beam in scan_type.beams if beam.field],
}
@staticmethod
def _get_scan_type(scan_type: ScanType):
"""
Return a simple dict for the given ScanType, suitable for turning into a SCAN_TYPE.
"""
return {
"scan_type_id": [scan_type.scan_type_id],
"integration_time": [scan_type.integration_time],
"scan_intents": [scan_type.scan_intents],
"averaging_channels": [scan_type.averaging_channels],
"averaging_samples": [scan_type.averaging_samples],
}
def _write_common_metadata(self, visibility: Visibility):
baselines = table_utils.get_baselines(visibility)
antennas = table_utils.get_antennas(visibility.attrs["configuration"])
return {
"baselines": self.store.put_new_table(baselines, schema=TABLES["BASELINE"]),
"antennas": self.store.put_new_table(antennas, schema=TABLES["ANTENNA2"]),
}
def _write_scan_type_metadata(self, scan_type: ScanType):
return {
"polarizations": self.store.put_new_table(
self._get_polarizations(scan_type), schema=TABLES["POLARIZATION"]
),
"spectral_window": self.store.put_new_table(
self._get_spectral_windows(scan_type), schema=TABLES["SPECTRAL_WINDOW"]
),
"channels": self.store.put_new_table(
self._get_channels(scan_type), schema=TABLES["CHANNELS"]
),
"field": self.store.put_new_table(self._get_fields(scan_type), schema=TABLES["FIELD"]),
"beam": self.store.put_new_table(self._get_beams(scan_type), schema=TABLES["BEAM"]),
"scan_type": self.store.put_new_table(
self._get_scan_type(scan_type), schema=TABLES["SCAN_TYPE"]
),
}
def _write_data_to_plasma(self, visibility: Visibility):
"""
Writes common_metadata tables to plasma on first visibility,
then time dependent data on subsequent ones.
"""
payload_seq_numbers: int = visibility.attrs["meta"]["payload_seq_numbers"]
scan: Scan = visibility.attrs["meta"]["scan"]
scan_type = scan.scan_type
if not self._common_refs:
self._common_refs = self._write_common_metadata(visibility)
if scan_type.scan_type_id not in self._scan_type_refs:
self._scan_type_refs[scan_type.scan_type_id] = self._write_scan_type_metadata(
scan_type
)
scan_type_refs = self._scan_type_refs[scan_type.scan_type_id]
channel_ids = visibility.meta["channel_ids"]
# start_scan() should have been called before, but if the receiver restarted
# then a reference to the exact scan_id in plasma won't exist yet
scan_id_tensor = self._get_existing_scan_id_tensor_ref(scan.scan_number)
vis_beam_id = visibility.meta["beam_id"]
# Plasma can't transmit string parameters, so vis_beam_id can't be directly communicated.
# Instead an index into the BEAM table is sent
beam_index = [beam.beam_id for beam in scan_type.beams].index(vis_beam_id)
phase_direction = scan_type.beams[beam_index].field.phase_dir
uvw = self._get_uvw(visibility, phase_direction)
telescope = visibility.meta["telescope"] == Telescope.MID
# Put everything on Plasma according to
# sdp_dal_schemas/schemas/procedures.json at process_visibility.inputs
input_refs = (
scan_id_tensor,
self.store.put_new_tensor(np.array([beam_index])),
self.store.put_new_tensor(np.array([telescope])),
self.store.put_new_tensor(payload_seq_numbers),
self.store.put_new_tensor(visibility.time.data),
self.store.put_new_tensor(visibility.exposure.data),
self.store.put_new_tensor(channel_ids),
self.store.put_new_tensor(visibility.flags.data.astype(bool)),
self.store.put_new_tensor(visibility.weight.data),
scan_type_refs["scan_type"],
scan_type_refs["beam"],
self._common_refs["baselines"],
scan_type_refs["spectral_window"],
scan_type_refs["channels"],
self._common_refs["antennas"],
scan_type_refs["field"],
scan_type_refs["polarizations"],
self.store.put_new_tensor(uvw),
self.store.put_new_tensor(visibility.vis.data, rpc.complex64),
self.store.put_new_tensor(visibility.time_centroids.data),
)
# pylint: disable-next=no-member
self._invoke_rpc("process_visibility", *input_refs)
def _get_response(self, refs, timeout, executor_done_evt):
"""Obtain results from previous calls referenced by `refs`."""
try:
return [ref["output"].get(timeout) for ref in refs]
finally:
executor_done_evt.set()
[docs]
async def wait_for_oldest_reponse(self, timeout):
"""
Waits for the response of the oldest RPC call issued by this consumer,
and removes references to its inputs and output objects so they can be
freed by plasma.
"""
# _in_refs has to be kept around until out_refs are awaited for a reponse
_in_refs, out_refs = self._rpc_inout_refs.pop(0)
executor_done_evt = threading.Event()
try:
all_results = await asyncio.get_running_loop().run_in_executor(
None, self._get_response, out_refs, timeout, executor_done_evt
)
except BaseException:
# We might have been interrupted while waiting for the future, so
# let's make sure self._get_response finishes before we do to avoid
# hanging the default threadpool executor for future users.
while True:
if executor_done_evt.is_set():
break
await asyncio.sleep(0.1)
self._invocation_results.unknown += len(out_refs)
raise
else:
if all_results:
# all_results is a list of ndarrays, each with 1 dimension
for invocation_res in all_results:
assert len(invocation_res.shape) == 1
if len(invocation_res) != 1:
logger.warning(
"Received result array with count != 1, reported "
"results might be incorrect"
)
if len(invocation_res) == 0:
self._invocation_results.unknown += 1
else:
self._invocation_results.add(invocation_res[0])
return all_results
async def _remove_old_references(self):
"""Remove references to calls older than the configured limit."""
while len(self._rpc_inout_refs) >= self.config.payloads_in_flight:
try:
await self.wait_for_oldest_reponse(self.config.remove_old_references_timeout)
except rpc.connection.TimeoutException:
logger.warning("Timed out while waiting for response, ignoring")
async def _wait_for_all_responses(self):
timeout = self.config.wait_for_all_responses_timeout
timed_out = False
while self._rpc_inout_refs:
start = time.time()
try:
await self.wait_for_oldest_reponse(timeout)
except rpc.connection.TimeoutException:
timed_out = True
finally:
timeout = max(0, timeout - (time.time() - start))
return timed_out
[docs]
@ignore_plasma_exit
async def wait_for_all_responses(self) -> bool:
"""Wait for the responses for all in-flight invocations.
:returns: whether there was a timeout while waiting for all responses.
"""
return await self._wait_for_all_responses()
def _remove_all_common_refs(self):
self._common_refs = {}
self._scan_type_refs = {}