# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project.
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.
"""Module for handling long running device proxy tasks."""
from __future__ import annotations
import concurrent.futures
import logging
import queue
import threading
from typing import Optional, cast
from ska_pst.lmc.job.task import DeviceCommandTaskContext
from ska_tango_base.commands import ResultCode
_logger = logging.getLogger(__name__)
[docs]class DeviceCommandTaskExecutor:
"""
Class to handle executing and tracking commands on device proxies.
This class uses a queue to receive task commands, while a background
thread receives these messages and then executes the commands.
Since the remote commands also run in the background on the device, this
class will use a :py:class:`LongRunningCommand` to wait on when the
command completes.
Clients should submit `DeviceCommandTask` tasks to the `TaskExecutor` rather
than building up a `DeviceCommandTaskContext` and sending it to the
task queue.
Instances of class and the `TaskExecutor` class work together by sharing
a queue. If creating separate instances of both classes, make sure that
queue between them is the same.
"""
def __init__(
self: DeviceCommandTaskExecutor,
task_queue: queue.Queue[DeviceCommandTaskContext],
max_parallel_workers: int = 4,
logger: Optional[logging.Logger] | None = None,
) -> None:
"""
Initialise the executor.
:param task_queue: the queue used to submit tasks to,
This should be shared by the `TaskExecutor` which is the producer
of the messages this class consumes.
:type task_queue: queue.Queue[DeviceCommandTaskContext]
:param max_parallel_workers: maximum number of workers used for parallel task
processing, defaults to 4
:type max_parallel_workers: int
:param logger: the logger to use in class, defaults to None
:type logger: Optional[logging.Logger] | None, optional
"""
self._logger = logger or logging.getLogger(__name__)
self._task_queue = task_queue
self._lock = threading.Lock()
self._stop = threading.Event()
self._max_parallel_workers = max_parallel_workers
self._running = False
def __del__(self: DeviceCommandTaskExecutor) -> None:
"""Tear down class being destroyed."""
self.stop()
[docs] def stop(self: DeviceCommandTaskExecutor) -> None:
"""Stop the executor."""
if self._running:
self._running = False
self._stop.set()
self._tpe.shutdown()
[docs] def start(self: DeviceCommandTaskExecutor) -> None:
"""Start the executor."""
if self._running:
return
self._running = True
# need to reset this each time we start.
self._stop = threading.Event()
self._tpe = concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_parallel_workers, thread_name_prefix="DeviceCommandTaskThread"
)
self._tpe.submit(self._process_queue)
def _process_queue(self: DeviceCommandTaskExecutor) -> None:
"""
Process messages off task queue.
This method uses an infinite loop to read messages off the
task queue. Once a message is received it will call the
`_handle_task` method.
The loop is only stopped when instances of this class are
destroyed.
"""
while not self._stop.is_set():
try:
task_context = cast(DeviceCommandTaskContext, self._task_queue.get(timeout=0.01))
self._logger.debug(f"DeviceCommandTaskExecutor received a device task: {task_context}")
self._handle_task(task_context)
self._task_queue.task_done()
except queue.Empty:
continue
def _handle_task(self: DeviceCommandTaskExecutor, task_context: DeviceCommandTaskContext) -> None:
"""
Handle task request that has been received.
Method will use a :py:class:`LongRunningCommand` to ensure the command completes or times out.
If the result code is OK then the `task_context` is signaled as being completed successfully.
:param task_context: the device command context that should be processed.
:type task_context: DeviceCommandTaskContext
"""
# the local import is used to all mocking of LongRunningCommand
from ska_pst.lmc.util import LongRunningCommand
# ensure subscription
device = task_context.device
command = task_context.command
command_args = task_context.command_args
timeout = task_context.timeout
command_str = f"{device}.{command}()"
lrc = LongRunningCommand(command=command)
result = lrc(proxy=device.device, command_args=command_args, timeout=timeout)
if result.result_code == ResultCode.OK:
task_context.signal_complete(result=result.result)
elif result.exception is not None:
self._logger.error(f"Error while executing command {command_str}")
task_context.signal_failed(exception=result.exception)
else:
self._logger.error(
(
f"{command_str} failed with status '{result.result_code.name}'."
f" Message: {result.result}"
)
)
task_context.signal_failed_from_str(msg=str(result.result))