Source code for ska_tango_base.long_running_commands.api

#
# This file is part of the SKA Tango Base project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE.txt for more info.
"""This module provides a convenience client API for invoking long running commands."""

from __future__ import annotations

import inspect
import json
import logging
import threading
import traceback
import typing
import warnings
import weakref
from abc import ABC, abstractmethod

from ska_control_model import ResultCode, TaskStatus
from tango import (
    CommunicationFailed,
    ConnectionFailed,
    DevFailed,
    DeviceProxy,
    DeviceUnlocked,
    EventData,
    EventType,
    Except,
)

from ..callback_scheduler import CallbackID, CallbackScheduler, Queue
from ..faults import CommandError, ResultCodeError
from ..type_hints import LRCCallbackType, LRCSubscriptionsProtocol
from .common import _SUPPORTED_LRC_PROTOCOL_VERSIONS

module_logger = logging.getLogger(__name__)

__all__ = ["invoke_lrc", "connect_lrc_interface", "disconnect_lrc_interface"]


class _LRCSubscriptions:
    """
    LRC event subscriptions that is returned by invoke_lrc.

    Unsubscribes from all events when the instance is deleted.
    """

    def __init__(
        self,
        command_id: str,
        callback_ids: list[CallbackID],
        scheduler: CallbackScheduler,
        protocol_version: int,
    ) -> None:
        """
        Initialise a LRCSubscriptions instance.

        :param command_id: Unique command identifier.
        :param unsubscribe_lrc_events: Method to unsubscribe from all LRC change events.
        :param protocol_version: The LRC client-server protocol version used.
        """
        self._command_id = command_id
        self._lock = threading.Lock()
        self._callback_ids = callback_ids
        self._scheduler = weakref.ref(scheduler)
        self._protocol_version = protocol_version

    def __del__(self) -> None:
        """Delete the LRCSubscriptions instance."""
        self.unsubscribe_lrc_events()

    @property
    def command_id(self) -> str:
        """
        The command ID.

        :returns: the command ID.
        """
        return self._command_id

    def _unsubscribe_lrc_events(self) -> None:
        """
        Unsubscribe from LRC attributes' events.

        .. deprecated:: 1.6.0

            Use unsubscribe_lrc_events instead.
        """
        warnings.warn(
            "_LRCSubscriptions._unsubscribe_lrc_event has been deprecated since "
            "ska-tango-base 1.6.0. Use unsubscribe_lrc_events() instead.",
            DeprecationWarning,
        )
        self.unsubscribe_lrc_events()

    def unsubscribe_lrc_events(self) -> None:
        """Unsubscribe from LRC attributes' events."""
        sched = self._scheduler()

        if sched is None:
            return

        with self._lock:
            cids = self._callback_ids
            self._callback_ids = []

        for cid in cids:
            try:
                sched.unregister_callback(cid)
            except Exception:
                pass

    @property
    def protocol_version(self) -> int:
        """
        The LRC client-server protocol version used.

        :returns: the protocol version.
        """
        return self._protocol_version


