Source code for ska_ser_skallop.connectors.remoting.tangobridge.wscontrol

"""Facilitates a tangogql websocket connection to a client."""
import asyncio
import atexit
import json
import logging
import socket
from concurrent.futures import Future
from threading import Event, Lock
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterable,
    Callable,
    Dict,
    Iterable,
    List,
    Literal,
    NamedTuple,
    Union,
)

import requests
from websockets.client import WebSocketClientProtocol
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode

from .base import (
    Disconnected,
    MessagePusher,
    Reconnected,
    Subscriber,
    WSHealthSubscriber,
)
from .control import cancel_future
from .factories import TBridgeFactory
from .subscription_control import SubscriptionController
from .ws_messages import ws_message_init

logger = logging.getLogger(__name__)


[docs]class Selector(Subscriber): """Subscriber that funnels/filters incoming events to downstream subscribers.""" def __init__(self, predicate: Callable[[Any], bool], name="") -> None: """Initialise the selector. :param predicate: A check/predicate to run on the input event whether it must be passed on. If the event is to be selected, the predicate function must return True. :type predicate: Callable[[Any], bool] :param name: human readable name for the selector to help with logging, defaults to "" :type name: str, optional """ self.queue: Union[asyncio.Queue, None] = None self.predicate = predicate self.subscribers: List[Subscriber] = [] self.name = name
[docs] def push_event(self, event: Any): """Receive and handle the subscriber push event. :param event: A event that the producer is required to push to subscribers :type event: Any """ assert self.queue self.queue.put_nowait(event)
[docs] def subscribe(self, subscriber: Subscriber): """Let subscribers subscribe to this object as a producer of selected events. :param subscriber: the subscriber to be called when event is selected :type subscriber: Subscriber """ self.subscribers.append(subscriber)
[docs] def bind(self, loop): """Bind this object to an asyncio loop for asynchronous waiting. :param loop: the loop belonging to an asynchronous thread :type loop: [type] """ self.queue = asyncio.Queue()
async def _publish_event(self, event): for subscriber in self.subscribers: subscriber.push_event(event)
[docs] async def listen(self, stop: Event): """Listen asynchronously for events produced and push if selected. Will stop listening when stop event is set. :param stop: The stop event which will signal to task to stop listening. :type stop: Event """ assert self.queue while stop.is_set(): event = await self.queue.get() if self.predicate(event): await self._publish_event(event)
[docs]class WSHealthSelector(Selector): """Specific Selector that looks at only WS health type of events.""" def __init__(self) -> None: """Initialise object.""" name = "websocket health selector" predicate = self._select_for_health_monitoring super().__init__(predicate, name) self.health_subscribers: list[WSHealthSubscriber] = [] def _select_for_health_monitoring( self, event: Union[Dict, "Reconnected", "Disconnected"] ): if isinstance(event, (Reconnected, Disconnected)): return True return False async def _publish_event(self, event: Union["Reconnected", "Disconnected"]): for subscriber in self.health_subscribers: await subscriber.push_health_event(event)
[docs] def subscribe_health_subscriber(self, subscriber: WSHealthSubscriber): """Add a ws health subscriber object to subscribe to ws health events. :param subscriber: the subscriber for which we health events should be published for. """ self.health_subscribers.append(subscriber)
[docs]class MessageContext(NamedTuple): """Bundles a message inbox, outbox and subscribers list into a single object.""" outbox: asyncio.Queue inbox: asyncio.Queue subscribers: List[Subscriber] = []
[docs]class BufferedSubscriber(Subscriber): """Subscriber that places received events in a buffer for later retrieval. Getting results from the buffer happens asynchronously. """ def __init__(self) -> None: """Initialise the buffered subscriber. Since the object will get events and block asynchronously until any have been received, an asynchronous loop must be given as input paramater. """ self.queue = asyncio.Queue()
[docs] def push_event(self, event: Any): """Receive and handle the subscriber push event. :param event: A event that the producer is required to push to subscribers :type event: Any """ self.queue.put_nowait(event)
[docs] async def get_event(self): """Asynchronously wait for incoming events.""" await self.queue.get()
[docs]class Websocket: """Manages a websocket connection and wraps the websocket api provided by a factory.""" def __init__(self, factory: TBridgeFactory) -> None: """Initialise the websocket. :param factory: the factory to use for getting the websocket implementation. :type factory: TBridgeFactory """ self.url = factory.get_tango_gql_ws_url() self.websockets = factory.get_websockets() self.websocket: Union[None, WebSocketClientProtocol] = None self.ws_healthy = Event() self.running = Event() self._connection_status: Literal["closed", "open"] = "closed" self.subscriber = None self._exceptions: List[Exception] = [] self.reconnected_flag: Event = Event() def __bool__(self) -> bool: """Evaluate if a websocket have been initialised and connected. :return: True if the websocket is connected. """ return self.websocket is not None
[docs] async def block_until_healthy(self, timeout=60): """Asynchronously wait until a websocket is healthy. :param timeout: the time to block until the task can not continue :raises TimeoutError: When unhealthy longer than given timeout """ polling_period = 10 counter = int(timeout / polling_period) while not self.ws_healthy.is_set(): if counter == 0: raise TimeoutError( f"Websocket is not becoming healthy after {timeout}s" ) logger.warning( f"waiting for {polling_period} seconds for websocket to be healthy" ) await asyncio.sleep(polling_period) counter -= 1
[docs] def set_health_ok(self): """Log and set the observed health of the websocket as healthy.""" self.ws_healthy.set()
[docs] def set_health_not_ok(self): """Log and set the observed health of the websocket as not healthy.""" self.ws_healthy.clear()
[docs] def healthy(self) -> bool: """Report the current observed health of websocket. :return: Returns True if health is ok """ return self.ws_healthy.is_set()
[docs] def wait_until_healthy(self, timeout=1): """Wait for a given period until a thread has observed the health of ws to be ok. :param timeout: The time to wait, defaults to 1 :type timeout: int, optional :raises TimeoutError: when the wait for a websocket exceeds the given timeout """ if not self.ws_healthy.wait(timeout): raise TimeoutError( f"Timeout after {timeout} seconds waiting for websocket to be healthy" )
[docs] def get_health_selector(self) -> Selector: """Create a health selector that will generate events related to websocket health. :return: the created selector object """ def select_for_connection_ack(event: Dict) -> bool: if event_type := event.get("type"): return "connection_ack" in event_type return False health_selector = Selector( select_for_connection_ack, "select for acknowledge data", ) self.subscriber = BufferedSubscriber() health_selector.subscribe(self.subscriber) return health_selector
[docs] async def close(self): """Call the websocket close command asynchronously.""" self.running.clear() if self.websocket: await self.websocket.close() self._connection_status = "closed" self.websocket = None
[docs] async def send(self, message: Union[str, bytes, Iterable[Any], AsyncIterable[Any]]): """Send an asynchronous message over the websocket. :param message: The websocket message or messages :type message: Union[str, bytes, Iterable[Any], AsyncIterable[Any]] """ if self.websocket: await self.websocket.send(message)
[docs] async def connect(self): """Asynchronously connects to a remove websocket service provider. If the connection is not available the task will retry to connect every second until either a connection is successfull or the websocket is closed (:py:meth:`close`). """ self.running.set() while self.running.is_set(): try: self.websocket = await self.websockets.connect( # type: ignore uri=self.url, subprotocols=["graphql-ws"] ) self._connection_status = "open" connection_init_message = ws_message_init() await self.send(connection_init_message) self.set_health_ok() return except InvalidStatusCode: # retry after 4 times polling period logger.warning( f"Unable to connect to websocket service {self.url}, will retry " "connection in 5 seconds" ) self.set_health_not_ok() await asyncio.sleep(5) except requests.exceptions.ConnectionError: logger.warning( f"Unable to authenticate on {self.url} as connection is not " "available" ) self.set_health_not_ok() await asyncio.sleep(5) except socket.error: logger.warning( "Unable to establish an outside connection to the network, will retry " "connecting in 5 seconds" ) self.set_health_not_ok() await asyncio.sleep(5) except Exception as general_exception: logger.exception( "Unable to establish an outside connection to the network, will retry " f"connecting in 5 seconds exception: {general_exception}" ) self.set_health_not_ok() await asyncio.sleep(5)
[docs] async def get_messages(self) -> AsyncGenerator[Any, None]: """Get incoming websocket messages asynchronously as json decoded objects. For example: .. code-block:: python async for message in ws.get_messages(): handle_message(message) :yield: The json decoded message from the websocket """ while self.running.is_set(): try: await self.block_until_healthy() except TimeoutError as error: logger.exception(error) yield Disconnected() assert self.websocket, "messages called before websocket initialised" # i.e websocket was disconnected and reconnected # this means the next message to send is a reconnect if self.reconnected_flag.is_set(): self.reconnected_flag.clear() logger.info("submitting a reconnection flag event") yield Reconnected() try: async for message in self.websocket: try: decoded = json.loads(message) yield decoded except Exception as exception: logger.warning(exception) yield {} except ConnectionClosedError: self.set_health_not_ok() logger.warning( "unable to receive incoming messages as connection unexpectedly closed" ) yield Disconnected()
[docs] async def monitor_ws(self, timeout=1): """Asynchronously ping the web socket continuously until ws is closed (:py:meth:`close`). Note this method assumes a health selector has already been defined as per :py:meth:`get_health_selector`. :param timeout: How long to wait for a ping result to be received before deeming it as faulty, defaults to 1. :type timeout: int, optional """ assert self.subscriber while self.running.is_set(): await self.ping(timeout) if self._connection_status == "closed": logger.warning("websocket connection closed unexpectedly") await self.connect() # if we got disconnected but reconnected again it means # we need to signal that we have reconnecting self.reconnected_flag.set() logger.info("websocket reconnected") await asyncio.sleep(timeout)
[docs] async def ping(self, timeout=1): """Asynchronously ping the web socket. :param timeout: How long to wait for a ping result to be received before deeming it as faulty, defaults to 1. :type timeout: int, optional """ assert self.subscriber connection_init_message = ws_message_init() try: await self.send(connection_init_message) except ConnectionClosedError as error: self._connection_status = "closed" self.set_health_not_ok() self._exceptions.append(error) return except ConnectionRefusedError as error: self._connection_status = "closed" self.set_health_not_ok() self._exceptions.append(error) return receive_ack = self.subscriber.get_event() try: await asyncio.wait_for(receive_ack, timeout=timeout) except asyncio.exceptions.TimeoutError as error: self.set_health_not_ok() self._exceptions.append(error)
[docs]class WSController(MessagePusher): """Monitors and controls a websocket connection.""" monitor_polling_period = 5 def __init__(self, factory: TBridgeFactory = TBridgeFactory()) -> None: """Initialise a websocket controller. :param factory: The factory to use for getting a websocket implementation , defaults to TBridgeFactory() :type factory: TBridgeFactory """ self.factory = factory self.websocket = Websocket(factory) self._controller = factory.get_controller() self.running = Event() self.selectors: List[Subscriber] = [] self.selector_listeners: List["Future[None]"] = [] self.lock = Lock() self.running.set() self._subscription_controller = SubscriptionController(self) self.inbox = factory.generate_async_queue() self._ws_deamon = self._load_ws_controller() try: self.websocket.wait_until_healthy(10) except TimeoutError: self._ws_deamon.exception() logger.info(f"ws connected on {self.websocket.url}") self._ws_health_selector = WSHealthSelector() self._ws_health_selector.subscribe_health_subscriber( self._subscription_controller ) self.add_selector(self._ws_health_selector) atexit.register(self.tear_down) def _load_ws_controller(self): return self._controller.dispatch_concurrent_routine( self._ws_deamon_routine(), name="ws_controller" ) @property def ws_healthy(self) -> bool: """Whether the websocket connection is healthy. :return: Returns True if healthy """ return self.websocket.healthy()
[docs] def wait_until_ws_healthy(self, timeout=1): """Block until asynchronous monitoring threads have set the websocket as healthy. :param timeout: The maximum time to wait for websocket to become healthy, defaults to 1 :type timeout: int, optional """ self.websocket.wait_until_healthy(timeout)
[docs] def reload(self): """Close a current websocket connection and create a new one.""" self._controller.run_async_task(self._close_ws()) self._ws_deamon.result() self._ws_deamon = self._load_ws_controller()
[docs] def tear_down_ws(self): """Tear down montoring threads related to the websocket and close the ws connection.""" # NOTE disable closing ws as it seems to be giving problems with # proper closing down of processes during pytests # the ws is closed implicitly as a consequence of canceling the ws deamon # future # self._controller.run_async_task(self.close_ws(), "close_ws") self.finish_selector_listeners() cancel_future(self._ws_deamon)
def _generate_monitor_service_health_task(self, timeout=1): health_selector = self.websocket.get_health_selector() self.add_selector(health_selector) return self._controller.create_async_task(self.websocket.monitor_ws(timeout)) async def _ws_deamon_routine(self, retries=4, wait_time=5, timeout=0.5): nr_of_retries = 0 while nr_of_retries < retries: try: await self.websocket.connect() if self.running.is_set(): receive_task = self._controller.create_async_task( self._receive_ws_messages() ) produce_message_events_task = self._controller.create_async_task( self._produce_message_events_routine() ) monitor_service_health_task = ( self._generate_monitor_service_health_task(timeout) ) await asyncio.gather( receive_task, produce_message_events_task, monitor_service_health_task, return_exceptions=True, ) else: nr_of_retries = retries # close the loop # TODO implement more specific exception catching except Exception as exception: logger.info( f"exception raised on ws deamon task: {exception} will retry " f"connect in {wait_time} seconds" ) await asyncio.sleep(wait_time) finally: await self._close_ws() logger.debug("websocket connection closed") self.websocket.set_health_not_ok()
[docs] def push_message(self, item: Any): """Push a new message to be send by the websocket being controlled. :param item: The message to be send :type item: Any """ self._controller.run_async_task(self.push_message_routine(item), "push_message")
async def _close_ws(self): await self.websocket.close() async def _produce_message_events_routine(self): while self.running.is_set(): message = await self.inbox.get() with self.lock: for selector in self.selectors: selector.push_event(message)
[docs] async def push_message_routine(self, item: Any): """Send an asynchronous message on the websocket (must be from asyncio thread). :param item: The item to send :raises TimeoutError: when a messages could not be send due to a faulty websocket remaining faulty for longer than 10 seconds """ timeout = 10 try: await self.websocket.block_until_healthy(timeout=timeout) except TimeoutError as error: logger.warning( f"unable to push {item}, websocket have been unhealthy for longer than {timeout}" ) raise TimeoutError( f"unable to push {item}, websocket have been unhealthy for longer than {timeout}" ) from error await self.websocket.send(item)
async def _receive_ws_messages(self): try: async for message in self.websocket.get_messages(): self.inbox.put_nowait(message) except Exception as exception: logging.exception(exception)
[docs] def add_subscription(self, device: str, attribute: str) -> int: """Add a new subscription to the websocket based on events from a device attribute. Note the websocket will only produce a new subscription if there does not already exist a subscription for the same device and attribute, otherwise it will just "piggyback" on an existing subscription. :param device: The device (tango device producer) which must be subscribed to :type device: str :param attribute: The attribute from the device which will generate events. :type attribute: str :return: The subscription id to use for when a subscription needs to be removed (:py:meth:`remove_subscription`). """ return self._subscription_controller.add_subscription(device, attribute)
[docs] def remove_subscription(self, device: str, attribute: str) -> Union[None, int]: """Remove a subscription as identified by the given id. Note a subscription will only be removed virtually if other subscriptions still exist to the same device and attribute. If no subscriptions to the same device and attribute remains, then the actual subscription will be removed. :param device: The device for which a subscription have been made. :type device: str :param attribute: The attribute for which a subscription has been made. :type attribute: str :return: Returns empty if subscriptions still remain to the given device and attribute, otherwise will return the "base" subscription id upon which the subscriptions have been "piggy backed" on. """ return self._subscription_controller.remove_subscription(device, attribute)
[docs] def listen_to_websocket_health(self, subscriber: WSHealthSubscriber): """Add a ws health subscriber that will receive ws health change events. :param subscriber: The object to be called when an event occurs """ self._ws_health_selector.subscribe_health_subscriber(subscriber)
[docs] def add_selector(self, selector: Selector): """Add a selector (filter) to listen for incoming subscribed events. The selector will push events to downstream subscribers when certain kind of events (as defined by the selector's predicate function) have been received from the websocket. :param selector: [description] :type selector: Selector """ selector.bind(self._controller.get_loop()) future = self._controller.dispatch_concurrent_routine( selector.listen(self.running), selector.name ) self.selector_listeners.append(future) with self.lock: self.selectors.append(selector)
[docs] def finish_selector_listeners(self): """Gracefully end selector listening threads.""" for future in self.selector_listeners: cancel_future(future)
[docs] def tear_down(self): """Gracefully tear down threads related to monitoring and close ws connection.""" self.running.clear() self.tear_down_ws()