# -*- coding: utf-8 -*
#
# This file is part of the SKA Low MCCS project
#
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.
"""This module provides a component manager for a station calibration solver."""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from typing import Any, Callable, Final
import h5py
import numpy as np
from astropy.time import Time
from ska_control_model import CommunicationStatus, PowerState, ResultCode, TaskStatus
from ska_low_mccs_calibration.calibration import calibrate_mccs_visibility
from ska_low_mccs_calibration.eep import convert_eep2npy
from ska_low_mccs_calibration.utils import read_station_config, sdp_visibility_datamodel
from ska_tango_base.base import check_communicating
from ska_tango_base.executor import TaskExecutorComponentManager
from ska_telmodel.data import TMData # type: ignore
from .calibration_solution_product import (
LATEST_STRUCTURE_VERSION,
CalibrationSolutionProduct,
)
__all__ = ["StationCalibrationSolverComponentManager"]
NOF_CHANNELS: Final[int] = 512
BANDWIDTH: Final[float] = 400.0
[docs]
class StationCalibrationSolverComponentManager(TaskExecutorComponentManager):
"""A component manager for a station calibration solver."""
[docs]
def __init__(
self: StationCalibrationSolverComponentManager,
root_path: str,
eep_root_path: str,
logger: logging.Logger,
communication_state_callback: Callable[[CommunicationStatus], None],
component_state_callback: Callable[..., None],
):
"""
Initialise a new instance.
:param root_path: the root path for loading data from
:param eep_root_path: the root path for the embedded
element pattern files.
:param logger: the logger to be used by this object.
:param communication_state_callback: callback to be called when
the status of the communications channel between the
component manager and its component changes
:param component_state_callback: callback to be called when the
component state changes.
"""
self._root_path = Path(root_path)
self.eep_root_path = Path(eep_root_path)
self.eep_store_path_to_use = str(self.eep_root_path)
self.eep_suffix = ".npy"
# TODO: Do we want to clear these files in self.internal_eep_store,
# it may save time caching them, but, may also bloat.
self.internal_eep_store = Path("/app/src/ska_low_mccs_calibration/tango/")
self._channel_file_mapping: dict[int, dict[int, str]] = {}
super().__init__(
logger,
communication_state_callback,
component_state_callback,
max_workers=1,
power=None,
fault=None,
)
[docs]
def start_communicating(self: StationCalibrationSolverComponentManager) -> None:
"""Establish communication with the station components."""
if self.communication_state == CommunicationStatus.ESTABLISHED:
return
if self.communication_state == CommunicationStatus.DISABLED:
self._update_communication_state(CommunicationStatus.ESTABLISHED)
self._update_component_state(power=PowerState.ON)
[docs]
def stop_communicating(self: StationCalibrationSolverComponentManager) -> None:
"""Break off communication with the station components."""
if self.communication_state == CommunicationStatus.DISABLED:
return
self._update_communication_state(CommunicationStatus.DISABLED)
self._update_component_state(power=None, fault=None)
[docs]
@check_communicating
def off(
self: StationCalibrationSolverComponentManager,
task_callback: Callable | None = None,
) -> tuple[TaskStatus, str]:
"""
Turn the component off.
:param task_callback: callback to be called when the status of
the command changes
:raises NotImplementedError: Not implemented it's an abstract class
"""
raise NotImplementedError(
"The station calibration solver is an always-on component."
)
[docs]
@check_communicating
def standby(
self: StationCalibrationSolverComponentManager,
task_callback: Callable | None = None,
) -> tuple[TaskStatus, str]:
"""
Put the component into low-power standby mode.
:param task_callback: callback to be called when the status of
the command changes
:raises NotImplementedError: Not implemented it's an abstract class
"""
raise NotImplementedError(
"The station calibration solver is an always-on component."
)
[docs]
@check_communicating
def on(
self: StationCalibrationSolverComponentManager,
task_callback: Callable | None = None,
) -> tuple[TaskStatus, str]:
"""
Turn the component on.
:param task_callback: callback to be called when the status of
the command changes
:raises NotImplementedError: Not implemented it's an abstract class
"""
raise NotImplementedError(
"The station calibration solver is an always-on component."
)
[docs]
@check_communicating
def reset(
self: StationCalibrationSolverComponentManager,
task_callback: Callable | None = None,
) -> tuple[TaskStatus, str]:
"""
Reset the component (from fault state).
:param task_callback: callback to be called when the status of
the command changes
:raises NotImplementedError: Not implemented it's an abstract class
"""
raise NotImplementedError("The station calibration solver cannot be reset.")
# pylint: disable=too-many-locals
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-statements
[docs]
@check_communicating
def solve( # noqa: C901
self: StationCalibrationSolverComponentManager,
data_path: str,
solution_path: str,
station_config_path: tuple[str, str],
task_callback: Callable | None = None,
task_abort_event: threading.Event | None = None,
**kwargs: Any,
) -> None:
"""
Solve for a calibration solution.
:param data_path: path to the stored data to use for solving for calibration.
:param solution_path: the path of the solution to be written.
:param station_config_path: a list used to locate configuration from TelModel.
:param kwargs: any kwargs
:param task_callback: Update task state, defaults to None
:param task_abort_event: Check for abort, defaults to None
"""
def _report_status(status: TaskStatus, **kwargs: Any) -> None:
if task_callback is not None:
task_callback(status=status, **kwargs)
def _check_aborted() -> bool:
if task_abort_event and task_abort_event.is_set():
self.logger.info("Solve task has been aborted")
_report_status(
TaskStatus.ABORTED,
result=(ResultCode.ABORTED, "Task aborted"),
)
return True
return False
self.logger.info(f"Solving for calibration using data {data_path}...")
_report_status(TaskStatus.IN_PROGRESS)
if _check_aborted():
return
eep_filebase = kwargs.pop("eep_filebase", "")
self.logger.info(f"Loading dataset {data_path}...")
try:
observation_data = load_observation_data(str(self._root_path / data_path))
except FileNotFoundError as fnfe:
self.logger.info(f"Cannot open file {data_path}: {repr(fnfe)}")
_report_status(
TaskStatus.COMPLETED, result=(ResultCode.FAILED, "File not found error")
)
return
self.logger.info(f"Successfully loaded dataset {data_path}.")
if _check_aborted():
return
self.logger.info("Loading configuration data...")
# Get station configuration from TelModel using the station_config_path
(
location,
antenna_masks,
baselines,
enu_raw,
pol,
rotation,
) = self._get_station_config(station_config_path)
if _check_aborted():
return
self.logger.info("Packaging visibility...")
# Restructure select data.
masked_antennas = np.where(antenna_masks is True)[0]
enu = np.array(enu_raw, dtype=np.complex64(0).real.dtype).T
v_measurement = _package_visibility(
observation_data["correlation_times_array"],
observation_data["channel_id"],
observation_data["int_time"],
observation_data["itime"],
observation_data["correlations"],
baselines,
masked_antennas,
location,
pol,
enu,
)
if not kwargs.get("ignore_eeps", False):
eep_matches = list(
self.eep_root_path.glob(
f"{eep_filebase}{observation_data['channel_id']}MHz_*pol*.*"
)
)
if len(eep_matches) == 0:
self.logger.error(
f"No EEPs found with correct naming, Looking for:\n"
f"\t{str(self.eep_root_path)+str('/')+eep_filebase}"
f"{str(observation_data['channel_id'])}MHz_?pol.?"
)
_report_status(
TaskStatus.COMPLETED,
result=(ResultCode.FAILED, "File not found error"),
)
return
self._prepare_eeps(
eep_matches,
)
if _check_aborted():
return
self.logger.info("Calibrating visibility...")
gain_table, _, _, masked_antennas, calibration_info = calibrate_mccs_visibility(
vis=v_measurement,
masked_antennas=masked_antennas,
skymodel=kwargs.get("skymodel", "gsm"),
min_uv=kwargs.get("min_uv", 0),
refant=kwargs.get("refant", 1),
ignore_eeps=kwargs.get("ignore_eeps", False),
gain_threshold=kwargs.get("gain_threshold", 0.25),
eep_path=self.eep_store_path_to_use,
eep_rotation_deg=rotation,
eep_suffix=self.eep_suffix,
eep_filebase=eep_filebase,
jones_solve=kwargs.get("jones_solve", False),
back_rotation=kwargs.get("back_rotation", False),
adjust_solar_model=kwargs.get("adjust_solar_model", True),
nside=kwargs.get("nside", 32),
niter=kwargs.get("niter", 200),
)
if _check_aborted():
return
self.logger.info("Saving solution...")
jones_matrices: Any = gain_table.gain.values
flattened_matrices = jones_matrices.flatten()
calibration_solutions = np.array(
(np.real(flattened_matrices), np.imag(flattened_matrices))
).T.ravel()
self.logger.info(f"Storing products in location {solution_path}")
np.save(str(self._root_path / solution_path), calibration_solutions)
struc_version = kwargs.get("structure_version", LATEST_STRUCTURE_VERSION)
self.logger.info(f"File created {solution_path}.npy")
calibration_id = kwargs.get("calibration_id")
calibration_product = CalibrationSolutionProduct(
logger=self.logger,
structure_version=struc_version,
corrcoeff=calibration_info.corrcoeff,
residual_max=calibration_info.residual_max,
residual_std=calibration_info.residual_std,
xy_phase=calibration_info.xy_phase,
n_masked_initial=calibration_info.n_masked_initial,
n_masked_final=calibration_info.n_masked_final,
lst=calibration_info.lst,
galactic_centre_elevation=(calibration_info.galactic_centre_elevation),
sun_elevation=calibration_info.sun_elevation,
sun_adjustment_factor=calibration_info.sun_adjustment_factor,
masked_antennas=masked_antennas.tolist(),
solution=calibration_solutions.tolist(),
frequency_channel=observation_data["channel_id"],
station_id=observation_data["station_id"],
acquisition_time=observation_data["timestamp"],
calibration_id=calibration_id,
)
calibration_product.save_to_hdf5(str(self._root_path / solution_path) + ".h5")
self.logger.info(
f"Solution stored in location {str(self._root_path / solution_path)}"
)
self.logger.info(f"Solution tagged with id: {calibration_id}")
self._channel_file_mapping[observation_data["channel_id"]] = {
observation_data["station_id"]: solution_path
}
_report_status(
TaskStatus.COMPLETED,
result=(ResultCode.OK, "Solution successfully calculated"),
)
self.logger.info(f"Done solving for calibration using data {data_path}...")
def _get_station_config(
self: StationCalibrationSolverComponentManager,
station_config_path: tuple[str, str],
) -> Any:
platform_config = TMData([station_config_path[0]])[
station_config_path[1]
].get_dict()
# We (badly) need to have versioned parsers for this
station_name = list(find_by_key(platform_config, "stations"))[0]
station_config = find_by_key(
platform_config["platform"]["stations"], station_name
)
self.logger.info(f"reading config from {station_name}")
return read_station_config(station_config)
def _prepare_eeps(self, eep_matches: list[Path]) -> None:
for eep_match in eep_matches:
match eep_match.suffix:
case ".npz":
self.eep_store_path_to_use = str(self.eep_root_path)
self.eep_suffix = ".npz"
case ".npy":
self.eep_store_path_to_use = str(self.eep_root_path)
self.eep_suffix = ".npy"
case ".mat":
# must convert to .npy and use internal store.
self.eep_store_path_to_use = str(self.internal_eep_store)
self.eep_suffix = ".npy"
convert_eep2npy(
str(eep_match),
npy_dir=self.eep_store_path_to_use,
)
case _:
print(f"EEP extension {eep_match.suffix} not supported.")
def channel_id_to_freq(channel_id: int | float) -> float:
"""
Turn channel_id to frequency.
:param channel_id: Id of the channel
:return: The frequency of the channel.
"""
channel_bw_mhz = BANDWIDTH / NOF_CHANNELS
return float(channel_id) * channel_bw_mhz
# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments
def _package_visibility(
correlation_times_array: Any,
channel_id: Any,
int_time: Any,
itime: Any,
correlations: Any,
baselines: Any,
masked_antennas: Any,
location: Any,
pol: Any,
enu: Any,
) -> Any:
bl_flags = np.logical_or(
np.isin(baselines[0], masked_antennas),
np.isin(baselines[1], masked_antennas),
)
# May as well flag autos as well
bl_flags = np.logical_or(bl_flags, baselines[0] == baselines[1])
# Stack flags for the four polarisations in the dataset
vis_flags = np.tile(bl_flags, (4, 1)).T
frequency_mhz = channel_id_to_freq(channel_id)
time_array = Time(correlation_times_array, format="unix", location=location)
time_array.format = "fits"
data_time = time_array[itime]
vis = correlations
return sdp_visibility_datamodel(
vis=vis[:, pol],
flags=vis_flags[:, pol],
uvw=enu[baselines[0]] - enu[baselines[1]],
ant1=baselines[0],
ant2=baselines[1],
location=location,
antpos_enu=enu,
time=data_time,
int_time=int_time,
frequency_mhz=frequency_mhz,
)
# pylint: disable=too-many-locals
def load_observation_data(data_file_path: str) -> dict[str, Any]:
"""
Load the observation data from file and return it in a dictionary.
:param data_file_path: path to the observation visibility data.
:return: a dictionary containing the observation data.
"""
datafile = h5py.File(data_file_path, "r")
correlation_metadata = dict(datafile["root"].attrs)
ntimes = correlation_metadata["n_blocks"]
itime = ntimes - 1 # Index of timestamp we want
int_time = correlation_metadata["tsamp"]
n_ant = correlation_metadata["n_antennas"]
channel_id = correlation_metadata["channel_id"]
station_id = correlation_metadata["station_id"]
n_baselines = correlation_metadata["n_baselines"]
n_pol = correlation_metadata["n_pols"]
n_stokes = correlation_metadata["n_stokes"]
timestamp = correlation_metadata.pop("timestamp", -1)
# Chop off rounding errors
int_time = round(int_time, 12)
# Get the data associated with the correlation matrices
correlation_data = np.squeeze(datafile["correlation_matrix"]["data"])
# Get the timestamps associated with the correlation matrices
correlation_times = datafile["sample_timestamps"]
if ntimes > 1:
correlation_times_array = np.squeeze(correlation_times["data"])
correlations = correlation_data[itime]
else:
correlation_times_array = correlation_times["data"][itime]
correlations = correlation_data
return {
"correlations": correlations,
"itime": itime,
"int_time": int_time,
"timestamp": timestamp,
"correlation_times_array": correlation_times_array,
"n_ant": n_ant,
"n_stokes": n_stokes,
"n_baselines": n_baselines,
"n_pol": n_pol,
"channel_id": channel_id,
"station_id": station_id,
}
def find_by_key(data: dict, target: str) -> Any:
"""
Search nested dict breadth-first for the first target key and return its value.
This method is used to find station and antenna config within Low platform spec
files, and should eventually be replaced by functions specifically designed to
parse these files, aware of schema versions, etc, probably within ska-telmodel.
:param data: generic nested dictionary to traverse through.
:param target: key to find the first value of.
:returns: the next value for given key.
"""
bfs_queue = list(data.items())
while bfs_queue:
key, value = bfs_queue.pop(0)
if key == target:
return value
if isinstance(value, dict):
bfs_queue.extend(value.items())
return None