class _LRCProtocol(ABC):
    """Abstract base class for a LRC client-server protocol."""

    LRC_EVENT_QUEUE_SIZE = 8192
    REQUIRED_ATTRIBUTES: list[str]

    def __init__(
        self,
        lrc_callback: LRCCallbackType,
        proxy: DeviceProxy,
        callback_scheduler: CallbackScheduler,
        logger: logging.Logger,
        command: str,
        command_args: tuple[typing.Any] | None = None,
    ) -> None:
        """
        Initialise a LRC protocol instance.

        :param lrc_callback: of client to wrap.
        :param proxy: Tango DeviceProxy.
        :param logger: to use for logging exceptions.
        :param command: Name of command to invoke.
        :param command_args: Optional arguments for the command, defaults to None.
        """
        # Class parameters
        self._lrc_callback = lrc_callback
        self._proxy = proxy
        self._logger = logger
        self._command = command
        self._command_args = command_args
        self._callback_scheduler = callback_scheduler
        # Other variables
        self._calling_thread = threading.current_thread()
        self._submitted = threading.Event()
        self._lock = threading.Lock()
        self._command_id: str = ""
        self._callback_ids: list[CallbackID] = []
        self._done = False

    @abstractmethod
    def invoke_lrc(self) -> LRCSubscriptionsProtocol:
        """Invoke the LRC."""
        raise NotImplementedError("_LRCProtocol is abstract.")

    @abstractmethod
    def _wrap_lrc_callback(self, event: EventData) -> None:
        """
        Wrap the instance's given user LRC callback.

        :param event: tango attribute change event
        """
        raise NotImplementedError("_LRCProtocol is abstract.")

    def _register_lrc_callbacks(self, attributes: list[str]) -> None:
        """
        Register callbacks listening to LRC attributes with CallbackScheduler.

        We make sure that all LRC attributes on the same device use the same queue
        so that the order is preserved.

        :param attributes: List of LRC attributes to subscribe to.
        """
        for i, attr in enumerate(attributes):
            if i == 0:

                def queue_factory() -> Queue:
                    return self._callback_scheduler.allocate_queue(
                        queue_size=self.LRC_EVENT_QUEUE_SIZE
                    )
            else:

                def queue_factory() -> Queue:
                    return self._callback_scheduler.get_queue(
                        self._proxy, attributes[0], EventType.CHANGE_EVENT
                    )

            try:
                f = self._callback_scheduler.register_event_callback(
                    self._proxy,
                    attr,
                    EventType.CHANGE_EVENT,
                    self._wrap_lrc_callback,
                    queue_factory=queue_factory,
                    initial_event=False,
                )
                event_id = f.result()
                self._callback_ids.append(event_id)
            except Exception as exc:
                self._unregister_lrc_callbacks()
                self._re_throw_exception(
                    exc, f"Subscribing to '{attr}' change event failed."
                )

    def _unregister_lrc_callbacks(self) -> None:
        """Unsubscribe from LRC attributes' events."""
        if self._callback_ids:
            with self._lock:
                cids = self._callback_ids
                self._callback_ids = []

            for cid in cids:
                try:
                    self._callback_scheduler.unregister_callback(cid)
                except Exception as e:
                    self._logger.warning(
                        f"scheduler.unregister_callback({cid}) failed with: {e}"
                    )

            self._done = True

    def _execute_command(self) -> None:
        """
        Execute the LRC on the device proxy and check the result code.

        :raises CommandError: If the command is rejected.
        :raises ResultCodeError: If the command returns an unexpected result code.
        """
        try:
            inout_args = (
                (self._command, *self._command_args)
                if self._command_args is not None
                else (self._command,)
            )
            [[result_code], [self._command_id]] = self._proxy.command_inout(*inout_args)
        except (
            ConnectionFailed,
            CommunicationFailed,
            DeviceUnlocked,
            DevFailed,
        ) as exc:
            self._unregister_lrc_callbacks()
            self._re_throw_exception(
                exc,
                f"Invocation of command '{self._command}' failed with args: "
                f"{self._command_args}",
            )
        else:
            # Check for valid result codes
            if result_code == ResultCode.REJECTED:
                msg = f"{self._command} command rejected: {self._command_id}"
                self._logger.error(msg)
                self._unregister_lrc_callbacks()
                raise CommandError(msg)
            if result_code not in [ResultCode.QUEUED, ResultCode.STARTED]:
                msg = (
                    f"Unexpected result code for {self._command} command: {result_code}"
                )
                self._logger.error(msg)
                self._unregister_lrc_callbacks()
                raise ResultCodeError(msg)
        finally:
            # Command submitted, proceed with "subsequent" events.
            self._submitted.set()

    @staticmethod
    def _re_throw_exception(exception: Exception, description: str) -> None:
        """
        Re-throw a Tango DevFailed exception.

        :param exception: Tango exception.
        :param description: Description to include in the exception information.
        """
        calling_fn = inspect.getouterframes(inspect.currentframe())[2].function
        frame = traceback.extract_tb(exception.__traceback__)[0]
        Except.re_throw_exception(
            exception,
            "SKA_InvokeLRCFailed",  # Reason
            description,
            # Origin
            f"{calling_fn}::{frame.name} at ({frame.filename}:{frame.lineno})",
        )

    def _convert_and_check_status(
        self,
        raw_status: str | int,
    ) -> TaskStatus | None:
        """
        Convert the raw status to a TaskStatus enum.

        :param raw_status: Status as string or integer
        :return: TaskStatus
        """
        try:
            if isinstance(raw_status, str):
                status = TaskStatus[raw_status]
            else:
                status = TaskStatus(raw_status)
        except (KeyError, ValueError, TypeError) as exc:
            status = None
            self._logger.exception(
                f"Received unknown TaskStatus from '{self._command_id}' command: "
                f"{raw_status}"
            )
            self._lrc_callback(
                error=Except.to_dev_failed(type(exc), exc, exc.__traceback__).args
            )
        else:
            if status in [
                TaskStatus.ABORTED,
                TaskStatus.COMPLETED,
                TaskStatus.FAILED,
                TaskStatus.REJECTED,
            ]:
                self._unregister_lrc_callbacks()
        return status


