"""
The ska_oso_oet.procedure.application module holds classes and functionality that
belong in the application layer of the OET. This layer holds the application
interface, delegating to objects in the domain layer for business rules and
actions.
"""
import collections
import logging
import os
import threading
import time
from typing import Any
from pubsub import pub
from pydantic import BaseModel, ConfigDict, computed_field, model_serializer
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from ska_oso_oet import mptools
from ska_oso_oet.config import OET_HISTORY_SIZE, OET_WAIT_FOR_QA_READY
from ska_oso_oet.event import topics
from ska_oso_oet.procedure import domain
from ska_oso_oet.procedure.domain import ProcedureState
from ska_oso_oet.utils.ui import API_PATH
base_dir = os.getenv("SCRIPTS_LOCATION", "/scripts")
ABORT_SCRIPT = domain.FileSystemScript(script_uri=f"file://{base_dir}/abort.py")
DELETEABLE_STATES = [
domain.ProcedureState.COMPLETE,
domain.ProcedureState.FAILED,
domain.ProcedureState.STOPPED,
domain.ProcedureState.UNKNOWN,
]
# CI/CD runs take considerably longer than local tests, hence we need a way to
# increase the timeout. The environment variable is only set for CI/CD.
STATE_WAIT_TIMEOUT_SECS = float(os.getenv("STATE_WAIT_TIMEOUT_SECS", "3.0"))
LOGGER = logging.getLogger(__name__)
[docs]
class PrepareProcessCommand(BaseModel):
"""
PrepareProcessCommand is input argument dataclass for the
ScriptExecutionService prepare command. It holds all the information
required to load and prepare a Python script ready for execution.
"""
script: domain.ExecutableScript
init_args: domain.ProcedureInput
[docs]
def __init__(
self, script: domain.ExecutableScript, init_args: domain.ProcedureInput
):
if "subarray_id" in init_args.kwargs:
raise ValueError(
"subarray_id is tied to the OET instance and should not be passed as a"
" script arg at runtime or in the SBDefinition."
)
super(PrepareProcessCommand, self).__init__(script=script, init_args=init_args)
[docs]
class StartProcessCommand(BaseModel):
"""
StartProcessCommand is the input argument dataclass for the
ScriptExecutionService start command. It holds the references required to
start a prepared script process along with any late-binding runtime
arguments the script may require.
"""
process_uid: int
fn_name: str
run_args: domain.ProcedureInput
force_start: bool = False
[docs]
def __init__(
self,
process_uid: int,
fn_name: str,
run_args: domain.ProcedureInput,
force_start: bool = False,
):
super().__init__(
process_uid=process_uid,
fn_name=fn_name,
run_args=run_args,
force_start=force_start,
)
[docs]
class StopProcessCommand(BaseModel):
"""
StopProcessCommand is the input argument dataclass for the
ScriptExecutionService Stop command. It holds the references required to
Stop a script process along with any late-binding runtime
arguments the script may require.
"""
process_uid: int
run_abort: bool
[docs]
def __init__(self, process_uid: int, run_abort: bool):
super().__init__(process_uid=process_uid, run_abort=run_abort)
[docs]
class ProcedureHistory(BaseModel):
"""
ProcedureHistory is a non-functional dataclass holding execution history of
a Procedure spanning all transactions.
process_states: records time for each change of ProcedureState (list of
tuples where tuple contains the ProcedureState and time when state was
changed to)
stacktrace: None unless execution_error is True in which case stores
stacktrace from process
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
process_states: list[tuple[domain.ProcedureState, float]] | None = (None,)
stacktrace: Any | None = (None,)
[docs]
def __init__(
self,
process_states: list[tuple[domain.ProcedureState, float]] | None = None,
stacktrace: Any | None = None,
):
if process_states is None:
process_states = []
super().__init__(process_states=process_states, stacktrace=stacktrace)
def __eq__(self, other):
if not isinstance(other, ProcedureHistory):
return False
if (
self.process_states == other.process_states
and self.stacktrace == other.stacktrace
):
return True
return False
def __repr__(self):
p_history = ", ".join(
["({!s}, {!r})".format(s, t) for (s, t) in self.process_states]
)
return "<ProcessHistory(process_states=[{}], stacktrace={})>".format(
p_history, self.stacktrace
)
@model_serializer
def _serialize_procedure_history(self) -> dict[str, Any]:
process_states = [
(state[0].name, state[1]) # pylint: disable=unsubscriptable-object
for state in self.process_states
]
return {
"process_states": process_states,
"stacktrace": self.stacktrace,
}
[docs]
class ArgCapture(BaseModel):
"""
ArgCapture is a struct to record function call and time of invocation.
"""
fn: str
fn_args: domain.ProcedureInput
time: float = None
[docs]
class AbortSummary(BaseModel):
abort_message: str
[docs]
class ProcedureSummary(BaseModel):
"""
ProcedureSummary is a brief representation of a runtime Procedure. It
captures essential information required to describe a Procedure and to
distinguish it from other Procedures.
"""
id: int # pylint: disable=invalid-name
script: domain.GitScript | domain.FileSystemScript | None
script_args: list[ArgCapture] | None
history: ProcedureHistory | None
state: domain.ProcedureState | None
@computed_field
@property
def uri(self) -> str:
return f"http://localhost{API_PATH}/procedures/{self.id}"
[docs]
def __init__(
self,
id: int, # pylint: disable=redefined-builtin
script: domain.ExecutableScript | None,
script_args: list[ArgCapture] | None,
history: ProcedureHistory | None,
state: domain.ProcedureState | None,
uri: str | None = None,
):
super().__init__(
id=id,
uri=uri,
script=script,
script_args=script_args,
history=history,
state=state,
)
@model_serializer(mode="wrap")
def _serialize_procedure_summary(
self, default_serializer: SerializerFunctionWrapHandler
) -> dict[str, Any]:
script_args = {
args.fn: {"args": args.fn_args.args, "kwargs": args.fn_args.kwargs}
for args in self.script_args
}
state = self.state.name
dumped = default_serializer(self)
dumped["script_args"] = script_args
dumped["state"] = state
return dumped
[docs]
class ScriptContext(BaseModel):
"""
ScriptContext holds the current execution context provided to scripts.
This includes persistent operator overrides and any other runtime context
that should be available to scripts at start time.
"""
wait_for_qa_ready: bool = OET_WAIT_FOR_QA_READY
[docs]
class ScriptExecutionService:
"""
ScriptExecutionService provides the high-level interface and facade for
the script execution domain (i.e., the 'procedure' domain).
The interface is used to load and run Python scripts in their own
independent Python child process.
The shutdown method should be called to ensure cleanup of any
multiprocessing artefacts owned by this service.
"""
# defines which lifecycle event to announce when a lifecycle.statechange is received
# TODO rationalise procedure lifecycle events and topics for multi-run scripts
state_to_topic = {
ProcedureState.INITIALISING: topics.procedure.lifecycle.started,
ProcedureState.RUNNING: topics.procedure.lifecycle.started,
ProcedureState.READY: topics.procedure.lifecycle.ready,
ProcedureState.COMPLETE: topics.procedure.lifecycle.complete,
ProcedureState.FAILED: topics.procedure.lifecycle.failed,
ProcedureState.STOPPED: topics.procedure.lifecycle.stopped,
}
[docs]
def __init__(
self,
abort_script: domain.ExecutableScript = ABORT_SCRIPT,
):
"""
Create a new ScriptExecutionService.
The .stop() method of this ScriptExecutionService can run a second
script once the current process has been terminated. By default, this
second script calls SubArrayNode.abort() to halt further activities
on the sub-array controlled by the terminated script. To run a
different script, define the script URI in the abort_script_uri
argument to this constructor.
:param abort_script: post-termination script for two-phase abort
"""
self.states: dict[int, domain.ProcedureState] = {}
self.script_args: dict[int, list[ArgCapture]] = {}
self.scripts: dict[int, domain.ExecutableScript] = {}
self.history: dict[int, ProcedureHistory] = collections.defaultdict(
ProcedureHistory
)
# Persistent state to be passed to observing scripts, resets on OET
# restart
self.script_context = ScriptContext()
self._state_updating = threading.RLock()
# Pass our states dict and lock to ProcessManager so they share the same
# state source, avoiding race conditions between separate state tracking
self._process_manager = domain.ProcessManager(
states=self.states, state_lock=self._state_updating
)
self._abort_script = abort_script
# Subscribe to lifecycle events via pypubsub
pub.subscribe(self._on_statechange, topics.procedure.lifecycle.statechange)
pub.subscribe(self._on_stacktrace, topics.procedure.lifecycle.stacktrace)
if "SUBARRAY_ID" not in os.environ:
raise ValueError(
"SUBARRAY_ID environment variable is required, and should have been set"
" by default in the Helm template."
)
self.subarray_id = os.environ["SUBARRAY_ID"]
[docs]
def prepare(self, cmd: PrepareProcessCommand) -> ProcedureSummary:
"""
Load and prepare a Python script for execution, but do not commence
execution.
:param cmd: dataclass argument capturing the script identity and load
arguments
:return:
"""
cmd.init_args.kwargs = {**cmd.init_args.kwargs, "subarray_id": self.subarray_id}
# Previously, we set the initial CREATING state in response to an event
# set from the ScriptWorker. That caused a race where states[pid] could
# be read in _summarise BEFORE the on_pubsub handler had added the
# initial states[pid] entry. So, we now set the initial state in the
# parent and and remove the CREATING announcement from the
# ScriptWorker. By holding the _state_updating lock as we do so, we
# prevent that race.
now = time.time()
with self._state_updating:
pid = self._process_manager.create(cmd.script, init_args=cmd.init_args)
self.states[pid] = ProcedureState.CREATING
# Add the initial state to history to prevent empty process_states list
self.history[pid].process_states.append((ProcedureState.CREATING, now))
self.scripts[pid] = cmd.script
self.script_args[pid] = [
ArgCapture(fn="init", fn_args=cmd.init_args, time=now)
]
self._prune_old_state()
return self._summarise(pid)
[docs]
def start(self, cmd: StartProcessCommand) -> ProcedureSummary:
"""
Start execution of a prepared procedure.
:param cmd: dataclass argument capturing the execution arguments
:return:
"""
cmd.run_args.kwargs = {
**cmd.run_args.kwargs,
"context": {"wait_for_qa_ready": self.script_context.wait_for_qa_ready},
}
self._process_manager.run(
cmd.process_uid,
call=cmd.fn_name,
run_args=cmd.run_args,
force_start=cmd.force_start,
)
self.script_args[cmd.process_uid].append(
ArgCapture(fn=cmd.fn_name, fn_args=cmd.run_args, time=time.time())
)
return self._summarise(cmd.process_uid)
[docs]
def summarise(self, pids: list[int] | None = None) -> list[ProcedureSummary]:
"""
Return ProcedureSummary objects for Procedures with the requested IDs.
This method accepts an optional list of integers, representing the
Procedure IDs to summarise. If the pids is left undefined,
ProcedureSummary objects for all current Procedures will be returned.
:param pids: optional list of Procedure IDs to summarise.
:return: list of ProcedureSummary objects
"""
# freeze state to prevent mutation from events
with self._state_updating:
all_pids = self.states.keys()
if pids is None:
pids = all_pids
missing_pids = {p for p in pids if p not in all_pids}
if missing_pids:
raise ValueError(f"Process IDs not found: {missing_pids}")
return [self._summarise(pid) for pid in pids]
[docs]
def stop(self, cmd: StopProcessCommand) -> list[ProcedureSummary]:
"""
Stop execution of a running procedure, optionally running a
second script once the first process has terminated.
:param cmd: dataclass argument capturing the execution arguments
:return:
"""
self._process_manager.stop(cmd.process_uid)
# exit early if not instructed to run post-termination script
if not cmd.run_abort:
# Did not start a new process so return empty list
return []
# prepare second script
prepare_cmd = PrepareProcessCommand(
script=self._abort_script,
init_args=domain.ProcedureInput(),
)
procedure_summary = self.prepare(prepare_cmd)
# wait for the script to be READY, then run it
# Use Proc.STARTUP_WAIT_SECS as the timeout since script initialization
# can take several seconds, especially in CI environments
self._wait_for_state(
procedure_summary.id,
ProcedureState.READY,
timeout=mptools.Proc.STARTUP_WAIT_SECS,
)
# start the second script
run_cmd = StartProcessCommand(
process_uid=procedure_summary.id,
fn_name="main",
run_args=domain.ProcedureInput(),
)
summary = self.start(run_cmd)
return [summary]
def _get_subarray_id(self, pid: int) -> int:
"""
Return a Subarray id for given procedure ID.
:param pid: Procedure ID to summarise
:return: subarray id
"""
procedure_summary = self._summarise(pid)
subarray_ids = {
arg_capture.fn_args.kwargs["subarray_id"]
for arg_capture in procedure_summary.script_args
if "subarray_id" in arg_capture.fn_args.kwargs
}
if not subarray_ids:
raise ValueError("Subarray ID not specified")
if len(subarray_ids) > 1:
raise ValueError("Multiple subarray IDs found")
return subarray_ids.pop()
def _summarise(self, pid: int) -> ProcedureSummary:
"""
Return a ProcedureSummary for the Procedure with the given ID.
CAUTION: do NOT modify the arguments! SES state is exposed here.
:param pid: Procedure ID to summarise
:return: ProcedureSummary
"""
with self._state_updating:
state = self.states[pid]
script = self.scripts[pid]
script_args = self.script_args[pid]
history = self.history[pid]
return ProcedureSummary(
id=pid,
script=script,
script_args=script_args,
history=history,
state=state,
)
def _prune_old_state(self):
"""
Remove the state associated with the oldest deletable Procedures so
that the state history remains below the history limit
OET_HISTORY_SIZE.
"""
# Delete oldest deletable procedure if procedure limit reached
with self._state_updating:
if len(self.states) > OET_HISTORY_SIZE:
lower_bound = len(self.states) - OET_HISTORY_SIZE
pids_to_consider = list(self.states.keys())[:lower_bound]
to_delete = {
old_pid
for old_pid in pids_to_consider
if self.states.get(old_pid, None) in DELETEABLE_STATES
}
for old_pid in to_delete:
del self.states[old_pid]
del self.history[old_pid]
del self.script_args[old_pid]
del self.scripts[old_pid]
def _on_statechange(self, msg_src, new_state: ProcedureState) -> None:
"""
Callback method that updates Procedure history whenever a message on
the procedure.lifecycle.statechange topic is received.
:param msg_src: PID of the procedure (as string or int)
:param new_state: new ProcedureState
"""
try:
pid = int(msg_src)
except (ValueError, TypeError):
return
now = time.time()
# Note: We need to handle event publishing inside the lock to prevent race conditions
# where the PID gets cleaned up by _prune_old_state between our check and summarize call
events_to_publish = []
with self._state_updating:
previous = self.states.get(pid, None)
self.states[pid] = new_state
# CREATING is now set within this process to avoid a race. The
# second CREATING can be ignored - it's sent from the child process
if new_state is not ProcedureState.CREATING:
self.history[pid].process_states.append((new_state, now))
# Check if we need to publish events and collect them to publish after releasing the lock
# schedule a legacy lifecycle status change event when appropriate
if new_state in self.state_to_topic:
result = self._summarise(pid)
events_to_publish.append((self.state_to_topic[new_state], result))
# special case: there's no unique state to signify loading complete
if previous == ProcedureState.LOADING and new_state == ProcedureState.IDLE:
LOGGER.debug(
"Publishing created event: previous=%s, new_state=%s",
previous,
new_state,
)
result = self._summarise(pid)
events_to_publish.append((topics.procedure.lifecycle.created, result))
if not events_to_publish:
LOGGER.debug(
(
"No pubsub event published for state change: previous=%s,"
" new_state=%s"
),
previous,
new_state,
)
# Publish events outside the lock to avoid holding it during pubsub operations
# msg_src=None signals to EventBusWorker.republish to substitute the
# worker's own name, ensuring the event is forwarded to other workers.
for topic, result in events_to_publish:
pub.sendMessage(
topic,
msg_src=None,
request_id=None,
result=result,
)
def _on_stacktrace(self, msg_src, stacktrace: str) -> None:
"""
Callback method to record stacktrace event in the Procedure history
whenever a message on procedure.lifecycle.stacktrace is received.
:param msg_src: PID of the procedure (as string or int)
:param stacktrace: stacktrace string
"""
try:
pid = int(msg_src)
except (ValueError, TypeError):
return
with self._state_updating:
self.history[pid].stacktrace = stacktrace
[docs]
def handle_wait_for_qa_ready_enable(
self,
# msg_src MUST be part of method signature for pypubsub to function
msg_src, # pylint: disable=unused-argument
) -> None:
"""
Handler for operator.wait_for_qa_ready.enable topic.
Sets wait_for_qa_ready to True.
:param msg_src: component from which the request originated
"""
self.script_context.wait_for_qa_ready = True
[docs]
def handle_wait_for_qa_ready_disable(
self,
# msg_src MUST be part of method signature for pypubsub to function
msg_src, # pylint: disable=unused-argument
) -> None:
"""
Handler for operator.wait_for_qa_ready.disable topic.
Sets wait_for_qa_ready to False.
:param msg_src: component from which the request originated
"""
self.script_context.wait_for_qa_ready = False
def _wait_for_state(
self,
pid: int,
state: ProcedureState,
timeout=STATE_WAIT_TIMEOUT_SECS,
tick=0.01,
):
"""
A time-bound wait for a Procedure to reach the requested state.
:param pid: ID of Procedure to wait for
:param timeout: wait timeout, in seconds
:param tick: time between state checks, in seconds
"""
deadline = time.time() + timeout
sleep_secs = tick
while self.states.get(pid, None) != state and sleep_secs > 0:
time.sleep(sleep_secs)
sleep_secs = mptools._sleep_secs( # pylint: disable=protected-access
tick, deadline
)