Source code for ska_low_mccs.calibration_solver.solver_component_manager

# -*- coding: utf-8 -*
#
# This file is part of the SKA Low MCCS project
#
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.
"""This module provides a component manager for a station calibration solver."""

from __future__ import annotations

import logging
import threading
from pathlib import Path
from typing import Any, Callable, Final

import h5py
import numpy as np
from astropy.time import Time
from ska_control_model import CommunicationStatus, PowerState, ResultCode, TaskStatus
from ska_low_mccs_calibration.calibration import calibrate_mccs_visibility
from ska_low_mccs_calibration.eep import convert_eep2npy
from ska_low_mccs_calibration.utils import read_station_config, sdp_visibility_datamodel
from ska_tango_base.base import check_communicating
from ska_tango_base.executor import TaskExecutorComponentManager
from ska_telmodel.data import TMData  # type: ignore

from .calibration_solution_product import (
    LATEST_STRUCTURE_VERSION,
    CalibrationSolutionProduct,
)

__all__ = ["StationCalibrationSolverComponentManager"]


NOF_CHANNELS: Final[int] = 512
BANDWIDTH: Final[float] = 400.0


[docs] class StationCalibrationSolverComponentManager(TaskExecutorComponentManager): """A component manager for a station calibration solver."""
[docs] def __init__( self: StationCalibrationSolverComponentManager, root_path: str, eep_root_path: str, logger: logging.Logger, communication_state_callback: Callable[[CommunicationStatus], None], component_state_callback: Callable[..., None], ): """ Initialise a new instance. :param root_path: the root path for loading data from :param eep_root_path: the root path for the embedded element pattern files. :param logger: the logger to be used by this object. :param communication_state_callback: callback to be called when the status of the communications channel between the component manager and its component changes :param component_state_callback: callback to be called when the component state changes. """ self._root_path = Path(root_path) self.eep_root_path = Path(eep_root_path) self.eep_store_path_to_use = str(self.eep_root_path) self.eep_suffix = ".npy" # TODO: Do we want to clear these files in self.internal_eep_store, # it may save time caching them, but, may also bloat. self.internal_eep_store = Path("/app/src/ska_low_mccs_calibration/tango/") self._channel_file_mapping: dict[int, dict[int, str]] = {} super().__init__( logger, communication_state_callback, component_state_callback, max_workers=1, power=None, fault=None, )
[docs] def start_communicating(self: StationCalibrationSolverComponentManager) -> None: """Establish communication with the station components.""" if self.communication_state == CommunicationStatus.ESTABLISHED: return if self.communication_state == CommunicationStatus.DISABLED: self._update_communication_state(CommunicationStatus.ESTABLISHED) self._update_component_state(power=PowerState.ON)
[docs] def stop_communicating(self: StationCalibrationSolverComponentManager) -> None: """Break off communication with the station components.""" if self.communication_state == CommunicationStatus.DISABLED: return self._update_communication_state(CommunicationStatus.DISABLED) self._update_component_state(power=None, fault=None)
[docs] @check_communicating def off( self: StationCalibrationSolverComponentManager, task_callback: Callable | None = None, ) -> tuple[TaskStatus, str]: """ Turn the component off. :param task_callback: callback to be called when the status of the command changes :raises NotImplementedError: Not implemented it's an abstract class """ raise NotImplementedError( "The station calibration solver is an always-on component." )
[docs] @check_communicating def standby( self: StationCalibrationSolverComponentManager, task_callback: Callable | None = None, ) -> tuple[TaskStatus, str]: """ Put the component into low-power standby mode. :param task_callback: callback to be called when the status of the command changes :raises NotImplementedError: Not implemented it's an abstract class """ raise NotImplementedError( "The station calibration solver is an always-on component." )
[docs] @check_communicating def on( self: StationCalibrationSolverComponentManager, task_callback: Callable | None = None, ) -> tuple[TaskStatus, str]: """ Turn the component on. :param task_callback: callback to be called when the status of the command changes :raises NotImplementedError: Not implemented it's an abstract class """ raise NotImplementedError( "The station calibration solver is an always-on component." )
[docs] @check_communicating def reset( self: StationCalibrationSolverComponentManager, task_callback: Callable | None = None, ) -> tuple[TaskStatus, str]: """ Reset the component (from fault state). :param task_callback: callback to be called when the status of the command changes :raises NotImplementedError: Not implemented it's an abstract class """ raise NotImplementedError("The station calibration solver cannot be reset.")
# pylint: disable=too-many-locals # pylint: disable=too-many-return-statements # pylint: disable=too-many-statements
[docs] @check_communicating def solve( # noqa: C901 self: StationCalibrationSolverComponentManager, data_path: str, solution_path: str, station_config_path: tuple[str, str], task_callback: Callable | None = None, task_abort_event: threading.Event | None = None, **kwargs: Any, ) -> None: """ Solve for a calibration solution. :param data_path: path to the stored data to use for solving for calibration. :param solution_path: the path of the solution to be written. :param station_config_path: a list used to locate configuration from TelModel. :param kwargs: any kwargs :param task_callback: Update task state, defaults to None :param task_abort_event: Check for abort, defaults to None """ def _report_status(status: TaskStatus, **kwargs: Any) -> None: if task_callback is not None: task_callback(status=status, **kwargs) def _check_aborted() -> bool: if task_abort_event and task_abort_event.is_set(): self.logger.info("Solve task has been aborted") _report_status( TaskStatus.ABORTED, result=(ResultCode.ABORTED, "Task aborted"), ) return True return False self.logger.info(f"Solving for calibration using data {data_path}...") _report_status(TaskStatus.IN_PROGRESS) if _check_aborted(): return eep_filebase = kwargs.pop("eep_filebase", "") self.logger.info(f"Loading dataset {data_path}...") try: observation_data = load_observation_data(str(self._root_path / data_path)) except FileNotFoundError as fnfe: self.logger.info(f"Cannot open file {data_path}: {repr(fnfe)}") _report_status( TaskStatus.COMPLETED, result=(ResultCode.FAILED, "File not found error") ) return self.logger.info(f"Successfully loaded dataset {data_path}.") if _check_aborted(): return self.logger.info("Loading configuration data...") # Get station configuration from TelModel using the station_config_path ( location, antenna_masks, baselines, enu_raw, pol, rotation, ) = self._get_station_config(station_config_path) if _check_aborted(): return self.logger.info("Packaging visibility...") # Restructure select data. masked_antennas = np.where(antenna_masks is True)[0] enu = np.array(enu_raw, dtype=np.complex64(0).real.dtype).T v_measurement = _package_visibility( observation_data["correlation_times_array"], observation_data["channel_id"], observation_data["int_time"], observation_data["itime"], observation_data["correlations"], baselines, masked_antennas, location, pol, enu, ) if not kwargs.get("ignore_eeps", False): eep_matches = list( self.eep_root_path.glob( f"{eep_filebase}{observation_data['channel_id']}MHz_*pol*.*" ) ) if len(eep_matches) == 0: self.logger.error( f"No EEPs found with correct naming, Looking for:\n" f"\t{str(self.eep_root_path)+str('/')+eep_filebase}" f"{str(observation_data['channel_id'])}MHz_?pol.?" ) _report_status( TaskStatus.COMPLETED, result=(ResultCode.FAILED, "File not found error"), ) return self._prepare_eeps( eep_matches, ) if _check_aborted(): return self.logger.info("Calibrating visibility...") gain_table, _, _, masked_antennas, calibration_info = calibrate_mccs_visibility( vis=v_measurement, masked_antennas=masked_antennas, skymodel=kwargs.get("skymodel", "gsm"), min_uv=kwargs.get("min_uv", 0), refant=kwargs.get("refant", 1), ignore_eeps=kwargs.get("ignore_eeps", False), gain_threshold=kwargs.get("gain_threshold", 0.25), eep_path=self.eep_store_path_to_use, eep_rotation_deg=rotation, eep_suffix=self.eep_suffix, eep_filebase=eep_filebase, jones_solve=kwargs.get("jones_solve", False), back_rotation=kwargs.get("back_rotation", False), adjust_solar_model=kwargs.get("adjust_solar_model", True), nside=kwargs.get("nside", 32), niter=kwargs.get("niter", 200), ) if _check_aborted(): return self.logger.info("Saving solution...") jones_matrices: Any = gain_table.gain.values flattened_matrices = jones_matrices.flatten() calibration_solutions = np.array( (np.real(flattened_matrices), np.imag(flattened_matrices)) ).T.ravel() self.logger.info(f"Storing products in location {solution_path}") np.save(str(self._root_path / solution_path), calibration_solutions) struc_version = kwargs.get("structure_version", LATEST_STRUCTURE_VERSION) self.logger.info(f"File created {solution_path}.npy") calibration_id = kwargs.get("calibration_id") calibration_product = CalibrationSolutionProduct( logger=self.logger, structure_version=struc_version, corrcoeff=calibration_info.corrcoeff, residual_max=calibration_info.residual_max, residual_std=calibration_info.residual_std, xy_phase=calibration_info.xy_phase, n_masked_initial=calibration_info.n_masked_initial, n_masked_final=calibration_info.n_masked_final, lst=calibration_info.lst, galactic_centre_elevation=(calibration_info.galactic_centre_elevation), sun_elevation=calibration_info.sun_elevation, sun_adjustment_factor=calibration_info.sun_adjustment_factor, masked_antennas=masked_antennas.tolist(), solution=calibration_solutions.tolist(), frequency_channel=observation_data["channel_id"], station_id=observation_data["station_id"], acquisition_time=observation_data["timestamp"], calibration_id=calibration_id, ) calibration_product.save_to_hdf5(str(self._root_path / solution_path) + ".h5") self.logger.info( f"Solution stored in location {str(self._root_path / solution_path)}" ) self.logger.info(f"Solution tagged with id: {calibration_id}") self._channel_file_mapping[observation_data["channel_id"]] = { observation_data["station_id"]: solution_path } _report_status( TaskStatus.COMPLETED, result=(ResultCode.OK, "Solution successfully calculated"), ) self.logger.info(f"Done solving for calibration using data {data_path}...")
def _get_station_config( self: StationCalibrationSolverComponentManager, station_config_path: tuple[str, str], ) -> Any: platform_config = TMData([station_config_path[0]])[ station_config_path[1] ].get_dict() # We (badly) need to have versioned parsers for this station_name = list(find_by_key(platform_config, "stations"))[0] station_config = find_by_key( platform_config["platform"]["stations"], station_name ) self.logger.info(f"reading config from {station_name}") return read_station_config(station_config) def _prepare_eeps(self, eep_matches: list[Path]) -> None: for eep_match in eep_matches: match eep_match.suffix: case ".npz": self.eep_store_path_to_use = str(self.eep_root_path) self.eep_suffix = ".npz" case ".npy": self.eep_store_path_to_use = str(self.eep_root_path) self.eep_suffix = ".npy" case ".mat": # must convert to .npy and use internal store. self.eep_store_path_to_use = str(self.internal_eep_store) self.eep_suffix = ".npy" convert_eep2npy( str(eep_match), npy_dir=self.eep_store_path_to_use, ) case _: print(f"EEP extension {eep_match.suffix} not supported.")
def channel_id_to_freq(channel_id: int | float) -> float: """ Turn channel_id to frequency. :param channel_id: Id of the channel :return: The frequency of the channel. """ channel_bw_mhz = BANDWIDTH / NOF_CHANNELS return float(channel_id) * channel_bw_mhz # pylint: disable=too-many-locals # pylint: disable=too-many-arguments def _package_visibility( correlation_times_array: Any, channel_id: Any, int_time: Any, itime: Any, correlations: Any, baselines: Any, masked_antennas: Any, location: Any, pol: Any, enu: Any, ) -> Any: bl_flags = np.logical_or( np.isin(baselines[0], masked_antennas), np.isin(baselines[1], masked_antennas), ) # May as well flag autos as well bl_flags = np.logical_or(bl_flags, baselines[0] == baselines[1]) # Stack flags for the four polarisations in the dataset vis_flags = np.tile(bl_flags, (4, 1)).T frequency_mhz = channel_id_to_freq(channel_id) time_array = Time(correlation_times_array, format="unix", location=location) time_array.format = "fits" data_time = time_array[itime] vis = correlations return sdp_visibility_datamodel( vis=vis[:, pol], flags=vis_flags[:, pol], uvw=enu[baselines[0]] - enu[baselines[1]], ant1=baselines[0], ant2=baselines[1], location=location, antpos_enu=enu, time=data_time, int_time=int_time, frequency_mhz=frequency_mhz, ) # pylint: disable=too-many-locals def load_observation_data(data_file_path: str) -> dict[str, Any]: """ Load the observation data from file and return it in a dictionary. :param data_file_path: path to the observation visibility data. :return: a dictionary containing the observation data. """ datafile = h5py.File(data_file_path, "r") correlation_metadata = dict(datafile["root"].attrs) ntimes = correlation_metadata["n_blocks"] itime = ntimes - 1 # Index of timestamp we want int_time = correlation_metadata["tsamp"] n_ant = correlation_metadata["n_antennas"] channel_id = correlation_metadata["channel_id"] station_id = correlation_metadata["station_id"] n_baselines = correlation_metadata["n_baselines"] n_pol = correlation_metadata["n_pols"] n_stokes = correlation_metadata["n_stokes"] timestamp = correlation_metadata.pop("timestamp", -1) # Chop off rounding errors int_time = round(int_time, 12) # Get the data associated with the correlation matrices correlation_data = np.squeeze(datafile["correlation_matrix"]["data"]) # Get the timestamps associated with the correlation matrices correlation_times = datafile["sample_timestamps"] if ntimes > 1: correlation_times_array = np.squeeze(correlation_times["data"]) correlations = correlation_data[itime] else: correlation_times_array = correlation_times["data"][itime] correlations = correlation_data return { "correlations": correlations, "itime": itime, "int_time": int_time, "timestamp": timestamp, "correlation_times_array": correlation_times_array, "n_ant": n_ant, "n_stokes": n_stokes, "n_baselines": n_baselines, "n_pol": n_pol, "channel_id": channel_id, "station_id": station_id, } def find_by_key(data: dict, target: str) -> Any: """ Search nested dict breadth-first for the first target key and return its value. This method is used to find station and antenna config within Low platform spec files, and should eventually be replaced by functions specifically designed to parse these files, aware of schema versions, etc, probably within ska-telmodel. :param data: generic nested dictionary to traverse through. :param target: key to find the first value of. :returns: the next value for given key. """ bfs_queue = list(data.items()) while bfs_queue: key, value = bfs_queue.pop(0) if key == target: return value if isinstance(value, dict): bfs_queue.extend(value.items()) return None