#
# This file is part of the SKA Tango Base project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE.txt for more info.
"""This module provides for asynchronous execution of tasks."""
from __future__ import annotations
import functools
import threading
import warnings
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import tango
from packaging import version
from ska_control_model import ResultCode
from ska_control_model import TaskStatus as _TaskStatus
from tango import __version__ as tango_version
from ..faults import CmdNotAllowedError, StateModelError
from ..type_hints import (
SimpleTaskFunctionType,
TaskCallbackType,
TaskExecutorProtocol,
TaskFunctionType,
)
if version.parse(tango_version) >= version.parse("10.0.0"):
PYTANGO_10_OR_NEWER = True
from tango.utils import PyTangoThreadPoolExecutor
try:
from opentelemetry.trace import get_tracer
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
from tango.utils import get_telemetry_tracer_provider_factory
OPENTELEMETRY_INSTALLED = True
except ImportError:
OPENTELEMETRY_INSTALLED = False
else:
PYTANGO_10_OR_NEWER = False
OPENTELEMETRY_INSTALLED = False
__all__ = [
"TaskExecutor",
"TaskStatus", # noqa: F822
"TaskAborted",
"ExecutorShutdownError",
"ExecutorNotShutdownError",
]
def __getattr__(name: str) -> Any:
if name == "TaskStatus":
warnings.warn(
"Importing 'TaskStatus' from the 'ska_tango_base.executor.executor' has "
"been deprecated since ska-tango-base 1.4.0. Import from "
"'ska_control_model' directly instead.",
DeprecationWarning,
stacklevel=2,
)
return _TaskStatus
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
DEFAULT_MAX_QUEUED_TASKS = 32
[docs]
class TaskAborted(Exception): # noqa N818
"""Report task abortion during a long running command."""
[docs]
class ExecutorShutdownError(Exception):
"""The TaskExecutor has been shutdown."""
[docs]
def __init__(self, action: str) -> None:
"""Initialise the exception message."""
super().__init__(f"TaskExecutor has already been shutdown when {action}")
[docs]
class ExecutorNotShutdownError(Exception):
"""The TaskExecutor has not been shutdown."""
[docs]
def __init__(self) -> None:
"""Initialise the exception message."""
super().__init__("TaskExecutor has not been shutdown when start() called")
[docs]
class TaskExecutor(TaskExecutorProtocol):
"""
An asynchronous executor of tasks.
This task executor provides a default implementation of the
:py:class:`~ska_tango_base.type_hints.TaskExecutorProtocol` for use with the
:py:class:`~ska_tango_base.long_running_commands.mixin.LRCMixin`.
It supports multiple worker threads for simultaneous LRC execution, although the
default of 1 is recommended for sequential execution in most cases.
A maximum queue length can also be specified after which more submitted tasks will
be rejected. The default is 32, with an allowed minimum of 1. In other words, this
task executor will always have a queue and allow at least 1 LRC to be queued while
another is busy executing.
"""
[docs]
def __init__(
self: TaskExecutor,
max_workers: int = 1,
unhandled_exception_callback: Callable[[Exception], None] | None = None,
max_queued_tasks: int = DEFAULT_MAX_QUEUED_TASKS,
) -> None:
"""
Initialise a new TaskExecutor instance.
:param max_workers: The maximum number of worker threads (minimum 1).
This is meant to be kept at the default value to allow
the sequential execution of LRC except for special cases.
:param unhandled_exception_callback: Callback to be called when a task raises an
unhandled exception.
:param max_queued_tasks: The maximum number of tasks allowed in the queue before
more ones are rejected (minimum 1).
"""
if max_workers < 1:
warnings.warn(
"'max_workers' must be equal to or greater than 1. Defaulting to 1. "
"Passing a valid value may be enforced in the future by raising an "
"error.",
FutureWarning,
)
if max_queued_tasks < 1:
warnings.warn(
"'max_queued_tasks' must be equal to or greater than 1. "
"Defaulting to 1. Passing a valid value may be enforced in the future "
"by raising an error.",
FutureWarning,
)
self._max_workers = max_workers if max_workers > 0 else 1
self._max_executing_tasks = max_workers + 1 if max_workers > 0 else 2
self._max_queued_tasks = max_queued_tasks if max_queued_tasks > 0 else 1
self._unhandled_exception_callback = unhandled_exception_callback
# Management of the _executor is a little complicated due to the
# abort() method. We might have a scenario where the abort() thread
# is currently in flight and shutdown() is called concurrently.
#
# In this scenario, we don't want the abort() thread to re-create the
# _executor. We use _is_shutdown to signal that shutdown() has been
# called so this TaskExecutor is effectively dead. The abort() thread
# will no longer re-create the _executor in this case and will set its
# task status to ABORTED.
#
# We have to guard the creation of the _executor and the assignment of
# _is_shutdown with the same lock so that if the abort() thread is in
# the middle of re-creating the _executor, we end up shutting down the
# correct PyTangoThreadPoolExecutor in shutdown().
#
# We also use this same _executor_lock in submit() so that we can check
# _is_shutdown and raise an error if it is True. Calling `submit()`
# after `shutdown()` is considered a bug, so users must ensure they
# don't do this.
self._executor_lock = threading.RLock()
self._is_shutdown = False
self._executor: ThreadPoolExecutor
self._input_queue_size: int
self._abort_event: threading.Event
self._abort_thread: threading.Thread | None = None
self._try_start()
def _try_start(self: TaskExecutor) -> None:
with self._executor_lock:
if self._is_shutdown:
raise ExecutorShutdownError("trying to start the thread pool")
# Create a new Event rather than just clearing the old one, in case a
# running thread is yet to check.
self._abort_event = threading.Event()
# We don't need to synchronise this with submit, because
# the only time we call _start() is after __init__() is on
# the abort() thread and at the point this is called
# self._executor will be shutdown, meaning the
# self._executor.submit() call will throw before we increment
# _input_queue_size.
self._input_queue_size = 0
# COMPAT (pytango 9.x.x): PyTangoThreadPoolExecutor was only added with
# pytango 10 - revert to using the python built-in ThreadPoolExecutor.
if PYTANGO_10_OR_NEWER:
self._executor = PyTangoThreadPoolExecutor(
max_workers=self._max_workers,
)
else:
self._executor = ThreadPoolExecutor(
max_workers=self._max_workers,
)
[docs]
def start(self: TaskExecutor) -> None:
"""
Start the backing thread pool.
This function is automatically called during :py:meth:`!__init__()` and only
needs to be called in order to recover if :py:meth:`shutdown()` has already
been called.
:raises ExecutorNotShutdownError: If :meth:`shutdown()` has not already been
called.
"""
with self._executor_lock:
if not self._is_shutdown:
raise ExecutorNotShutdownError()
self._is_shutdown = False
self._try_start()
[docs]
def shutdown(self: TaskExecutor) -> None:
"""
Abort all outstanding tasks and shutdown the backing thread pool.
This should be called when the :py:class:`TaskExecutor` object is no
longer required and after this has been called, the :py:class:`TaskExecutor`
is unusable with :py:meth:`submit()` or :py:meth:`abort()` raising a
:py:class:`ExecutorShutdownError`.
:py:meth:`start()` can be called to reset the :py:class:`TaskExecutor` to a
usable state.
"""
with self._executor_lock:
self._abort_event.set()
self._executor.shutdown(wait=False)
self._is_shutdown = True
executor = self._executor
abort_thread = self._abort_thread
# We do not wait while holding the lock, in case there is a task trying
# to call `submit()` which might cause a deadlock
executor.shutdown(wait=True)
if abort_thread is not None:
abort_thread.join()
@property
def max_executing_tasks(self: TaskExecutor) -> int:
"""
Get the maximum number of simultaneously executing tasks.
Will always be one more than then the given ``max_workers`` when
instantiating the TaskExecutor to accommodate an abort task.
:return: The maximum number of simultaneously executing tasks.
"""
return self._max_executing_tasks
@property
def max_queued_tasks(self: TaskExecutor) -> int:
"""
Get the maximum task queue size.
:return: The maximum task queue size.
"""
return self._max_queued_tasks
[docs]
def submit(
self: TaskExecutor,
func: TaskFunctionType,
args: Any = None,
kwargs: Any = None,
is_cmd_allowed: Callable[[], bool] | None = None,
task_callback: TaskCallbackType | None = None,
) -> tuple[_TaskStatus, str]:
"""
Submit a new task.
The `is_cmd_allowed` callback may raise a `CmdNotAllowedError` before the task
is to be executed.
:param func: the task function to be executed.
:param args: positional arguments to the task function
:param kwargs: keyword arguments to the task function
:param is_cmd_allowed: sanity check for task execution
:param task_callback: the callback to be called when the status
or progress of the task execution changes
:raises ExecutorShutdownError: if the executor has been shutdown
:return: (_TaskStatus, message)
"""
if self.get_input_queue_size() >= self.max_queued_tasks:
return (
_TaskStatus.REJECTED,
f"Input queue supports a maximum of {self.max_queued_tasks} commands",
)
thread_trace: dict[str, str] = {}
if OPENTELEMETRY_INSTALLED:
TraceContextTextMapPropagator().inject(thread_trace)
with self._executor_lock:
if self._is_shutdown:
raise ExecutorShutdownError("trying to submit a task")
start_event = threading.Event()
try:
self._executor.submit(
self._run,
func,
args,
kwargs,
is_cmd_allowed,
task_callback,
self._abort_event,
thread_trace,
start_event,
)
except RuntimeError:
self._call_task_callback(
task_callback,
status=_TaskStatus.REJECTED,
result=(ResultCode.REJECTED, "Queue is being aborted"),
)
return _TaskStatus.REJECTED, "Queue is being aborted"
except Exception as exc:
self._call_task_callback(
task_callback,
status=_TaskStatus.FAILED,
result=(
ResultCode.FAILED,
f"Unhandled exception submitting task: {str(exc)}",
),
exception=exc,
)
if self._unhandled_exception_callback is not None:
self._unhandled_exception_callback(exc)
return (
_TaskStatus.FAILED,
f"Unhandled exception submitting task: {str(exc)}",
)
self._call_task_callback(task_callback, status=_TaskStatus.QUEUED)
self._input_queue_size += 1
start_event.set()
return _TaskStatus.QUEUED, "Task queued"
[docs]
def abort(
self: TaskExecutor, task_callback: TaskCallbackType | None = None
) -> tuple[_TaskStatus, str]:
"""
Tell this executor to abort execution.
New submissions will be rejected until the queue is empty and no
tasks are still running. Tasks on the queue will be marked as
aborted and not run. Tasks already running will be allowed to
continue running
:raises ExecutorShutdownError: if the executor has been shutdown
:param task_callback: callback for abort complete
:return: tuple of task status & message
"""
def _shutdown_and_relaunch(executor: ThreadPoolExecutor) -> None:
with tango.EnsureOmniThread():
executor.shutdown(wait=True)
try:
self._try_start()
self._call_task_callback(
task_callback,
status=_TaskStatus.COMPLETED,
result=(ResultCode.OK, "Abort completed OK"),
)
except ExecutorShutdownError as ex:
self._call_task_callback(
task_callback,
status=_TaskStatus.ABORTED,
result=(ResultCode.ABORTED, f"Abort aborted: {ex}"),
)
return
finally:
self._abort_thread = None
with self._executor_lock:
if self._is_shutdown:
self._call_task_callback(
task_callback,
status=_TaskStatus.REJECTED,
result=(ResultCode.REJECTED, "TaskExecutor is shutdown"),
)
raise ExecutorShutdownError("trying to abort running tasks")
self._call_task_callback(task_callback, status=_TaskStatus.IN_PROGRESS)
# We capture the executor that the abort thread needs to wait on in
# case `shutdown()`, `start()` and `submit()` are somehow run
# _before_ this thread starts. In that scenario the abort_event
# associated with that executor will not be set, so the task will
# not know it is supposed to abort and we may get stuck rejecting
# tasks for a long time if the thread calls `shutdown()` again.
self._abort_thread = threading.Thread(
target=_shutdown_and_relaunch, args=(self._executor,), daemon=True
)
self._abort_event.set()
# We call shutdown here initially just so that we start rejecting
# new tasks immediately
self._executor.shutdown(wait=False)
self._abort_thread.start()
return _TaskStatus.IN_PROGRESS, "Aborting tasks"
def _run(
self: TaskExecutor,
func: TaskFunctionType,
args: Any,
kwargs: Any,
is_cmd_allowed: Callable[[], bool] | None,
task_callback: TaskCallbackType | None,
abort_event: threading.Event,
thread_trace: dict[str, str],
start_event: threading.Event,
) -> None:
if OPENTELEMETRY_INSTALLED:
context = TraceContextTextMapPropagator().extract(thread_trace)
tracer_provider_factory = get_telemetry_tracer_provider_factory()
tracer_provider = tracer_provider_factory(self.__class__.__name__)
tracer = get_tracer(
instrumenting_module_name=self.__class__.__name__,
tracer_provider=tracer_provider,
)
# Let the submit method finish before we start. This prevents this thread from
# calling back with "IN PROGRESS" before the submit method has called back with
# "QUEUED".
start_event.wait()
if abort_event.is_set():
self._call_task_callback(
task_callback,
status=_TaskStatus.ABORTED,
result=(ResultCode.ABORTED, "Command has been aborted"),
)
return
if is_cmd_allowed is not None:
try:
if not is_cmd_allowed():
self._call_task_callback(
task_callback,
status=_TaskStatus.REJECTED,
result=(ResultCode.NOT_ALLOWED, "Command is not allowed"),
)
return
except (CmdNotAllowedError, StateModelError) as exc:
self._call_task_callback(
task_callback,
status=_TaskStatus.REJECTED,
result=(ResultCode.NOT_ALLOWED, str(exc)),
)
return
except Exception as exc:
# Catching all exceptions because we're on a thread. Any
# uncaught exception will take down the thread without giving
# us any useful diagnostics.
self._call_task_callback(
task_callback,
status=_TaskStatus.REJECTED,
result=(
ResultCode.FAILED,
f"Exception from 'is_cmd_allowed' method: {str(exc)}",
),
exception=exc,
)
if self._unhandled_exception_callback is not None:
self._unhandled_exception_callback(exc)
return
finally:
with self._executor_lock:
self._input_queue_size -= 1
# Don't set the task to IN_PROGRESS yet, in case func is itself implemented
# asynchronously. We leave it to func to set the task to IN_PROGRESS, and
# eventually to set it to COMPLETE
try:
args = args or []
kwargs = kwargs or {}
if OPENTELEMETRY_INSTALLED:
task = func
if not hasattr(task, "__name__"):
while isinstance(task, functools.partial):
task = task.func
name = getattr(task, "__name__", "<unknown task>")
with tracer.start_as_current_span(
f"{self.__class__.__name__}._run.{name}",
context,
attributes={
"function_args": args,
"function_kwargs": [f"{k}={v}" for k, v in kwargs.items()],
},
):
func(
*args,
task_callback=task_callback,
task_abort_event=abort_event,
**kwargs,
)
else:
func(
*args,
task_callback=task_callback,
task_abort_event=abort_event,
**kwargs,
)
except Exception as exc:
# Catching all exceptions because we're on a thread. Any
# uncaught exception will take down the thread without giving
# us any useful diagnostics.
self._call_task_callback(
task_callback,
status=_TaskStatus.FAILED,
result=(
ResultCode.FAILED,
f"Unhandled exception during execution: {str(exc)}",
),
exception=exc,
)
if self._unhandled_exception_callback is not None:
self._unhandled_exception_callback(exc)
finally:
with self._executor_lock:
self._input_queue_size -= 1
# This method is for linter to not complain about too many nested ifs
@staticmethod
def _call_task_callback(
task_callback: TaskCallbackType | None,
**kwargs: Any,
) -> None:
if task_callback is not None:
task_callback(**kwargs)
[docs]
@staticmethod
def task(task: SimpleTaskFunctionType) -> TaskFunctionType:
"""
Apply task executor boilerplate to a task.
Wraps a task function that accepts only a progress callback and abort
event with common task executor boilerplate code to transition the task
through the `task state machine.
<https://developer.skao.int/projects/ska-tango-base/en/latest/concepts/long-running-commands.html#long-running-command-tasks>`_
It must also return JSONData that will be used as the result of the
task.
The task should regularly call the progress callback to report the
progress of the task. The progress is expected to be an integer, and it
is recommended to use this integer to represent a percentage of the task
progress i.e. integer values from 0 to 100. At similar intervals the
task should check whether the abort event has been set
(:code:`task_abort_event.is_set()`). If the :py:func:`!is_set()` method
returns :code:`True`, the task should quickly exit, performing any essential
cleanup, and raising :py:class:`TaskAborted`.
The task should never throw an exception, other than TaskAborted, as
failure should be reported by a
:py:class:`~ska_control_model.ResultCode` in the :py:class:`~ska_tango_base.type_hints.JSONData`.
:param task: A task within a task factory. It should accept a progress callback and task abort event (:py:class:`threading.Event`).
:return: The input task wrapped with task executor boilerplate.
""" # noqa: E501
assigned = tuple(
x for x in functools.WRAPPER_ASSIGNMENTS if x != "__annotations__"
)
@functools.wraps(task, assigned=assigned)
def task_callback_wrapper(
*args: Any,
task_callback: TaskCallbackType | None,
task_abort_event: threading.Event | None,
**kwargs: Any,
) -> None:
def progress_callback(progress: int) -> None:
if task_callback is not None:
task_callback(progress=progress)
if task_abort_event is None:
task_abort_event = threading.Event()
if task_callback is not None:
task_callback(status=_TaskStatus.IN_PROGRESS)
try:
result = task(
*args,
progress_callback=progress_callback,
task_abort_event=task_abort_event,
**kwargs,
)
except TaskAborted:
if task_callback is not None:
task_callback(
status=_TaskStatus.ABORTED,
result=(ResultCode.ABORTED, "Task aborted"),
)
return
if task_callback is not None:
task_callback(status=_TaskStatus.COMPLETED, result=result)
return
return task_callback_wrapper