class _LRCProtocolV1(_LRCProtocol):
    REQUIRED_ATTRIBUTES = [
        "longRunningCommandStatus",
        "longRunningCommandProgress",
        "longRunningCommandResult",
    ]

    def invoke_lrc(self) -> LRCSubscriptionsProtocol:
        """
        Invoke the LRC with protocol version 1.

        :return: LRCSubscriptions
        """
        self._register_lrc_callbacks(self.REQUIRED_ATTRIBUTES)
        self._execute_command()
        return _LRCSubscriptions(
            self._command_id, self._callback_ids, self._callback_scheduler, 1
        )

    def _wrap_lrc_callback(self, event: EventData) -> None:
        """
        Wrap the instance's given user LRC callback with protocol version 1.

        :param event: tango attribute change event
        """
        if self._done:
            return

        # Check for tango error
        if event.err:
            self._logger.error(
                f"'{self._command_id}' command encountered error(s): {event.errors}"
            )
            self._lrc_callback(error=event.errors)
            return

        # Wait for the command to have an ID. Timeout is command_inout timeout + 1.
        if not self._submitted.wait(timeout=4):
            self._unregister_lrc_callbacks()
            return

        try:
            cmd_idx = event.attr_value.value.index(self._command_id)
            lrc_attr_value = event.attr_value.value[cmd_idx + 1]
        except ValueError:
            pass  # Do nothing, as will often be called for unrelated events
        except IndexError as exc:
            self._logger.exception(
                f"'{self._command_id}' command has no status/progress/result value"
            )
            self._lrc_callback(
                error=Except.to_dev_failed(type(exc), exc, exc.__traceback__).args
            )
        else:
            match event.attr_value.name:
                case "longrunningcommandstatus":
                    status = self._convert_and_check_status(lrc_attr_value)
                    self._lrc_callback(status=status)
                case "longrunningcommandprogress":
                    self._lrc_callback(progress=int(lrc_attr_value))
                case "longrunningcommandresult":
                    self._lrc_callback(result=json.loads(lrc_attr_value))


class _LRCProtocolV2(_LRCProtocol):
    REQUIRED_ATTRIBUTES = [
        "_lrcEvent",
    ]

    def invoke_lrc(self) -> LRCSubscriptionsProtocol:
        """
        Invoke the LRC with protocol version 2.

        :return: LRCSubscriptions
        """
        self._register_lrc_callbacks(self.REQUIRED_ATTRIBUTES)
        self._execute_command()
        return _LRCSubscriptions(
            self._command_id, self._callback_ids, self._callback_scheduler, 2
        )

    def _wrap_lrc_callback(self, event: EventData) -> None:
        """
        Wrap the instance's given user LRC callback with protocol version 2.

        :param event: tango attribute change event
        """
        # Check for tango error
        if event.err:
            self._logger.error(
                f"'{self._command_id}' command encountered error(s): {event.errors}"
            )
            self._lrc_callback(error=event.errors)
            return

        # Wait for the command to have an ID. Timeout is command_inout timeout + 1.
        if not self._submitted.wait(timeout=4):
            self._unregister_lrc_callbacks()
            return

        try:
            cmd_idx = event.attr_value.value.index(self._command_id)
            lrc_attr_value = event.attr_value.value[cmd_idx + 1]
        except ValueError:
            pass  # Do nothing, as will often be called for unrelated events
        except IndexError as exc:
            self._logger.exception(
                f"'{self._command_id}' command has no status/progress/result value"
            )
            self._lrc_callback(
                error=Except.to_dev_failed(type(exc), exc, exc.__traceback__).args
            )
        else:
            event = json.loads(lrc_attr_value)
            if "status" in event:
                event["status"] = self._convert_and_check_status(event["status"])
            if "progress" in event and not isinstance(event["progress"], int):
                self._logger.warning(
                    f"'{self._command}' command's progress is not an int, but "
                    f"{type(event['progress'])}. "
                    "Its type may be checked and enforced in the future."
                )
            self._lrc_callback(**event)


def _get_protocol_impl(version: int) -> type[_LRCProtocol]:
    match version:
        case 2:
            return _LRCProtocolV2
        case 1:
            return _LRCProtocolV1
        case _:
            # We know that _find_latest_compatible_protocol_version can only
            # return something in the range _SUPPORTED_LRC_PROTOCOL_VERSIONS
            # so we cannot hit this prong.
            assert False


_GLOBAL_SCHEDULER: CallbackScheduler | None = None
# Lock invariant: Only a single thread is initialising the GLOBAL_SCHEDULER
_GLOBAL_SCHEDULER_LOCK = threading.Lock()


def _ensure_global_scheduler() -> CallbackScheduler:
    global _GLOBAL_SCHEDULER  # noqa: PLW0603
    with _GLOBAL_SCHEDULER_LOCK:
        if _GLOBAL_SCHEDULER is None:
            _GLOBAL_SCHEDULER = CallbackScheduler(
                name="GlobalLRCScheduler", logger=module_logger
            )
    return _GLOBAL_SCHEDULER


