Source code for ska_sdp_spectral_line_imaging.msreader
import logging
import xarray as xr
from xradio.measurement_set.open_processing_set import open_processing_set
logger = logging.getLogger(__name__)
[docs]
def msv2_to_datatree(input_path: str, chunks: dict) -> xr.DataTree:
return xr.open_datatree(input_path, chunks=chunks, engine="xarray-ms:msv2")
[docs]
class MSReader:
def __init__(self, input_path: str, chunks: dict[str, int]):
self._input_path = input_path
self._dt = self.open_dataset(input_path, chunks)
[docs]
@staticmethod
def open_dataset(input_path: str, chunks: dict[str, int]) -> xr.DataTree:
"""Open a measurement set as an xarray DataTree.
Attempts to read the input file as MSv4 (processing set) first, then
falls back to MSv2 format if that fails.
Parameters
----------
input_path : str
Path to the measurement set file.
chunks : dict[str, int]
Dictionary specifying chunk sizes for dimensions.
Required for reading MSv2 format.
Returns
-------
xr.DataTree
The measurement set as an xarray DataTree object.
Raises
------
Exception
If reading as both MSv4 and MSv2 formats fails.
"""
try:
logger.info("Trying to read the input as MSv4.")
return open_processing_set(input_path)
except FileNotFoundError:
logger.info("Reading as MSv4 failed.")
logger.info("Trying to read the input as MSv2.")
return msv2_to_datatree(input_path, chunks)
except Exception as e:
logger.info("Reading input as either MSv4 or MSv2 failed.")
raise e
def __get_base_groups(self) -> list[str]:
return list(self._dt.keys())
def __select_obs(self, selected_obs: list[str] | str) -> list[str]:
obs = self.__get_base_groups()
if selected_obs == "all":
return obs
if isinstance(selected_obs, str):
if selected_obs not in obs:
raise ValueError(
f"{selected_obs} is not a valid observation ID. "
f"Available ones are {obs}"
)
return [selected_obs]
if isinstance(selected_obs, list):
for name in selected_obs:
if name not in obs:
raise ValueError(
f"Observation ID: {name} not found. "
f"Available ones are {obs}"
)
return selected_obs
raise ValueError("Provide a valid obs id.")
def __assign_xds(self, ds: xr.Dataset) -> xr.Dataset:
obs = self.__get_base_groups()
ds = ds.assign_attrs(
{
"antenna_xds": self._dt[obs[0]].antenna_xds.to_dataset(
inherit=False
)
}
)
ds = ds.assign(
{
"VISIBILITY": ds.VISIBILITY.assign_attrs(
{
"field_and_source_xds": self._dt[
obs[0]
].field_and_source_base_xds.to_dataset(inherit=False)
}
)
}
)
return ds
[docs]
def get_datatree(self) -> xr.DataTree:
"""Get the underlying DataTree object.
Returns
-------
xr.DataTree
The measurement set as an xarray DataTree object.
"""
return self._dt
[docs]
def get_dataset(self, obs: list[str] | str) -> xr.Dataset:
"""Get dataset from the measurement set.
Parameters
----------
obs : list[str] | str
Observation ID(s) to retrieve. Can be a single observation ID as
a string, a list of observation IDs, or "all" to retrieve all
available observations.
Returns
-------
xr.Dataset
The concatenated dataset containing the selected observations.
Raises
------
ValueError
If the specified observation ID(s) are not found in the
measurement set.
"""
selected_obs = self.__select_obs(obs)
selected_groups = [
self._dt[name].to_dataset().unify_chunks() for name in selected_obs
]
# NOTE: There is an issue in either xradio/xarray/dask that causes
# chunk sizes to be different for coordinate variables
ds = xr.concat(selected_groups, dim="time").chunk({"time": -1})
return self.__assign_xds(ds)