Source code for realtime.receive.modules.consumers.plasma_writer

"""Takes a SPEAD2 HEAP and writes it to an apache plasma store. This uses the sdp-dal-prototype
API and will fail if that cannot be loaded
"""

from __future__ import annotations

import asyncio
import enum
import errno
import functools
import logging
import threading
import time
from collections.abc import Sequence

import numpy as np
import ska_sdp_dal as rpc
from overrides import overrides
from pydantic import dataclasses
from realtime.receive.core import Scan, ScanType
from realtime.receive.core.icd import Telescope
from realtime.receive.core.uvw_engine import UVWEngine
from sdp_dal_schemas import PROCEDURES, TABLES
from ska_sdp_datamodels.visibility import Visibility

from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils import table_utils

from ._config import Config

logger = logging.getLogger(__name__)


References = Sequence[rpc.connection.Ref]


[docs] class PlasmaExitedAction(enum.Enum): """Actions to take when the consumer detects Plasma has exited.""" RECONNECT_AND_RETRY = enum.auto() """Try reconnecting to Plasma and upon success retry the current operation.""" IGNORE = enum.auto() """Ignore the fact that Plasma exited and don't continue."""
async def _handle_plasma_exit( consumer: PlasmaWriterConsumer, action: PlasmaExitedAction, func, *args, **kwargs ): try: await func(consumer, *args, **kwargs) except OSError as e: # If plasma goes down, interacting with its socket raises a "bad file descriptor" error: # * When ska-sdp-dal is select()-ing on the fd, the original EBADF is set. # * When ska-sdp-dal is retrieving the notification message, plasma.cc internally turns # that into a simple error with the message given here. if ( e.errno != errno.EBADF and "Failed to read object notification from Plasma socket" not in e.args[0] ): raise match action: case PlasmaExitedAction.RECONNECT_AND_RETRY: logger.warning( "Plasma exited while running %s, attempting reconnect and trying again", func.__name__, ) consumer.disconnect() consumer.connect() await func(consumer, *args, **kwargs) case PlasmaExitedAction.IGNORE: logger.warning( "Plasma exited while running %s, will not attempt again", func.__name__ ) def _as_decorator(action: PlasmaExitedAction): def decorator(func): @functools.wraps(func) async def _wrapper(consumer, *args, **kwargs): await _handle_plasma_exit(consumer, action, func, *args, **kwargs) return _wrapper return decorator reconnect_on_plasma_exit = _as_decorator(PlasmaExitedAction.RECONNECT_AND_RETRY) ignore_plasma_exit = _as_decorator(PlasmaExitedAction.IGNORE)
[docs] @dataclasses.dataclass class PlasmaWriterConfig(Config): """Configuration for the `plasma_writer` consumer.""" name: str = "plasma_writer" plasma_path: str = "/tmp/plasma" """The UNIX socket to connect to.""" payloads_in_flight: int = 10 """ The maximum number of payloads to keep in flight in the Plasma Store before requesting their results and releasing our local references to them. If more payloads are written to Plasma the oldest references are relesed. """ remove_old_references_timeout: float = 0.0 """ The maximum amount of time to wait, in seconds, for a pending response when removing references to old input objects in the Plasma store. """ wait_for_all_responses_timeout: float = 5.0 """ The maximum amount of time to wait, in seconds, for all pending responses to be read from Plasma when the consumer is shut down. """
[docs] @dataclasses.dataclass class InvocationResults: """Summary of RPC call results""" success: int = 0 fail: int = 0 unknown: int = 0
[docs] def add(self, result: int): """Add a single result, which signals either success or failure""" if result == 0: self.success += 1 else: self.fail += 1
[docs] class PlasmaWriterConsumer(Consumer): """A heap consumer that writes incoming data into a Plasma Store. Because data consumption happens inside the event loop we need to defer the data writing to a different thread. We do this by creating a single-threaded executor that we then use to schedule the access to the Plasma Store. """ config_class = PlasmaWriterConfig @overrides def __init__( self, config: PlasmaWriterConfig, tm: TelescopeManager, uvw_engine: UVWEngine, ): super().__init__(config, tm, uvw_engine) self._plasma_path = config.plasma_path self.store: rpc.Store | None = None self.caller: rpc.Caller | None = None self._rpc_inout_refs: list[tuple[References, References]] = [] self._common_refs: dict = {} self._scan_type_refs: dict[str, dict] = {} self._scan_id_refs: dict[int, dict] = {} self._invocation_results = InvocationResults() self.connect()
[docs] def disconnect(self): """Disconnect from the Plasma store, if still connected.""" if self.store: connection = self.store.conn try: connection._socket.close() # pylint: disable=protected-access except Exception as ex: # pylint: disable=broad-exception-caught logger.warning("Failed to close notification socket, ignoring: %s", ex) connection.client.disconnect()
[docs] def connect(self): """Connect to the plasma store and reset all local tensor references.""" self.store = rpc.Store(self.config.plasma_path) self.caller = rpc.Caller( [PROCEDURES[name] for name in ("process_visibility", "start_scan", "end_scan")], self.store, broadcast=True, verbose=True, minimum_processors=0, ) self._rpc_inout_refs.clear() self._common_refs.clear() self._scan_type_refs.clear() self._scan_id_refs.clear()
@property def invocation_results(self): """A summary of the results from RPC invocations.""" return self._invocation_results @property def invocations_in_flight(self): """ The number of RPC invocations that are still in flight, and for which we haven't collected their output. """ return len(self._rpc_inout_refs) @property def num_processors(self): """The number of processors known to this caller.""" # pylint: disable-next=protected-access return len(self.caller._processors)
[docs] def find_processors(self): """Search the caller for processors.""" self.caller.find_processors(verbose=True) return self.num_processors
def _invoke_rpc(self, rpc_call_name, *input_refs) -> None: logger.debug("Remotely invoking %s", rpc_call_name) output_refs = getattr(self.caller, rpc_call_name)(*input_refs) # Keep arguments and responses in scope self._rpc_inout_refs.append((input_refs, output_refs))
[docs] @reconnect_on_plasma_exit async def start_scan(self, scan: Scan) -> None: """Invoke the start_scan RPC procedure on remote processors.""" scan_id = scan.scan_number scan_id_tensor = self.store.put_new_tensor(np.array([scan_id], dtype=np.int64)) assert scan_id not in self._scan_id_refs self._scan_id_refs[scan_id] = scan_id_tensor self.find_processors() # pylint: disable-next=no-member self._invoke_rpc("start_scan", scan_id_tensor)
[docs] @reconnect_on_plasma_exit async def end_scan(self, scan: Scan) -> None: """Invoke the end_scan RPC procedure on remote processors.""" self.find_processors() scan_id_tensor = self._get_existing_scan_id_tensor_ref(scan.scan_number) # pylint: disable-next=no-member self._invoke_rpc("end_scan", scan_id_tensor)
def _get_existing_scan_id_tensor_ref(self, scan_id) -> rpc.connection.Ref: if scan_id not in self._scan_id_refs: logger.warning("Tensor for scan %d not found in Plasma, creating one anyway", scan_id) self._scan_id_refs[scan_id] = self.store.put_new_tensor( np.array([scan_id], dtype=np.int64) ) return self._scan_id_refs[scan_id]
[docs] @overrides @reconnect_on_plasma_exit async def consume(self, visibility: Visibility): self.find_processors() await self._remove_old_references() self._write_data_to_plasma(visibility)
[docs] @overrides @ignore_plasma_exit async def astop(self): timed_out = await self._wait_for_all_responses() if timed_out: logger.warning( "Timed out while waiting for pending RPC responses, will disconnect now anyway" ) self._remove_all_common_refs() self.disconnect()
@staticmethod def _get_polarizations(scan_type: ScanType): """Get a columnar table serialization of polarizations for a single ScanType.""" return { "polarizations_id": [beam.polarisations.polarisation_id for beam in scan_type.beams], "corr_type": [ list(map(lambda t: t.name, beam.polarisations.correlation_type)) for beam in scan_type.beams ], } @staticmethod def _get_spectral_windows(scan_type: ScanType): """Get a columnar table serialization of spectral windows for a single ScanType.""" def get_spectral_windows_attribute(name: str): return [ getattr(sw, name) for beam in scan_type.beams for sw in beam.channels.spectral_windows ] return { "channels_id": [ beam.channels.channels_id for beam in scan_type.beams for sw in beam.channels.spectral_windows ], "spectral_window_id": get_spectral_windows_attribute("spectral_window_id"), "count": get_spectral_windows_attribute("count"), "start": get_spectral_windows_attribute("start"), "stride": get_spectral_windows_attribute("stride"), "freq_min": get_spectral_windows_attribute("freq_min"), "freq_max": get_spectral_windows_attribute("freq_max"), } @staticmethod def _get_channels(scan_type: ScanType): """Get a columnar table serialization of channels for a single ScanType.""" return { "channels_id": [beam.channels.channels_id for beam in scan_type.beams], } @staticmethod def _get_fields(scan_type: ScanType): """Get a columnar table serialization of fields for a single ScanType.""" return { "field_id": [beam.field.field_id for beam in scan_type.beams if beam.field], "phase_dir": [ { "ra": beam.field.phase_dir.ra.rad, "dec": beam.field.phase_dir.dec.rad, "reference_time": beam.field.phase_dir.reference_time, "reference_frame": beam.field.phase_dir.reference_frame, } for beam in scan_type.beams if beam.field ], } @staticmethod def _get_beams(scan_type: ScanType): """Get a columnar table serialization of beams for a single ScanType.""" # TODO(cgray): this integer mapping of beams can't be used directly with measurement # sets due to BEAM_ID corresponding to a row entry. Unless beams across all # scans are written this will need additional mapping. return { "beam_id": [beam.beam_id for beam in scan_type.beams if beam.field], "field_id": [beam.field.field_id for beam in scan_type.beams if beam.field], "channels_id": [beam.channels.channels_id for beam in scan_type.beams if beam.field], "polarizations_id": [ beam.polarisations.polarisation_id for beam in scan_type.beams if beam.field ], "function": [beam.function for beam in scan_type.beams if beam.field], } @staticmethod def _get_scan_type(scan_type: ScanType): """ Return a simple dict for the given ScanType, suitable for turning into a SCAN_TYPE. """ return { "scan_type_id": [scan_type.scan_type_id], "integration_time": [scan_type.integration_time], "scan_intents": [scan_type.scan_intents], "averaging_channels": [scan_type.averaging_channels], "averaging_samples": [scan_type.averaging_samples], } def _write_common_metadata(self, visibility: Visibility): baselines = table_utils.get_baselines(visibility) antennas = table_utils.get_antennas(visibility.attrs["configuration"]) return { "baselines": self.store.put_new_table(baselines, schema=TABLES["BASELINE"]), "antennas": self.store.put_new_table(antennas, schema=TABLES["ANTENNA2"]), } def _write_scan_type_metadata(self, scan_type: ScanType): return { "polarizations": self.store.put_new_table( self._get_polarizations(scan_type), schema=TABLES["POLARIZATION"] ), "spectral_window": self.store.put_new_table( self._get_spectral_windows(scan_type), schema=TABLES["SPECTRAL_WINDOW"] ), "channels": self.store.put_new_table( self._get_channels(scan_type), schema=TABLES["CHANNELS"] ), "field": self.store.put_new_table(self._get_fields(scan_type), schema=TABLES["FIELD"]), "beam": self.store.put_new_table(self._get_beams(scan_type), schema=TABLES["BEAM"]), "scan_type": self.store.put_new_table( self._get_scan_type(scan_type), schema=TABLES["SCAN_TYPE"] ), } def _write_data_to_plasma(self, visibility: Visibility): """ Writes common_metadata tables to plasma on first visibility, then time dependent data on subsequent ones. """ payload_seq_numbers: int = visibility.attrs["meta"]["payload_seq_numbers"] scan: Scan = visibility.attrs["meta"]["scan"] scan_type = scan.scan_type if not self._common_refs: self._common_refs = self._write_common_metadata(visibility) if scan_type.scan_type_id not in self._scan_type_refs: self._scan_type_refs[scan_type.scan_type_id] = self._write_scan_type_metadata( scan_type ) scan_type_refs = self._scan_type_refs[scan_type.scan_type_id] channel_ids = visibility.meta["channel_ids"] # start_scan() should have been called before, but if the receiver restarted # then a reference to the exact scan_id in plasma won't exist yet scan_id_tensor = self._get_existing_scan_id_tensor_ref(scan.scan_number) vis_beam_id = visibility.meta["beam_id"] # Plasma can't transmit string parameters, so vis_beam_id can't be directly communicated. # Instead an index into the BEAM table is sent beam_index = [beam.beam_id for beam in scan_type.beams].index(vis_beam_id) phase_direction = scan_type.beams[beam_index].field.phase_dir uvw = self._get_uvw(visibility, phase_direction) telescope = visibility.meta["telescope"] == Telescope.MID # Put everything on Plasma according to # sdp_dal_schemas/schemas/procedures.json at process_visibility.inputs input_refs = ( scan_id_tensor, self.store.put_new_tensor(np.array([beam_index])), self.store.put_new_tensor(np.array([telescope])), self.store.put_new_tensor(payload_seq_numbers), self.store.put_new_tensor(visibility.time.data), self.store.put_new_tensor(visibility.exposure.data), self.store.put_new_tensor(channel_ids), self.store.put_new_tensor(visibility.flags.data.astype(bool)), self.store.put_new_tensor(visibility.weight.data), scan_type_refs["scan_type"], scan_type_refs["beam"], self._common_refs["baselines"], scan_type_refs["spectral_window"], scan_type_refs["channels"], self._common_refs["antennas"], scan_type_refs["field"], scan_type_refs["polarizations"], self.store.put_new_tensor(uvw), self.store.put_new_tensor(visibility.vis.data, rpc.complex64), self.store.put_new_tensor(visibility.time_centroids.data), ) # pylint: disable-next=no-member self._invoke_rpc("process_visibility", *input_refs) def _get_response(self, refs, timeout, executor_done_evt): """Obtain results from previous calls referenced by `refs`.""" try: return [ref["output"].get(timeout) for ref in refs] finally: executor_done_evt.set()
[docs] async def wait_for_oldest_reponse(self, timeout): """ Waits for the response of the oldest RPC call issued by this consumer, and removes references to its inputs and output objects so they can be freed by plasma. """ # _in_refs has to be kept around until out_refs are awaited for a reponse _in_refs, out_refs = self._rpc_inout_refs.pop(0) executor_done_evt = threading.Event() try: all_results = await asyncio.get_running_loop().run_in_executor( None, self._get_response, out_refs, timeout, executor_done_evt ) except BaseException: # We might have been interrupted while waiting for the future, so # let's make sure self._get_response finishes before we do to avoid # hanging the default threadpool executor for future users. while True: if executor_done_evt.is_set(): break await asyncio.sleep(0.1) self._invocation_results.unknown += len(out_refs) raise else: if all_results: # all_results is a list of ndarrays, each with 1 dimension for invocation_res in all_results: assert len(invocation_res.shape) == 1 if len(invocation_res) != 1: logger.warning( "Received result array with count != 1, reported " "results might be incorrect" ) if len(invocation_res) == 0: self._invocation_results.unknown += 1 else: self._invocation_results.add(invocation_res[0]) return all_results
async def _remove_old_references(self): """Remove references to calls older than the configured limit.""" while len(self._rpc_inout_refs) >= self.config.payloads_in_flight: try: await self.wait_for_oldest_reponse(self.config.remove_old_references_timeout) except rpc.connection.TimeoutException: logger.warning("Timed out while waiting for response, ignoring") async def _wait_for_all_responses(self): timeout = self.config.wait_for_all_responses_timeout timed_out = False while self._rpc_inout_refs: start = time.time() try: await self.wait_for_oldest_reponse(timeout) except rpc.connection.TimeoutException: timed_out = True finally: timeout = max(0, timeout - (time.time() - start)) return timed_out
[docs] @ignore_plasma_exit async def wait_for_all_responses(self) -> bool: """Wait for the responses for all in-flight invocations. :returns: whether there was a timeout while waiting for all responses. """ return await self._wait_for_all_responses()
def _remove_all_common_refs(self): self._common_refs = {} self._scan_type_refs = {}