Source code for ska_oso_oet.activity.application

"""
The ska_oso_oet.activity.application module contains code related
to OET 'activities' that belong in the application layer. This application
layer holds the application interface, delegating to objects in the domain
layer for business rules and actions.
"""

import collections
import itertools
import logging
import os
import tempfile
import time
from datetime import datetime, timezone
from typing import Any

from pubsub import pub
from pydantic import BaseModel, Field, computed_field, model_serializer
from pydantic_core.core_schema import SerializerFunctionWrapHandler
from ska_db_oda.unitofwork.postgresunitofwork import PostgresUnitOfWork
from ska_oso_pdm import SBDefinition
from ska_oso_pdm._shared import TelescopeType
from ska_oso_pdm.sb_definition.procedures import (
    FilesystemScript,
    GitScript,
    PythonArguments,
    PythonProcedure,
)
from ska_oso_pdm.sb_instance import ActivityCall, FunctionArgs, SBInstance

from ska_oso_oet.activity.domain import Activity, ActivityState
from ska_oso_oet.config import OET_HISTORY_SIZE
from ska_oso_oet.event import topics
from ska_oso_oet.procedure import domain
from ska_oso_oet.procedure.application import (
    PrepareProcessCommand,
    ProcedureSummary,
    StartProcessCommand,
)
from ska_oso_oet.procedure.domain import ProcedureState
from ska_oso_oet.utils.ui import API_PATH

LOGGER = logging.getLogger(__name__)

# Terminal states for activity cleanup - activities in these states can be removed
# when the activity history exceeds OET_HISTORY_SIZE
DELETABLE_ACTIVITY_STATES = frozenset(
    [
        ActivityState.COMPLETE,
        ActivityState.FAILED,
        ActivityState.TERMINATED,
    ]
)

# State mapping from ProcedureState to ActivityState
# This dictionary defines how procedure lifecycle states map to activity states
_PROCEDURE_TO_ACTIVITY_STATE_MAPPING = {
    ProcedureState.IDLE: ActivityState.PREPARING,
    ProcedureState.CREATING: ActivityState.PREPARING,
    ProcedureState.PREP_ENV: ActivityState.PREPARING,
    ProcedureState.LOADING: ActivityState.PREPARING,
    ProcedureState.INITIALISING: ActivityState.PREPARING,
    ProcedureState.READY: ActivityState.READY,
    ProcedureState.RUNNING: ActivityState.RUNNING,
    ProcedureState.COMPLETE: ActivityState.COMPLETE,
    ProcedureState.FAILED: ActivityState.FAILED,
    ProcedureState.STOPPED: ActivityState.TERMINATED,
    ProcedureState.UNKNOWN: ActivityState.FAILED,  # Error condition mapping
}


class ActivityCommand(BaseModel):
    """ """

    activity_name: str
    sbd_id: str
    prepare_only: bool
    create_env: bool
    script_args: dict[str, domain.ProcedureInput]

    def __init__(
        self,
        activity_name: str,
        sbd_id: str,
        prepare_only: bool,
        create_env: bool,
        script_args: dict[str, domain.ProcedureInput],
    ):
        super(ActivityCommand, self).__init__(
            activity_name=activity_name,
            sbd_id=sbd_id,
            prepare_only=prepare_only,
            create_env=create_env,
            script_args=script_args,
        )


