Source code for realtime.receive.modules.consumers.plasma_writer

# -*- coding: utf-8 -*-


""" Takes a SPEAD2 HEAP and writes it to an apache plasma store. This uses the sdp-dal-prototype
    API and will fail if that cannot be loaded
"""

import asyncio
import dataclasses
import logging
import threading
import time
from typing import Dict, Sequence

import numpy as np
import ska_sdp_dal as rpc
from overrides import overrides
from realtime.receive.core import Scan, ScanType
from realtime.receive.core.common import autocast_fields
from realtime.receive.core.uvw_engine import UVWEngine
from sdp_dal_schemas import PROCEDURES, TABLES
from ska_sdp_datamodels.visibility import Visibility

from realtime.receive.modules.consumers.consumer import Consumer
from realtime.receive.modules.tm import TelescopeManager
from realtime.receive.modules.utils import table_utils

from ._config import Config

logger = logging.getLogger(__name__)


References = Sequence[rpc.connection.Ref]


[docs] @dataclasses.dataclass @autocast_fields class PlasmaWriterConfig(Config): """Configuration for the `plasma_writer` consumer.""" name: str = "plasma_writer" plasma_path: str = "/tmp/plasma" """The UNIX socket to connect to.""" payloads_in_flight: int = 10 """ The maximum number of payloads to keep in flight in the Plasma Store before requesting their results and releasing our local references to them. If more payloads are written to Plasma the oldest references are relesed. """ remove_old_references_timeout: float = 0.0 """ The maximum amount of time to wait, in seconds, for a pending response when removing references to old input objects in the Plasma store. """ wait_for_all_responses_timeout: float = 5.0 """ The maximum amount of time to wait, in seconds, for all pending responses to be read from Plasma when the consumer is shut down. """
[docs] @dataclasses.dataclass class InvocationResults: """Summary of RPC call results""" success: int = 0 fail: int = 0 unknown: int = 0
[docs] def add(self, result: int): """Add a single result, which signals either success or failure""" if result == 0: self.success += 1 else: self.fail += 1
[docs] class PlasmaWriterConsumer(Consumer): """ A heap consumer that writes incoming data into a Plasma Store. Because data consumption happens inside the event loop we need to defer the data writing to a different thread. We do this by creating a single-threaded executor that we then use to schedule the access to the Plasma Store. """ config_class = PlasmaWriterConfig @overrides def __init__( self, config: PlasmaWriterConfig, tm: TelescopeManager, uvw_engine: UVWEngine, ): super().__init__(config, tm, uvw_engine) self.store = rpc.Store(config.plasma_path) self.caller = rpc.Caller( [PROCEDURES[name] for name in ("process_visibility", "start_scan", "end_scan")], self.store, broadcast=True, verbose=True, minimum_processors=0, ) self._rpc_inout_refs: list[tuple[References, References], ...] = [] self._common_refs: Dict = {} self._scan_type_refs: Dict[str, Dict] = {} self._scan_id_refs: Dict[int, Dict] = {} self._invocation_results = InvocationResults() @property def invocation_results(self): """A summary of the results from RPC invocations""" return self._invocation_results @property def invocations_in_flight(self): """ The number of RPC invocations that are still in flight, and for which we haven't collected their output. """ return len(self._rpc_inout_refs) @property def num_processors(self): """The number of processors known to this caller""" # pylint: disable-next=protected-access return len(self.caller._processors)
[docs] def find_processors(self): """Search the caller for processors""" self.caller.find_processors(verbose=True) return self.num_processors
def _invoke_rpc(self, rpc_call, *input_refs) -> None: output_refs = rpc_call(*input_refs) # Keep arguments and responses in scope self._rpc_inout_refs.append((input_refs, output_refs))
[docs] async def start_scan(self, scan_id: int) -> None: """Invoke the start_scan RPC procedure on remote processors.""" scan_id_tensor = self.store.put_new_tensor(np.array([scan_id], dtype=np.int64)) assert scan_id not in self._scan_id_refs self._scan_id_refs[scan_id] = scan_id_tensor self.find_processors() # pylint: disable-next=no-member self._invoke_rpc(self.caller.start_scan, scan_id_tensor)
[docs] async def end_scan(self, scan_id: int) -> None: """Invoke the end_scan RPC procedure on remote processors.""" if scan_id not in self._scan_id_refs: logger.warning("Scan %d ended, but we didn't see it start, continuing anyway", scan_id) scan_id_tensor = self.store.put_new_tensor(np.array([scan_id], dtype=np.int64)) else: scan_id_tensor = self._scan_id_refs[scan_id] self.find_processors() # pylint: disable-next=no-member self._invoke_rpc(self.caller.end_scan, scan_id_tensor)
[docs] @overrides async def consume(self, visibility: Visibility): logger.debug("Storing payload on plasma for remote invocations") self.find_processors() await self._remove_old_references() self._write_data_to_plasma(visibility)
[docs] @overrides async def astop(self): timed_out = await self._wait_for_all_responses() if timed_out: logger.warning( "Timed out while waiting for pending RPC responses, will disconnect now anyway" ) self._remove_all_common_refs() self.store.conn.client.disconnect()
@staticmethod def _get_polarizations(scan_type: ScanType): """ Gets a columnar table serialization of polarizations for a single ScanType """ return { "polarizations_id": [beam.polarisations.polarisation_id for beam in scan_type.beams], "corr_type": [ list(map(lambda t: t.name, beam.polarisations.correlation_type)) for beam in scan_type.beams ], } @staticmethod def _get_spectral_windows(scan_type: ScanType): """ Gets a columnar table serialization of spectral windows for a single ScanType """ def get_spectral_windows_attribute(name: str): return [ getattr(sw, name) for beam in scan_type.beams for sw in beam.channels.spectral_windows ] return { "channels_id": [ beam.channels.channels_id for beam in scan_type.beams for sw in beam.channels.spectral_windows ], "spectral_window_id": get_spectral_windows_attribute("spectral_window_id"), "count": get_spectral_windows_attribute("count"), "start": get_spectral_windows_attribute("start"), "stride": get_spectral_windows_attribute("stride"), "freq_min": get_spectral_windows_attribute("freq_min"), "freq_max": get_spectral_windows_attribute("freq_max"), } @staticmethod def _get_channels(scan_type: ScanType): """ Gets a columnar table serialization of channels for a single ScanType """ return { "channels_id": [beam.channels.channels_id for beam in scan_type.beams], } @staticmethod def _get_fields(scan_type: ScanType): """ Gets a columnar table serialization of fields for a single ScanType """ return { "field_id": [beam.field.field_id for beam in scan_type.beams if beam.field], "phase_dir": [ { "ra": beam.field.phase_dir.ra.rad, "dec": beam.field.phase_dir.dec.rad, "reference_time": beam.field.phase_dir.reference_time, "reference_frame": beam.field.phase_dir.reference_frame, } for beam in scan_type.beams if beam.field ], } @staticmethod def _get_beams(scan_type: ScanType): """ Gets a columnar table serialization of beams for a single ScanType """ # TODO(cgray): this integer mapping of beams can't be used directly with measurement # sets due to BEAM_ID corresponding to a row entry. Unless beams across all # scans are written this will need additional mapping. return { "beam_id": [beam.beam_id for beam in scan_type.beams if beam.field], "field_id": [beam.field.field_id for beam in scan_type.beams if beam.field], "channels_id": [beam.channels.channels_id for beam in scan_type.beams if beam.field], "polarizations_id": [ beam.polarisations.polarisation_id for beam in scan_type.beams if beam.field ], "function": [beam.function for beam in scan_type.beams if beam.field], } def _write_common_metadata(self, visibility: Visibility): baselines = table_utils.get_baselines(visibility) antennas = table_utils.get_antennas(visibility.attrs["configuration"]) return { "baselines": self.store.put_new_table(baselines, schema=TABLES["BASELINE"]), "antennas": self.store.put_new_table(antennas, schema=TABLES["ANTENNA2"]), } def _write_scan_type_metadata(self, scan_type: ScanType): return { # "scan_type_id": self.store.put_new_tensor( # np.array([scan_type.scan_type_id]) # ), "polarizations": self.store.put_new_table( self._get_polarizations(scan_type), schema=TABLES["POLARIZATION"], ), "spectral_window": self.store.put_new_table( self._get_spectral_windows(scan_type), schema=TABLES["SPECTRAL_WINDOW"], ), "channels": self.store.put_new_table( self._get_channels(scan_type), schema=TABLES["CHANNELS"], ), "field": self.store.put_new_table( self._get_fields(scan_type), schema=TABLES["FIELD"], ), "beam": self.store.put_new_table( self._get_beams(scan_type), schema=TABLES["BEAM"], ), } def _write_data_to_plasma(self, visibility: Visibility): """ Writes common_metadata tables to plasma on first visibility, then time dependent data on subsequent ones. """ payload_seq_numbers: int = visibility.attrs["meta"]["payload_seq_numbers"] scan: Scan = visibility.attrs["meta"]["scan"] scan_type = scan.scan_type if not self._common_refs: self._common_refs = self._write_common_metadata(visibility) if scan_type.scan_type_id not in self._scan_type_refs: self._scan_type_refs[scan_type.scan_type_id] = self._write_scan_type_metadata( scan_type ) scan_type_refs = self._scan_type_refs[scan_type.scan_type_id] intervals = visibility.integration_time.data.astype("float64") exposures = intervals channel_ids = visibility.meta["channel_ids"] uvw = self._get_uvw(visibility, scan) # start_scan() MUST have been called for the visibility argument to # make it all the way through here, because the receiver drops payloads # for scans it doesn't know about -- and as soon as it learns about a # scan, it tells its consumer about it via start_scan() assert scan.scan_number in self._scan_id_refs, f"Unknown scan id {scan.scan_number}" # Put everything on Plasma according to # sdp_dal_schemas/schemas/procedures.json at process_visibility.inputs input_refs = ( self._scan_id_refs[scan.scan_number], self.store.put_new_tensor(payload_seq_numbers), self.store.put_new_tensor(visibility.time.data), self.store.put_new_tensor(intervals), self.store.put_new_tensor(exposures), self.store.put_new_tensor(channel_ids), self.store.put_new_tensor(visibility.flags.data.astype(bool)), self.store.put_new_tensor(visibility.weight.data), scan_type_refs["beam"], self._common_refs["baselines"], scan_type_refs["spectral_window"], scan_type_refs["channels"], self._common_refs["antennas"], scan_type_refs["field"], scan_type_refs["polarizations"], self.store.put_new_tensor(uvw), self.store.put_new_tensor(visibility.vis.data, rpc.complex64), ) # pylint: disable-next=no-member self._invoke_rpc(self.caller.process_visibility, *input_refs) def _get_response(self, refs, timeout, executor_done_evt): """Obtains results from previous calls referenced by `refs`""" try: return [ref["output"].get(timeout) for ref in refs] finally: executor_done_evt.set()
[docs] async def wait_for_oldest_reponse(self, timeout): """ Waits for the response of the oldest RPC call issued by this consumer, and removes references to its inputs and output objects so they can be freed by plasma. """ # _in_refs has to be kept around until out_refs are awaited for a reponse _in_refs, out_refs = self._rpc_inout_refs.pop(0) executor_done_evt = threading.Event() try: all_results = await asyncio.get_running_loop().run_in_executor( None, self._get_response, out_refs, timeout, executor_done_evt ) except BaseException: # We might have been interrupted while waiting for the future, so # let's make sure self._get_response finishes before we do to avoid # hanging the default threadpool executor for future users. while True: if executor_done_evt.is_set(): break await asyncio.sleep(0.1) self._invocation_results.unknown += len(out_refs) raise else: if all_results: # all_results is a list of ndarrays, each with 1 dimension for invocation_res in all_results: assert len(invocation_res.shape) == 1 if len(invocation_res) != 1: logger.warning( "Received result array with count != 1, reported " "results might be incorrect" ) if len(invocation_res) == 0: self._invocation_results.unknown += 1 else: self._invocation_results.add(invocation_res[0]) return all_results
async def _remove_old_references(self): """Removes references to calls older than the configured limit.""" while len(self._rpc_inout_refs) >= self.config.payloads_in_flight: try: await self.wait_for_oldest_reponse(self.config.remove_old_references_timeout) except rpc.connection.TimeoutException: logger.warning("Timed out while waiting for response, ignoring") async def _wait_for_all_responses(self): timeout = self.config.wait_for_all_responses_timeout timed_out = False while self._rpc_inout_refs: start = time.time() try: await self.wait_for_oldest_reponse(timeout) except rpc.connection.TimeoutException: timed_out = True finally: timeout = max(0, timeout - (time.time() - start)) return timed_out
[docs] async def wait_for_all_responses(self) -> bool: """ Wait for the responses for all in-flight invocations. :returns: whether there was a timeout while waiting for all responses. """ return await self._wait_for_all_responses()
def _remove_all_common_refs(self): self._common_refs = {} self._scan_type_refs = {}