"""
The ska_oso_oet.activity.application module contains code related
to OET 'activities' that belong in the application layer. This application
layer holds the application interface, delegating to objects in the domain
layer for business rules and actions.
"""
import collections
import itertools
import logging
import os
import tempfile
import time
from datetime import datetime, timezone
from typing import Any
from pubsub import pub
from pydantic import BaseModel, Field, computed_field, model_serializer
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from ska_db_oda.unitofwork.postgresunitofwork import PostgresUnitOfWork
from ska_oso_pdm import SBDefinition
from ska_oso_pdm._shared import TelescopeType
from ska_oso_pdm.sb_definition.procedures import (
FilesystemScript,
GitScript,
PythonArguments,
PythonProcedure,
)
from ska_oso_pdm.sb_instance import ActivityCall, FunctionArgs, SBInstance
from ska_oso_oet.activity.domain import Activity, ActivityState
from ska_oso_oet.config import OET_HISTORY_SIZE
from ska_oso_oet.event import topics
from ska_oso_oet.procedure import domain
from ska_oso_oet.procedure.application import (
PrepareProcessCommand,
ProcedureSummary,
StartProcessCommand,
)
from ska_oso_oet.procedure.domain import ProcedureState
from ska_oso_oet.utils.ui import API_PATH
LOGGER = logging.getLogger(__name__)
# Terminal states for activity cleanup - activities in these states can be removed
# when the activity history exceeds OET_HISTORY_SIZE
DELETABLE_ACTIVITY_STATES = frozenset(
[
ActivityState.COMPLETE,
ActivityState.FAILED,
ActivityState.TERMINATED,
]
)
# State mapping from ProcedureState to ActivityState
# This dictionary defines how procedure lifecycle states map to activity states
_PROCEDURE_TO_ACTIVITY_STATE_MAPPING = {
ProcedureState.IDLE: ActivityState.PREPARING,
ProcedureState.CREATING: ActivityState.PREPARING,
ProcedureState.PREP_ENV: ActivityState.PREPARING,
ProcedureState.LOADING: ActivityState.PREPARING,
ProcedureState.INITIALISING: ActivityState.PREPARING,
ProcedureState.READY: ActivityState.READY,
ProcedureState.RUNNING: ActivityState.RUNNING,
ProcedureState.COMPLETE: ActivityState.COMPLETE,
ProcedureState.FAILED: ActivityState.FAILED,
ProcedureState.STOPPED: ActivityState.TERMINATED,
ProcedureState.UNKNOWN: ActivityState.FAILED, # Error condition mapping
}
class ActivityCommand(BaseModel):
""" """
activity_name: str
sbd_id: str
prepare_only: bool
create_env: bool
script_args: dict[str, domain.ProcedureInput]
def __init__(
self,
activity_name: str,
sbd_id: str,
prepare_only: bool,
create_env: bool,
script_args: dict[str, domain.ProcedureInput],
):
super(ActivityCommand, self).__init__(
activity_name=activity_name,
sbd_id=sbd_id,
prepare_only=prepare_only,
create_env=create_env,
script_args=script_args,
)
[docs]
class ActivitySummary(BaseModel):
id: int # pylint: disable=invalid-name
pid: int = (Field(alias="procedure_id"),)
procedure_id: int
sbd_id: str
activity_name: str
prepare_only: bool
script_args: dict[str, domain.ProcedureInput]
activity_states: list[tuple[ActivityState, float]]
state: str | None = None
sbi_id: str
@computed_field
@property
def uri(self) -> str:
return f"http://localhost{API_PATH}/activities/{self.id}"
[docs]
def __init__(
self,
id: int, # pylint: disable=redefined-builtin
pid: int,
sbd_id: str,
activity_name: str,
prepare_only: bool,
script_args: dict[str, domain.ProcedureInput],
activity_states: list[tuple[ActivityState, float]] | collections.deque | None,
sbi_id: str,
uri: str | None = None,
state: str | None = None,
):
# Normalize deque to list for serialization
if activity_states is None:
activity_states = []
elif isinstance(activity_states, collections.deque):
activity_states = list(activity_states)
super().__init__(
id=id,
pid=pid,
procedure_id=pid,
uri=uri,
sbd_id=sbd_id,
activity_name=activity_name,
prepare_only=prepare_only,
script_args=script_args,
activity_states=activity_states,
state=state,
sbi_id=sbi_id,
)
@model_serializer(mode="wrap")
def _serialize_activity_summary(
self, default_serializer: SerializerFunctionWrapHandler
) -> dict[str, Any]:
script_args = {
fn: {
"args": self.script_args[fn].args,
"kwargs": self.script_args[fn].kwargs,
}
for fn in self.script_args.keys()
}
activity_states = [
(state_enum.name, timestamp)
for (state_enum, timestamp) in self.activity_states
]
state = max(
states_to_time := dict(self.activity_states), key=states_to_time.get
).name
dumped = default_serializer(self)
dumped["state"] = state
dumped["activity_states"] = activity_states
dumped["script_args"] = script_args
return dumped
[docs]
class ActivityService:
"""
ActivityService provides the high-level interface and facade for
the activity domain.
The interface is used to run activities referenced by Scheduling Blocks.
Each activity will run a script (or `procedure`) but ActivityService
will create the necessary commands for Procedure domain to create
and execute the scripts.
"""
[docs]
def __init__(
self,
):
# ActivityService stores activity state history using FIFO deques with unbounded size per activity
# The total number of activities is managed by cleanup logic based on OET_HISTORY_SIZE
self.states: dict[int, collections.deque] = collections.defaultdict(
collections.deque
)
self.script_args: dict[int, dict[str, domain.ProcedureInput]] = {}
self.activities: dict[int, Activity] = {}
# We need to store this state as the service needs to check if a procedure that has been created
# is the result of an activity request
self.request_ids_to_aid: dict[int, int] = {}
# The script workers currently only allow one execution of main(), so we
# must Track which activities have had main dispatched, to prevent duplicate
# RUN(main) when lifecycle.ready fires a second time (after main completes)
# As/when we allow long-lived scripts with multiple executions, this should
# be removed.
self._main_dispatched: set[int] = set()
# counter used to generate activity ID for new activities
self._aid_counter = itertools.count(1)
self._oda = PostgresUnitOfWork()
# Subscribe to procedure lifecycle events for state synchronization
pub.subscribe(
self._on_procedure_statechange, topics.procedure.lifecycle.statechange
)
def _get_current_activity_state(self, activity_id: int) -> ActivityState | None:
"""
Get the current state of an activity from its state history.
:param activity_id: Activity ID to get state for
:return: Current ActivityState or None if no states recorded
"""
state_history = self.states.get(activity_id)
if not state_history:
return None
# Return the most recent state (last item in deque)
return state_history[-1][0]
def _cleanup_activity_history(self) -> None:
"""
Clean up oldest activities in terminal states if we've exceeded the maximum number of activities.
Only removes activities that are in terminal states (COMPLETE, FAILED, TERMINATED)
to preserve activities that are still in progress.
"""
if len(self.states) > OET_HISTORY_SIZE:
# Sort activities by ID (oldest first)
sorted_activity_ids = sorted(self.states.keys())
# Find activities in terminal states, starting from the oldest
to_delete = []
for activity_id in sorted_activity_ids:
# Stop when we have enough remaining activities
remaining_count = len(self.states) - len(to_delete)
if remaining_count <= OET_HISTORY_SIZE:
break
current_state = self._get_current_activity_state(activity_id)
if current_state in DELETABLE_ACTIVITY_STATES:
to_delete.append(activity_id)
# Clean up activities in terminal states
for activity_id in to_delete:
self._cleanup_activity(activity_id)
def _cleanup_activity(self, activity_id: int) -> None:
"""
Remove activity and all its related data from state variables.
:param activity_id: Activity ID to clean up
"""
# Remove from all state dictionaries
self.states.pop(activity_id, None)
self.script_args.pop(activity_id, None)
self.activities.pop(activity_id, None)
self._main_dispatched.discard(activity_id)
# Remove from request_id mapping (need to search for values)
request_ids_to_remove = [
req_id
for req_id, aid in self.request_ids_to_aid.items()
if aid == activity_id
]
for req_id in request_ids_to_remove:
del self.request_ids_to_aid[req_id]
LOGGER.debug("Pruned state for oldest activity with ID %d", activity_id)
[docs]
def prepare_run_activity(self, cmd: ActivityCommand, request_id: int) -> None:
"""
Prepare to run the activity of a Scheduling Block. This includes retrieving the script
from the scheduling block and sending the request messages to the
ScriptExecutionService to prepare the script.
The request_id is required to be propagated through the messages sent to the Procedure layer,
so the REST layer can wait for the correct response event.
:param cmd: dataclass argument capturing the activity name and SB ID
:param request_id: The original request_id from the REST layer
"""
aid = next(self._aid_counter)
# Check if we need to clean up oldest activities before adding the new one
self._cleanup_activity_history()
# Record PREPARING state immediately to prevent race conditions with procedure lifecycle events
self.states[aid].append((ActivityState.PREPARING, time.time()))
# Publish initial activity lifecycle state change event
pub.sendMessage(
topics.activity.lifecycle.statechange,
msg_src=str(aid),
new_state=ActivityState.PREPARING,
)
with self._oda as oda:
sbd: SBDefinition = oda.sbds.get(cmd.sbd_id)
telescope = sbd.telescope
sbi = self._create_sbi(telescope, cmd, sbd_version=sbd.metadata.version)
sbi = oda.sbis.add(sbi)
oda.commit()
if (pdm_script := sbd.activities.get(cmd.activity_name)) is None:
raise KeyError(
f"Activity '{cmd.activity_name}' not present in the SBDefinition"
f" {cmd.sbd_id}"
)
script = self._get_oet_script(pdm_script, cmd.create_env)
script_args = self._combine_script_args(pdm_script, cmd)
sbd_path = self.write_sbd_to_file(sbd)
script_args["main"].kwargs.update({"sb_json": sbd_path, "sbi_id": sbi.sbi_id})
prepare_cmd = PrepareProcessCommand(
script=script,
init_args=script_args.get("init", domain.ProcedureInput()),
)
pub.sendMessage(
topics.request.procedure.create,
# Setting the msg_src as None means the republish logic will recognise the
# message has originated from its local pypubsub and should be republished
msg_src=None,
request_id=request_id,
cmd=prepare_cmd,
)
# The Activity dataclass is an internal representation of the Activity. The procedure_id will be populated
# once the procedure created event has been received
self.activities[aid] = Activity(
activity_id=aid,
procedure_id=None,
activity_name=cmd.activity_name,
sbd_id=cmd.sbd_id,
prepare_only=cmd.prepare_only,
sbi_id=sbi.sbi_id,
)
self.script_args[aid] = script_args
self.request_ids_to_aid[request_id] = aid
[docs]
def link_procedure_to_activity(
self, prepared_summary: ProcedureSummary, request_id: int
) -> ActivitySummary | None:
"""
Link a newly-created Procedure to the Activity that requested it, using
the ProcedureSummary that is now available.
:param prepared_summary: the ProcedureSummary for the Procedure related
to the requested Activity
:param request_id: The original request_id from the REST layer
:returns: an ActivitySummary describing the state of the Activity that
the Procedure is linked to, or None if the Procedure was not
created from an Activity
"""
try:
aid = self.request_ids_to_aid[request_id]
except KeyError:
# The request_id does not match a request sent to the activity domain, so the procedure is not linked to an activity
return None
activity = self.activities[aid]
# Now the ProcedureSummary is available, update the Activity with the procedure_id
activity.procedure_id = prepared_summary.id
return self._summarise(aid)
[docs]
def dispatch_main(self, ready_summary: ProcedureSummary) -> None:
"""
Start the Procedure's main function now that init has completed and the
procedure is in READY state. Triggered by procedure.lifecycle.ready.
:param ready_summary: the ProcedureSummary for the Procedure that is
now ready
"""
# Find the activity linked to this procedure by iterating activities
# (lifecycle.ready carries request_id=None so we can't use request_ids_to_aid)
activity = None
aid = None
for candidate_aid, candidate_activity in self.activities.items():
if candidate_activity.procedure_id == ready_summary.id:
activity = candidate_activity
aid = candidate_aid
break
if activity is None or activity.prepare_only:
return
# lifecycle.ready fires on every READY transition, after every user
# function completes. Skip if main has already been called for this
# activity so that we only ever call main() once.
if aid in self._main_dispatched:
return
self._main_dispatched.add(aid)
# TODO: should we allow here for multiple functions or limit to just main as is assumed by PM?
fns_to_start = [fn for fn in self.script_args[aid].keys() if fn != "init"]
for fn in fns_to_start:
start_cmd = StartProcessCommand(
ready_summary.id,
fn_name=fn,
run_args=self.script_args[aid][fn],
force_start=True,
)
pub.sendMessage(
topics.request.procedure.start,
# Setting the msg_src as None means the republish logic will recognise the
# message has originated from its local pypubsub and should be republished
msg_src=None,
request_id=None,
cmd=start_cmd,
)
pub.sendMessage(
topics.sb.lifecycle.started, msg_src=None, sbi_id=activity.sbi_id
)
[docs]
def summarise(self, activity_ids: list[int] | None = None) -> list[ActivitySummary]:
"""
Return ActivitySummary objects for Activities with the requested IDs.
This method accepts an optional list of integers, representing the
Activity IDs to summarise. If the IDs are left undefined,
ActivitySummary objects for all current Activities will be returned.
:param activity_ids: optional list of Activity IDs to summarise.
:return: list of ActivitySummary objects
"""
all_activity_ids = self.states.keys()
if activity_ids is None:
activity_ids = all_activity_ids
missing_pids = {p for p in activity_ids if p not in all_activity_ids}
if missing_pids:
raise ValueError(f"Activity IDs not found: {missing_pids}")
return [self._summarise(activity_id) for activity_id in activity_ids]
def _summarise(self, aid: int) -> ActivitySummary:
"""
Return a ActivitySummary for the Activity with the given ID.
:param aid: Activity ID to summarise
:return: ActivitySummary
"""
state = list(self.states[aid]) # Convert deque to list for ActivitySummary
activity = self.activities[aid]
script_args = self.script_args[aid]
return ActivitySummary(
id=aid,
pid=activity.procedure_id,
activity_name=activity.activity_name,
sbd_id=activity.sbd_id,
script_args=script_args,
activity_states=state,
prepare_only=activity.prepare_only,
sbi_id=activity.sbi_id,
)
def _get_oet_script(
self, pdm_script: PythonProcedure, create_env: bool
) -> domain.ExecutableScript:
"""
Converts the PDM representation of the script retrieved from the SB into the OET representation.
"""
if isinstance(pdm_script, GitScript):
git_args = domain.GitArgs(
git_repo=pdm_script.repo,
git_branch=pdm_script.branch,
git_commit=pdm_script.commit,
)
return domain.GitScript(
script_uri=pdm_script.path,
git_args=git_args,
create_env=create_env,
)
elif isinstance(pdm_script, FilesystemScript):
return domain.FileSystemScript(script_uri=pdm_script.path)
else:
raise RuntimeError(
f"Cannot run script with type {pdm_script.__class__.__name__}"
)
def _combine_script_args(
self, pdm_script: PythonProcedure, cmd: ActivityCommand
) -> dict[str, domain.ProcedureInput]:
"""
Combines the function args from the SB with any overwrites sent in the command,
returning a dict of the OET representation of the args for each function.
BTN-3197 Positional arguments are no longer forwarded to scripts, instead a
warning is logged if any are present in either the SBD or the command.
"""
# First turn the PDM format into the OET ProcedureInput, kwargs only
script_args = {}
for fn, fn_args in pdm_script.function_args.items():
if fn_args.args:
LOGGER.warning(
(
"Ignoring positional arguments for '%s' in SBD %s (scripts must"
" use keyword arguments only): %s"
),
fn,
cmd.sbd_id,
fn_args.args,
)
script_args[fn] = domain.ProcedureInput(**fn_args.kwargs)
# Then add any from the cmd, with the cmd taking priority
for fn_name, cmd_input in cmd.script_args.items():
if cmd_input.args:
LOGGER.warning(
(
"Ignoring positional arguments for '%s' in command for SBD %s "
"(scripts must use keyword arguments only): %s"
),
fn_name,
cmd.sbd_id,
cmd_input.args,
)
if not cmd_input.kwargs:
continue
if fn_name in script_args:
sbd_input = script_args[fn_name]
# Merge: SBD kwargs form the base, cmd kwargs override on conflict.
script_args[fn_name] = domain.ProcedureInput(
**{**sbd_input.kwargs, **cmd_input.kwargs}
)
else:
script_args[fn_name] = domain.ProcedureInput(**cmd_input.kwargs)
return script_args
[docs]
def write_sbd_to_file(self, sbd: SBDefinition) -> str:
"""
Writes the SBD json to a temporary file location and returns the path.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".json", encoding="utf-8"
) as f:
path = f.name
LOGGER.debug("Writing SB %s to path: %s", sbd.sbd_id, path)
f.write(sbd.model_dump_json())
return path
def _create_sbi(
self, telescope: TelescopeType, cmd: ActivityCommand, sbd_version: int
) -> SBInstance:
"""
Creates an SBInstance from the relevant fields in the command.
"""
function_args = [
FunctionArgs(
function_name=fn_name,
function_args=PythonArguments(
args=list(procedure_input.args), kwargs=procedure_input.kwargs
),
)
for (fn_name, procedure_input) in cmd.script_args.items()
]
# sbi_id is left as None and will be set when uploaded to the ODA
return SBInstance(
interface="https://schema.skao.int/ska-oso-pdm-sbi/0.1",
telescope=telescope,
sbd_ref=cmd.sbd_id,
sbd_version=sbd_version,
subarray_id=os.getenv("SUBARRAY_ID"),
activities=[
ActivityCall(
activity_ref=cmd.activity_name,
executed_at=datetime.now(tz=timezone.utc),
runtime_args=function_args,
)
],
)
def _on_procedure_statechange(
self, msg_src: str, new_state: ProcedureState
) -> None:
"""
Handle procedure state change events to update corresponding activity states.
:param msg_src: procedure ID as string
:param new_state: new procedure state
"""
try:
procedure_id = int(msg_src)
except ValueError:
LOGGER.warning(
"Received invalid procedure ID %s in state change event", msg_src
)
return
# Find the activity associated with this procedure
activity_id = None
for aid, activity in self.activities.items():
if activity.procedure_id == procedure_id:
activity_id = aid
break
if activity_id is None:
# This procedure is not linked to any activity, ignore the event
return
# Map procedure state to activity state
activity_state = self._map_procedure_state_to_activity_state(new_state)
# Update activity state
current_time = time.time()
self.states[activity_id].append((activity_state, current_time))
# Publish activity lifecycle state change event
pub.sendMessage(
topics.activity.lifecycle.statechange,
msg_src=str(activity_id),
new_state=activity_state,
)
LOGGER.info(
"Activity %d state updated to %s due to procedure %d state change to %s",
activity_id,
activity_state.value,
procedure_id,
new_state.value,
)
def _map_procedure_state_to_activity_state(
self, procedure_state: ProcedureState
) -> ActivityState:
"""
Map ProcedureState to ActivityState according to the implementation plan.
This code is implemented as a function rather than a dict.get() call so that we can
log when unknown states are received.
:param procedure_state: the procedure state to map
:return: corresponding activity state
"""
try:
return _PROCEDURE_TO_ACTIVITY_STATE_MAPPING[procedure_state]
except KeyError:
LOGGER.warning(
"Unknown procedure state %s received, mapping to FAILED",
procedure_state.value
if hasattr(procedure_state, "value")
else str(procedure_state),
)
return ActivityState.FAILED