[docs] class ActivitySummary(BaseModel): id: int # pylint: disable=invalid-name pid: int = (Field(alias="procedure_id"),) procedure_id: int sbd_id: str activity_name: str prepare_only: bool script_args: dict[str, domain.ProcedureInput] activity_states: list[tuple[ActivityState, float]] state: str | None = None sbi_id: str @computed_field @property def uri(self) -> str: return f"http://localhost{API_PATH}/activities/{self.id}"
[docs] def __init__( self, id: int, # pylint: disable=redefined-builtin pid: int, sbd_id: str, activity_name: str, prepare_only: bool, script_args: dict[str, domain.ProcedureInput], activity_states: list[tuple[ActivityState, float]] | collections.deque | None, sbi_id: str, uri: str | None = None, state: str | None = None, ): # Normalize deque to list for serialization if activity_states is None: activity_states = [] elif isinstance(activity_states, collections.deque): activity_states = list(activity_states) super().__init__( id=id, pid=pid, procedure_id=pid, uri=uri, sbd_id=sbd_id, activity_name=activity_name, prepare_only=prepare_only, script_args=script_args, activity_states=activity_states, state=state, sbi_id=sbi_id, )
@model_serializer(mode="wrap") def _serialize_activity_summary( self, default_serializer: SerializerFunctionWrapHandler ) -> dict[str, Any]: script_args = { fn: { "args": self.script_args[fn].args, "kwargs": self.script_args[fn].kwargs, } for fn in self.script_args.keys() } activity_states = [ (state_enum.name, timestamp) for (state_enum, timestamp) in self.activity_states ] state = max( states_to_time := dict(self.activity_states), key=states_to_time.get ).name dumped = default_serializer(self) dumped["state"] = state dumped["activity_states"] = activity_states dumped["script_args"] = script_args return dumped
[docs] class ActivityService: """ ActivityService provides the high-level interface and facade for the activity domain. The interface is used to run activities referenced by Scheduling Blocks. Each activity will run a script (or `procedure`) but ActivityService will create the necessary commands for Procedure domain to create and execute the scripts. """
[docs] def __init__( self, ): # ActivityService stores activity state history using FIFO deques with unbounded size per activity # The total number of activities is managed by cleanup logic based on OET_HISTORY_SIZE self.states: dict[int, collections.deque] = collections.defaultdict( collections.deque ) self.script_args: dict[int, dict[str, domain.ProcedureInput]] = {} self.activities: dict[int, Activity] = {} # We need to store this state as the service needs to check if a procedure that has been created # is the result of an activity request self.request_ids_to_aid: dict[int, int] = {} # The script workers currently only allow one execution of main(), so we # must Track which activities have had main dispatched, to prevent duplicate # RUN(main) when lifecycle.ready fires a second time (after main completes) # As/when we allow long-lived scripts with multiple executions, this should # be removed. self._main_dispatched: set[int] = set() # counter used to generate activity ID for new activities self._aid_counter = itertools.count(1) self._oda = PostgresUnitOfWork() # Subscribe to procedure lifecycle events for state synchronization pub.subscribe( self._on_procedure_statechange, topics.procedure.lifecycle.statechange )
def _get_current_activity_state(self, activity_id: int) -> ActivityState | None: """ Get the current state of an activity from its state history. :param activity_id: Activity ID to get state for :return: Current ActivityState or None if no states recorded """ state_history = self.states.get(activity_id) if not state_history: return None # Return the most recent state (last item in deque) return state_history[-1][0] def _cleanup_activity_history(self) -> None: """ Clean up oldest activities in terminal states if we've exceeded the maximum number of activities. Only removes activities that are in terminal states (COMPLETE, FAILED, TERMINATED) to preserve activities that are still in progress. """ if len(self.states) > OET_HISTORY_SIZE: # Sort activities by ID (oldest first) sorted_activity_ids = sorted(self.states.keys()) # Find activities in terminal states, starting from the oldest to_delete = [] for activity_id in sorted_activity_ids: # Stop when we have enough remaining activities remaining_count = len(self.states) - len(to_delete) if remaining_count <= OET_HISTORY_SIZE: break current_state = self._get_current_activity_state(activity_id) if current_state in DELETABLE_ACTIVITY_STATES: to_delete.append(activity_id) # Clean up activities in terminal states for activity_id in to_delete: self._cleanup_activity(activity_id) def _cleanup_activity(self, activity_id: int) -> None: """ Remove activity and all its related data from state variables. :param activity_id: Activity ID to clean up """ # Remove from all state dictionaries self.states.pop(activity_id, None) self.script_args.pop(activity_id, None) self.activities.pop(activity_id, None) self._main_dispatched.discard(activity_id) # Remove from request_id mapping (need to search for values) request_ids_to_remove = [ req_id for req_id, aid in self.request_ids_to_aid.items() if aid == activity_id ] for req_id in request_ids_to_remove: del self.request_ids_to_aid[req_id] LOGGER.debug("Pruned state for oldest activity with ID %d", activity_id)
[docs] def prepare_run_activity(self, cmd: ActivityCommand, request_id: int) -> None: """ Prepare to run the activity of a Scheduling Block. This includes retrieving the script from the scheduling block and sending the request messages to the ScriptExecutionService to prepare the script. The request_id is required to be propagated through the messages sent to the Procedure layer, so the REST layer can wait for the correct response event. :param cmd: dataclass argument capturing the activity name and SB ID :param request_id: The original request_id from the REST layer """ aid = next(self._aid_counter) # Check if we need to clean up oldest activities before adding the new one self._cleanup_activity_history() # Record PREPARING state immediately to prevent race conditions with procedure lifecycle events self.states[aid].append((ActivityState.PREPARING, time.time())) # Publish initial activity lifecycle state change event pub.sendMessage( topics.activity.lifecycle.statechange, msg_src=str(aid), new_state=ActivityState.PREPARING, ) with self._oda as oda: sbd: SBDefinition = oda.sbds.get(cmd.sbd_id) telescope = sbd.telescope sbi = self._create_sbi(telescope, cmd, sbd_version=sbd.metadata.version) sbi = oda.sbis.add(sbi) oda.commit() if (pdm_script := sbd.activities.get(cmd.activity_name)) is None: raise KeyError( f"Activity '{cmd.activity_name}' not present in the SBDefinition" f" {cmd.sbd_id}" ) script = self._get_oet_script(pdm_script, cmd.create_env) script_args = self._combine_script_args(pdm_script, cmd) sbd_path = self.write_sbd_to_file(sbd) script_args["main"].kwargs.update({"sb_json": sbd_path, "sbi_id": sbi.sbi_id}) prepare_cmd = PrepareProcessCommand( script=script, init_args=script_args.get("init", domain.ProcedureInput()), ) pub.sendMessage( topics.request.procedure.create, # Setting the msg_src as None means the republish logic will recognise the # message has originated from its local pypubsub and should be republished msg_src=None, request_id=request_id, cmd=prepare_cmd, ) # The Activity dataclass is an internal representation of the Activity. The procedure_id will be populated # once the procedure created event has been received self.activities[aid] = Activity( activity_id=aid, procedure_id=None, activity_name=cmd.activity_name, sbd_id=cmd.sbd_id, prepare_only=cmd.prepare_only, sbi_id=sbi.sbi_id, ) self.script_args[aid] = script_args self.request_ids_to_aid[request_id] = aid
[docs] def dispatch_main(self, ready_summary: ProcedureSummary) -> None: """ Start the Procedure's main function now that init has completed and the procedure is in READY state. Triggered by procedure.lifecycle.ready. :param ready_summary: the ProcedureSummary for the Procedure that is now ready """ # Find the activity linked to this procedure by iterating activities # (lifecycle.ready carries request_id=None so we can't use request_ids_to_aid) activity = None aid = None for candidate_aid, candidate_activity in self.activities.items(): if candidate_activity.procedure_id == ready_summary.id: activity = candidate_activity aid = candidate_aid break if activity is None or activity.prepare_only: return # lifecycle.ready fires on every READY transition, after every user # function completes. Skip if main has already been called for this # activity so that we only ever call main() once. if aid in self._main_dispatched: return self._main_dispatched.add(aid) # TODO: should we allow here for multiple functions or limit to just main as is assumed by PM? fns_to_start = [fn for fn in self.script_args[aid].keys() if fn != "init"] for fn in fns_to_start: start_cmd = StartProcessCommand( ready_summary.id, fn_name=fn, run_args=self.script_args[aid][fn], force_start=True, ) pub.sendMessage( topics.request.procedure.start, # Setting the msg_src as None means the republish logic will recognise the # message has originated from its local pypubsub and should be republished msg_src=None, request_id=None, cmd=start_cmd, ) pub.sendMessage( topics.sb.lifecycle.started, msg_src=None, sbi_id=activity.sbi_id )
[docs] def summarise(self, activity_ids: list[int] | None = None) -> list[ActivitySummary]: """ Return ActivitySummary objects for Activities with the requested IDs. This method accepts an optional list of integers, representing the Activity IDs to summarise. If the IDs are left undefined, ActivitySummary objects for all current Activities will be returned. :param activity_ids: optional list of Activity IDs to summarise. :return: list of ActivitySummary objects """ all_activity_ids = self.states.keys() if activity_ids is None: activity_ids = all_activity_ids missing_pids = {p for p in activity_ids if p not in all_activity_ids} if missing_pids: raise ValueError(f"Activity IDs not found: {missing_pids}") return [self._summarise(activity_id) for activity_id in activity_ids]
def _summarise(self, aid: int) -> ActivitySummary: """ Return a ActivitySummary for the Activity with the given ID. :param aid: Activity ID to summarise :return: ActivitySummary """ state = list(self.states[aid]) # Convert deque to list for ActivitySummary activity = self.activities[aid] script_args = self.script_args[aid] return ActivitySummary( id=aid, pid=activity.procedure_id, activity_name=activity.activity_name, sbd_id=activity.sbd_id, script_args=script_args, activity_states=state, prepare_only=activity.prepare_only, sbi_id=activity.sbi_id, ) def _get_oet_script( self, pdm_script: PythonProcedure, create_env: bool ) -> domain.ExecutableScript: """ Converts the PDM representation of the script retrieved from the SB into the OET representation. """ if isinstance(pdm_script, GitScript): git_args = domain.GitArgs( git_repo=pdm_script.repo, git_branch=pdm_script.branch, git_commit=pdm_script.commit, ) return domain.GitScript( script_uri=pdm_script.path, git_args=git_args, create_env=create_env, ) elif isinstance(pdm_script, FilesystemScript): return domain.FileSystemScript(script_uri=pdm_script.path) else: raise RuntimeError( f"Cannot run script with type {pdm_script.__class__.__name__}" ) def _combine_script_args( self, pdm_script: PythonProcedure, cmd: ActivityCommand ) -> dict[str, domain.ProcedureInput]: """ Combines the function args from the SB with any overwrites sent in the command, returning a dict of the OET representation of the args for each function. BTN-3197 Positional arguments are no longer forwarded to scripts, instead a warning is logged if any are present in either the SBD or the command. """ # First turn the PDM format into the OET ProcedureInput, kwargs only script_args = {} for fn, fn_args in pdm_script.function_args.items(): if fn_args.args: LOGGER.warning( ( "Ignoring positional arguments for '%s' in SBD %s (scripts must" " use keyword arguments only): %s" ), fn, cmd.sbd_id, fn_args.args, ) script_args[fn] = domain.ProcedureInput(**fn_args.kwargs) # Then add any from the cmd, with the cmd taking priority for fn_name, cmd_input in cmd.script_args.items(): if cmd_input.args: LOGGER.warning( ( "Ignoring positional arguments for '%s' in command for SBD %s " "(scripts must use keyword arguments only): %s" ), fn_name, cmd.sbd_id, cmd_input.args, ) if not cmd_input.kwargs: continue if fn_name in script_args: sbd_input = script_args[fn_name] # Merge: SBD kwargs form the base, cmd kwargs override on conflict. script_args[fn_name] = domain.ProcedureInput( **{**sbd_input.kwargs, **cmd_input.kwargs} ) else: script_args[fn_name] = domain.ProcedureInput(**cmd_input.kwargs) return script_args
[docs] def write_sbd_to_file(self, sbd: SBDefinition) -> str: """ Writes the SBD json to a temporary file location and returns the path. """ with tempfile.NamedTemporaryFile( mode="w", delete=False, suffix=".json", encoding="utf-8" ) as f: path = f.name LOGGER.debug("Writing SB %s to path: %s", sbd.sbd_id, path) f.write(sbd.model_dump_json()) return path
def _create_sbi( self, telescope: TelescopeType, cmd: ActivityCommand, sbd_version: int ) -> SBInstance: """ Creates an SBInstance from the relevant fields in the command. """ function_args = [ FunctionArgs( function_name=fn_name, function_args=PythonArguments( args=list(procedure_input.args), kwargs=procedure_input.kwargs ), ) for (fn_name, procedure_input) in cmd.script_args.items() ] # sbi_id is left as None and will be set when uploaded to the ODA return SBInstance( interface="https://schema.skao.int/ska-oso-pdm-sbi/0.1", telescope=telescope, sbd_ref=cmd.sbd_id, sbd_version=sbd_version, subarray_id=os.getenv("SUBARRAY_ID"), activities=[ ActivityCall( activity_ref=cmd.activity_name, executed_at=datetime.now(tz=timezone.utc), runtime_args=function_args, ) ], ) def _on_procedure_statechange( self, msg_src: str, new_state: ProcedureState ) -> None: """ Handle procedure state change events to update corresponding activity states. :param msg_src: procedure ID as string :param new_state: new procedure state """ try: procedure_id = int(msg_src) except ValueError: LOGGER.warning( "Received invalid procedure ID %s in state change event", msg_src ) return # Find the activity associated with this procedure activity_id = None for aid, activity in self.activities.items(): if activity.procedure_id == procedure_id: activity_id = aid break if activity_id is None: # This procedure is not linked to any activity, ignore the event return # Map procedure state to activity state activity_state = self._map_procedure_state_to_activity_state(new_state) # Update activity state current_time = time.time() self.states[activity_id].append((activity_state, current_time)) # Publish activity lifecycle state change event pub.sendMessage( topics.activity.lifecycle.statechange, msg_src=str(activity_id), new_state=activity_state, ) LOGGER.info( "Activity %d state updated to %s due to procedure %d state change to %s", activity_id, activity_state.value, procedure_id, new_state.value, ) def _map_procedure_state_to_activity_state( self, procedure_state: ProcedureState ) -> ActivityState: """ Map ProcedureState to ActivityState according to the implementation plan. This code is implemented as a function rather than a dict.get() call so that we can log when unknown states are received. :param procedure_state: the procedure state to map :return: corresponding activity state """ try: return _PROCEDURE_TO_ACTIVITY_STATE_MAPPING[procedure_state] except KeyError: LOGGER.warning( "Unknown procedure state %s received, mapping to FAILED", procedure_state.value if hasattr(procedure_state, "value") else str(procedure_state), ) return ActivityState.FAILED