[docs] def connect_lrc_interface( device: DeviceProxy, scheduler: CallbackScheduler | None = None, ) -> None: """ Connect a CallbackScheduler to the LRC interface of a device. Pre-connecting to the LRC interface can speed up multiple calls to invoke_lrc as they no longer have to connect the event streams individually. :param device: the device to connect to :param scheduler: scheduler to connect with or None for the global scheduler """ protocol_version = _find_latest_compatible_protocol_version(device) attributes = _get_protocol_impl(protocol_version).REQUIRED_ATTRIBUTES if scheduler is None: scheduler = _ensure_global_scheduler() # We don't need to wait on the futures here because we will wait for them to # connect in `invoke_lrc`. I think it just confuses users if we return # these to let them do it. for i, attr in enumerate(attributes): if i == 0: def queue_factory() -> Queue: return scheduler.allocate_queue( queue_size=_LRCProtocol.LRC_EVENT_QUEUE_SIZE ) else: def queue_factory() -> Queue: return scheduler.get_queue( device, attributes[0], EventType.CHANGE_EVENT ) scheduler.connect_event_stream( device, attr, EventType.CHANGE_EVENT, queue_factory=queue_factory )
[docs] def disconnect_lrc_interface( device: DeviceProxy, scheduler: CallbackScheduler | None = None, ) -> None: """ Disconnect a CallbackScheduler from the LRC interface of a device. :param device: the device to connect to :param scheduler: scheduler to connect with or None for the global scheduler """ protocol_version = _find_latest_compatible_protocol_version(device) attributes = _get_protocol_impl(protocol_version).REQUIRED_ATTRIBUTES if scheduler is None: scheduler = _ensure_global_scheduler() for attr in attributes: scheduler.disconnect_event_stream(device, attr, EventType.CHANGE_EVENT)
[docs] def invoke_lrc( lrc_callback: LRCCallbackType, proxy: DeviceProxy, command: str, command_args: tuple[typing.Any] | None = None, logger: logging.Logger | None = None, *, callback_scheduler: CallbackScheduler | None = None, ) -> LRCSubscriptionsProtocol: """ Invoke a long running command (LRC) and monitor its progress with callbacks. Subscribe to the relevant LRC attributes and inform the client about change events via the provided lrc_callback with either the status, progress, result or error. :param lrc_callback: Client LRC callback to call whenever the LRC's state changes. :param proxy: Tango DeviceProxy. :param command: Name of command to invoke. :param command_args: Optional arguments for the command, defaults to None. :param logger: Logger to use for logging exceptions. If not provided, then a default module logger will be used. :return: An object modelling LRCSubscriptionsProtocol, containing the command ID. A reference to the instance must be kept until the command is completed or the client is not interested in its events anymore, as deleting the instance unsubscribes from the LRC change events and thus stops any further callbacks. :raises CommandError: If the command is rejected. :raises ResultCodeError: If the command returns an unexpected result code. :raises ValueError: If the lrc_callback does not accept `**kwargs`. :raises RuntimeError: If the supported client-server protocol versions do not overlap. """ if not _is_future_proof_lrc_callback(lrc_callback): raise ValueError("lrc_callback must accept **kwargs") logger = logger or module_logger protocol_version = _find_latest_compatible_protocol_version(proxy) if callback_scheduler is None: callback_scheduler = _ensure_global_scheduler() impl = _get_protocol_impl(protocol_version) return impl( lrc_callback, proxy, callback_scheduler, logger, command, command_args ).invoke_lrc()
def _is_future_proof_lrc_callback(lrc_callback: LRCCallbackType) -> bool: sig = inspect.signature(lrc_callback) for param in sig.parameters.values(): if param.kind == inspect.Parameter.VAR_KEYWORD: return True return False def _find_latest_compatible_protocol_version(server_proxy: DeviceProxy) -> int: """ Find latest compatible protocol between client's and server's supported versions. :param server_proxy: Server proxy to query lrcProtocolVersions. :return: Highest compatible version, or None if there is no overlap. :raises RuntimeError: If the supported client-server protocol versions do not overlap. """ server_min, server_max = ( server_proxy.lrcProtocolVersions if "lrcProtocolVersions" in server_proxy.get_attribute_list() else (1, 1) # Assume server supports only the 1st version of the protocol ) client_min, client_max = _SUPPORTED_LRC_PROTOCOL_VERSIONS if server_min <= client_max and client_min <= server_max: return min(server_max, client_max) msg = ( f"Incompatible LRC protocol version for {server_proxy.dev_name()}. " f"Supported client version range {(client_min, client_max)} " "does not overlap with supported server version range " f"{(server_min, server_max)}" ) raise RuntimeError(msg)