Source code for ska_sdp_spectral_line_imaging.msreader

import logging

import xarray as xr
from xradio.measurement_set.open_processing_set import open_processing_set

logger = logging.getLogger(__name__)


[docs] def msv2_to_datatree(input_path: str, chunks: dict) -> xr.DataTree: return xr.open_datatree(input_path, chunks=chunks, engine="xarray-ms:msv2")
[docs] class MSReader: def __init__(self, input_path: str, chunks: dict[str, int]): self._input_path = input_path self._dt = self.open_dataset(input_path, chunks)
[docs] @staticmethod def open_dataset(input_path: str, chunks: dict[str, int]) -> xr.DataTree: """Open a measurement set as an xarray DataTree. Attempts to read the input file as MSv4 (processing set) first, then falls back to MSv2 format if that fails. Parameters ---------- input_path : str Path to the measurement set file. chunks : dict[str, int] Dictionary specifying chunk sizes for dimensions. Required for reading MSv2 format. Returns ------- xr.DataTree The measurement set as an xarray DataTree object. Raises ------ Exception If reading as both MSv4 and MSv2 formats fails. """ try: logger.info("Trying to read the input as MSv4.") return open_processing_set(input_path) except FileNotFoundError: logger.info("Reading as MSv4 failed.") logger.info("Trying to read the input as MSv2.") return msv2_to_datatree(input_path, chunks) except Exception as e: logger.info("Reading input as either MSv4 or MSv2 failed.") raise e
def __get_base_groups(self) -> list[str]: return list(self._dt.keys()) def __select_obs(self, selected_obs: list[str] | str) -> list[str]: obs = self.__get_base_groups() if selected_obs == "all": return obs if isinstance(selected_obs, str): if selected_obs not in obs: raise ValueError( f"{selected_obs} is not a valid observation ID. " f"Available ones are {obs}" ) return [selected_obs] if isinstance(selected_obs, list): for name in selected_obs: if name not in obs: raise ValueError( f"Observation ID: {name} not found. " f"Available ones are {obs}" ) return selected_obs raise ValueError("Provide a valid obs id.") def __assign_xds(self, ds: xr.Dataset) -> xr.Dataset: obs = self.__get_base_groups() ds = ds.assign_attrs( { "antenna_xds": self._dt[obs[0]].antenna_xds.to_dataset( inherit=False ) } ) ds = ds.assign( { "VISIBILITY": ds.VISIBILITY.assign_attrs( { "field_and_source_xds": self._dt[ obs[0] ].field_and_source_base_xds.to_dataset(inherit=False) } ) } ) return ds
[docs] def get_datatree(self) -> xr.DataTree: """Get the underlying DataTree object. Returns ------- xr.DataTree The measurement set as an xarray DataTree object. """ return self._dt
[docs] def get_dataset(self, obs: list[str] | str) -> xr.Dataset: """Get dataset from the measurement set. Parameters ---------- obs : list[str] | str Observation ID(s) to retrieve. Can be a single observation ID as a string, a list of observation IDs, or "all" to retrieve all available observations. Returns ------- xr.Dataset The concatenated dataset containing the selected observations. Raises ------ ValueError If the specified observation ID(s) are not found in the measurement set. """ selected_obs = self.__select_obs(obs) selected_groups = [ self._dt[name].to_dataset().unify_chunks() for name in selected_obs ] # NOTE: There is an issue in either xradio/xarray/dask that causes # chunk sizes to be different for coordinate variables ds = xr.concat(selected_groups, dim="time").chunk({"time": -1}) return self.__assign_xds(ds)