import logging
import os
import pickle
from copy import deepcopy
from typing import List, Optional
import dask
import dask.array as da
import numpy as np
import numpy.typing as npt
import pandas
import xarray as xr
from astropy import units as u
from astropy.coordinates import EarthLocation, SkyCoord
from astropy.time import Time
from astropy.units import Quantity
from casacore.tables import table, taql
from dask.delayed import Delayed
from ska_sdp_datamodels.configuration.config_model import Configuration
from ska_sdp_datamodels.science_data_model.polarisation_model import (
PolarisationFrame,
ReceptorFrame,
)
from ska_sdp_datamodels.visibility.vis_model import Visibility
from ..xarray_processors import simplify_baselines_dim, with_chunks
logger = logging.getLogger(__name__)
ATTRS_FILE_NAME = "attrs.pickle"
VIS_FILE_NAME = "vis.zarr"
BASELINE_FILE_NAME = "baselines.pickle"
[docs]
def create_template_vis_from_ms(
msname: str,
ack: bool = False,
datacolumn: str = "DATA",
field_ids: List[int] = None,
data_desc_ids: List[int] = None,
) -> List[Visibility]:
"""
Create empty "template" visibility objects from a Measurement Set.
This function inspects the provided Measurement Set (MS) to determine
the shapes, types, and metadata required to create Visibility objects.
It returns a list of these objects where the data arrays (vis, flags,
weights, uvw) are initialized as empty Dask arrays. These templates
can be populated later.
Parameters
----------
msname : str
The file path to the Measurement Set.
ack : bool, optional
If True, print an acknowledgement message when opening the table.
Default is False.
datacolumn : str, optional
The name of the column in the MS to use for determining the data
type of the visibility data. Default is "DATA".
field_ids : list[int], optional
A list of field IDs to process. If None, defaults to [0].
data_desc_ids : list[int], optional
A list of data description IDs to process. If None, defaults to [0].
Returns
-------
list[Visibility]
A list of Visibility objects corresponding to the selected field
and data description IDs. The data arrays within are empty Dask
arrays.
Raises
------
ValueError
If the selection for a specific Field ID or Data Description ID
yields no rows in the MS.
KeyError
If the polarization configuration in the MS is not recognized.
"""
field_ids = field_ids or [0]
data_desc_ids = data_desc_ids or [0]
vis_list = []
with table(msname, ack=ack, readonly=True) as tab:
for field in field_ids:
with tab.query(f"FIELD_ID=={field}", style="") as ftab:
if ftab.nrows() <= 0:
raise ValueError(f"Empty selection for FIELD_ID={field}")
for dd in data_desc_ids:
with table(
f"{msname}/DATA_DESCRIPTION", ack=ack, readonly=True
) as ddtab:
spwid = ddtab.getcol("SPECTRAL_WINDOW_ID")[dd]
polid = ddtab.getcol("POLARIZATION_ID")[dd]
meta = {"MSV2": {"FIELD_ID": field, "DATA_DESC_ID": dd}}
with ftab.query(f"DATA_DESC_ID=={dd}", style="") as ms:
if ms.nrows() <= 0:
raise ValueError(
f"Empty selection for FIELD_ID= {field} "
f"and DATA_DESC_ID={dd}"
)
time = ms.getcol("TIME")
antenna1 = ms.getcol("ANTENNA1")
antenna2 = ms.getcol("ANTENNA2")
integration_time = ms.getcol("INTERVAL")
vis_dtype = ms.getcol(datacolumn, nrow=1).dtype
flags_dtype = ms.getcol("FLAG", nrow=1).dtype
weight_dtype = ms.getcol("WEIGHT", nrow=1).dtype
uvw_dtype = ms.getcol("UVW", nrow=1).dtype
start_time = np.min(time) / 86400.0
end_time = np.max(time) / 86400.0
logger.debug(
"create_visibility_from_ms: Observation from %s to %s",
Time(start_time, format="mjd").iso,
Time(end_time, format="mjd").iso,
)
with table(
f"{msname}/SPECTRAL_WINDOW", ack=ack, readonly=True
) as spwtab:
cfrequency = np.array(
spwtab.getcol("CHAN_FREQ")[spwid]
)
cchannel_bandwidth = np.array(
spwtab.getcol("CHAN_WIDTH")[spwid]
)
nchan = cfrequency.shape[0]
# Get polarisation info
with table(
f"{msname}/POLARIZATION", ack=ack, readonly=True
) as poltab:
corr_type = poltab.getcol("CORR_TYPE")[polid]
# These correspond to the CASA Stokes enumerations
if np.array_equal(corr_type, [1, 2, 3, 4]):
polarisation_frame = PolarisationFrame("stokesIQUV")
npol = 4
elif np.array_equal(corr_type, [1, 2]):
polarisation_frame = PolarisationFrame("stokesIQ")
npol = 2
elif np.array_equal(corr_type, [1, 4]):
polarisation_frame = PolarisationFrame("stokesIV")
npol = 2
elif np.array_equal(corr_type, [5, 6, 7, 8]):
polarisation_frame = PolarisationFrame("circular")
npol = 4
elif np.array_equal(corr_type, [5, 8]):
polarisation_frame = PolarisationFrame("circularnp")
npol = 2
elif np.array_equal(corr_type, [9, 10, 11, 12]):
polarisation_frame = PolarisationFrame("linear")
npol = 4
elif np.array_equal(corr_type, [9, 12, 10, 11]):
polarisation_frame = PolarisationFrame("linearFITS")
npol = 4
elif np.array_equal(corr_type, [9, 12]):
polarisation_frame = PolarisationFrame("linearnp")
npol = 2
elif np.array_equal(corr_type, [9]) or np.array_equal(
corr_type, [1]
):
npol = 1
polarisation_frame = PolarisationFrame("stokesI")
else:
raise KeyError(
f"Polarisation not understood: {str(corr_type)}"
)
# Get configuration
with table(
f"{msname}/ANTENNA", ack=ack, readonly=True
) as anttab:
names = np.array(anttab.getcol("NAME"))
# pylint: disable=cell-var-from-loop
ant_map = []
actual = 0
# This assumes that the names are actually filled in!
for name in names:
if name != "":
ant_map.append(actual)
actual += 1
else:
ant_map.append(-1)
if actual == 0:
ant_map = list(range(len(names)))
names = np.repeat("No name", len(names))
mount = np.array(anttab.getcol("MOUNT"))[names != ""]
# logger.info("mount is: %s" % (mount))
diameter = np.array(anttab.getcol("DISH_DIAMETER"))[
names != ""
]
xyz = np.array(anttab.getcol("POSITION"))[names != ""]
offset = np.array(anttab.getcol("OFFSET"))[names != ""]
stations = np.array(anttab.getcol("STATION"))[
names != ""
]
names = np.array(anttab.getcol("NAME"))[names != ""]
nants = len(names)
antenna1 = list(map(lambda i: ant_map[i], antenna1))
antenna2 = list(map(lambda i: ant_map[i], antenna2))
baselines = pandas.MultiIndex.from_arrays(
np.triu_indices(nants, k=0),
names=("antenna1", "antenna2"),
)
nbaselines = len(baselines)
location = EarthLocation(
x=Quantity(xyz[0][0], "m"),
y=Quantity(xyz[0][1], "m"),
z=Quantity(xyz[0][2], "m"),
)
configuration = Configuration.constructor(
name="",
location=location,
names=names,
xyz=xyz,
mount=mount,
frame="ITRF",
receptor_frame=ReceptorFrame("linear"),
diameter=diameter,
offset=offset,
stations=stations,
)
# Get phasecentres
with table(
f"{msname}/FIELD", ack=ack, readonly=True
) as fieldtab:
pc = fieldtab.getcol("PHASE_DIR")[field, 0, :]
source = fieldtab.getcol("NAME")[field]
phasecentre = SkyCoord(
ra=pc[0] * u.rad,
dec=pc[1] * u.rad,
frame="icrs",
equinox="J2000",
)
time_index_row = np.zeros_like(time, dtype="int")
time_last = time[0]
time_index = 0
for row, _ in enumerate(time):
if time[row] > time_last + 0.5 * integration_time[row]:
assert (
time[row] > time_last
), "MS is not time-sorted - cannot convert"
time_index += 1
time_last = time[row]
time_index_row[row] = time_index
ntimes = time_index + 1
assert ntimes == len(
np.unique(time_index_row)
), "Error in finding data times"
bv_vis = da.empty(
[ntimes, nbaselines, nchan, npol], dtype=vis_dtype
)
bv_flags = da.empty(
[ntimes, nbaselines, nchan, npol], dtype=flags_dtype
)
bv_weight = da.empty(
[ntimes, nbaselines, nchan, npol], dtype=weight_dtype
)
bv_uvw = da.empty([ntimes, nbaselines, 3], dtype=uvw_dtype)
bv_times = np.zeros([ntimes])
bv_integration_time = np.zeros([ntimes])
for row, _ in enumerate(time):
time_index = time_index_row[row]
bv_times[time_index] = time[row]
bv_integration_time[time_index] = integration_time[row]
vis_template = Visibility.constructor(
uvw=bv_uvw,
baselines=baselines,
time=bv_times,
frequency=cfrequency,
channel_bandwidth=cchannel_bandwidth,
vis=bv_vis,
flags=bv_flags,
weight=bv_weight,
integration_time=bv_integration_time,
configuration=configuration,
phasecentre=phasecentre,
polarisation_frame=polarisation_frame,
source=source,
meta=meta,
low_precision="float64",
)
# Need to reassign with correct dtype
vis_template = vis_template.assign(
{
"weight": vis_template.weight.astype(weight_dtype),
"flags": vis_template.flags.astype(flags_dtype),
}
)
vis_list.append(vis_template)
return vis_list
[docs]
def get_col_from_ms(
msname: str,
colname: str,
start_time_idx: int,
ntimes: int,
num_baselines: int,
ack=False,
field_ids: List[int] = None,
data_desc_ids: List[int] = None,
) -> List[npt.NDArray]:
"""
Extract data from a specific column in a Measurement Set.
This function reads a slice of data from the specified column, determined
by a starting time index, a duration (number of times), and the number of
baselines. It iterates over the specified Field IDs and Data Description
IDs, returning the extracted data for each combination.
Parameters
----------
msname : str
The file path to the Measurement Set.
colname : str
The name of the column to retrieve (e.g., "DATA", "UVW", "FLAG").
start_time_idx : int
The index of the starting time step to read. This is used to calculate
the starting row offset: ``start_time_idx * num_baselines``.
ntimes : int
The number of time steps to read.
num_baselines : int
The number of baselines per time step. Used to calculate the total
number of rows to read.
ack : bool, optional
If True, print an acknowledgement message when opening the table.
Default is False.
field_ids : list[int], optional
A list of Field IDs to query. If None, defaults to [0].
data_desc_ids : list[int], optional
A list of Data Description IDs to query. If None, defaults to [0].
Returns
-------
list[numpy.ndarray]
A list of NumPy arrays containing the column data. Each element in the
list corresponds to the data extracted for a specific combination of
Field ID and Data Description ID.
Raises
------
ValueError
If the query for a specific Field ID or Data Description ID returns
zero rows (empty selection).
"""
field_ids = field_ids or [0]
data_desc_ids = data_desc_ids or [0]
start_row = start_time_idx * num_baselines
n_rows = ntimes * num_baselines
col_data_per_field_dd = []
with table(msname, ack=ack, readonly=True) as tab:
for field in field_ids:
with tab.query(f"FIELD_ID=={field}", style="") as ftab:
if ftab.nrows() <= 0:
raise ValueError(f"Empty selection for FIELD_ID={field}")
for dd in data_desc_ids:
with ftab.query(f"DATA_DESC_ID=={dd}", style="") as ms:
if ms.nrows() <= 0:
raise ValueError(
f"Empty selection for FIELD_ID= {field} "
f"and DATA_DESC_ID={dd}"
)
col_data = ms.getcol(
colname, startrow=start_row, nrow=n_rows
)
col_data_per_field_dd.append(col_data)
return col_data_per_field_dd
def _load_vis_xdr(
vis_chunk: xr.DataArray,
ms_name: str,
time_index_xdr: xr.DataArray,
vis_baseline_indices_to_update: np.ndarray,
num_baselines_in_ms: int,
crosscorr_mask_over_baseline: Optional[np.ndarray] = None,
polarisation_order: Optional[np.ndarray] = None,
datacolumn: str = "DATA",
field_id: int = 0,
data_desc_id: int = 0,
):
start_time_idx = time_index_xdr.data[0]
ntimes = time_index_xdr.size
ms_data_shape = (
vis_chunk.shape[0],
num_baselines_in_ms,
*vis_chunk.shape[2:],
)
ms_data = get_col_from_ms(
ms_name,
colname=datacolumn,
start_time_idx=start_time_idx,
ntimes=ntimes,
num_baselines=num_baselines_in_ms,
field_ids=[field_id],
data_desc_ids=[data_desc_id],
)[0].reshape(ms_data_shape)
if crosscorr_mask_over_baseline is not None:
ms_data[:, crosscorr_mask_over_baseline, ...] = np.conj(
ms_data[:, crosscorr_mask_over_baseline, ...]
)
if polarisation_order is not None:
ms_data[:, crosscorr_mask_over_baseline, ...] = ms_data[
:, crosscorr_mask_over_baseline, ...
][..., polarisation_order]
actual_vis_data = np.zeros_like(vis_chunk)
actual_vis_data[:, vis_baseline_indices_to_update, ...] = ms_data
del ms_data
return xr.DataArray(actual_vis_data, coords=vis_chunk.coords)
def _load_flags_xdr(
flags_chunk: xr.DataArray,
ms_name: str,
time_index_xdr: xr.DataArray,
vis_baseline_indices_to_update: np.ndarray,
num_baselines_in_ms: int,
crosscorr_mask_over_baseline: Optional[np.ndarray] = None,
polarisation_order: Optional[np.ndarray] = None,
field_id: int = 0,
data_desc_id: int = 0,
):
start_time_idx = time_index_xdr.data[0]
ntimes = time_index_xdr.size
flag_data_shape = (
flags_chunk.shape[0],
num_baselines_in_ms,
*flags_chunk.shape[2:],
)
flag_data = get_col_from_ms(
ms_name,
colname="FLAG",
start_time_idx=start_time_idx,
ntimes=ntimes,
num_baselines=num_baselines_in_ms,
field_ids=[field_id],
data_desc_ids=[data_desc_id],
)[0].reshape(flag_data_shape)
if (
crosscorr_mask_over_baseline is not None
and polarisation_order is not None
):
flag_data[:, crosscorr_mask_over_baseline, ...] = flag_data[
:, crosscorr_mask_over_baseline, ...
][..., polarisation_order]
actual_flags_data = np.zeros_like(flags_chunk)
actual_flags_data[:, vis_baseline_indices_to_update, ...] = flag_data
del flag_data
return xr.DataArray(actual_flags_data, coords=flags_chunk.coords)
def _load_weight_xdr(
weight_chunk: xr.DataArray,
ms_name: str,
time_index_xdr: xr.DataArray,
vis_baseline_indices_to_update: np.ndarray,
num_baselines_in_ms: int,
crosscorr_mask_over_baseline: Optional[np.ndarray] = None,
polarisation_order: Optional[np.ndarray] = None,
field_id: int = 0,
data_desc_id: int = 0,
):
start_time_idx = time_index_xdr.data[0]
ntimes = time_index_xdr.size
weight_data_shape = (
weight_chunk.shape[0],
num_baselines_in_ms,
weight_chunk.shape[-1],
)
weight_data = get_col_from_ms(
ms_name,
colname="WEIGHT",
start_time_idx=start_time_idx,
ntimes=ntimes,
num_baselines=num_baselines_in_ms,
field_ids=[field_id],
data_desc_ids=[data_desc_id],
)[0].reshape(weight_data_shape)
if (
crosscorr_mask_over_baseline is not None
and polarisation_order is not None
):
weight_data[:, crosscorr_mask_over_baseline, ...] = weight_data[
:, crosscorr_mask_over_baseline, ...
][..., polarisation_order]
actual_weight_data = np.zeros_like(weight_chunk)
actual_weight_data[:, vis_baseline_indices_to_update, ...] = weight_data[
:, :, np.newaxis, ...
]
del weight_data
return xr.DataArray(actual_weight_data, coords=weight_chunk.coords)
def _load_uvw_xdr(
uvw_chunk: xr.DataArray,
ms_name: str,
time_index_xdr: xr.DataArray,
vis_baseline_indices_to_update: np.ndarray,
num_baselines_in_ms: int,
crosscorr_mask_over_baseline: Optional[np.ndarray] = None,
field_id: int = 0,
data_desc_id: int = 0,
):
start_time_idx = time_index_xdr.data[0]
ntimes = time_index_xdr.size
uvw_data_shape = (uvw_chunk.shape[0], num_baselines_in_ms, 3)
uvw_data = get_col_from_ms(
ms_name,
colname="UVW",
start_time_idx=start_time_idx,
ntimes=ntimes,
num_baselines=num_baselines_in_ms,
field_ids=[field_id],
data_desc_ids=[data_desc_id],
)[0].reshape(uvw_data_shape)
# This sign switch was done in the original converter in data models
uvw_data = -1 * uvw_data
if crosscorr_mask_over_baseline is not None:
uvw_data[:, crosscorr_mask_over_baseline, :] *= -1
actual_uvw_data = np.zeros_like(uvw_chunk)
actual_uvw_data[:, vis_baseline_indices_to_update, ...] = uvw_data
del uvw_data
return xr.DataArray(actual_uvw_data, coords=uvw_chunk.coords)
def _load_data_vars(
vis: Visibility,
ms_name: str,
datacolumn: str = "DATA",
field_id: int = 0,
data_desc_id: int = 0,
):
"""
Pre-requisites:
* vis dimensions:
time, baselineid, frequency, polarisation
Measurement set "data" dimensions:
rows (time * baselineid), frequency, polarisation
* Measurement set may or may not contain auto-correlated values,
but the visibility always expects auto-correlated values.
Thus necessary conversions are made here. In case auto-corrs
are absent in MS, auto-correlations in the Visibility dataset
are set to zero.
"""
with table(ms_name, readonly=True, ack=False) as tab:
with tab.query(
f"FIELD_ID=={field_id} AND DATA_DESC_ID=={data_desc_id}", style=""
) as ms:
num_rows_in_ms = ms.nrows()
if num_rows_in_ms <= 0:
raise ValueError(
f"Empty selection for FIELD_ID={field_id} "
f"and DATA_DESC_ID={data_desc_id}"
)
ms_contains_autocorrelations = False
if (
taql(
"select ANTENNA1 from $1 where "
"ANTENNA1 == ANTENNA2 limit 1",
tables=[ms],
).nrows()
> 0
):
ms_contains_autocorrelations = True
logger.info(
"Does measurement set contain autocorrelations? %s",
ms_contains_autocorrelations,
)
ms_is_baseline_order_reversed = False
if (
taql(
"select ANTENNA1 from $1 where "
"ANTENNA1 > ANTENNA2 limit 1",
tables=[ms],
).nrows()
> 0
):
ms_is_baseline_order_reversed = True
logger.info(
"In the measurement set, is the baseline antenna "
"order reversed (i.e. is antenna1 > antenna2)? %s",
ms_is_baseline_order_reversed,
)
if ms_is_baseline_order_reversed and (
taql(
"select ANTENNA1 from $1 where "
"ANTENNA1 < ANTENNA2 limit 1",
tables=[ms],
).nrows()
> 0
):
raise RuntimeError(
"Order of antennas in baseline pairs is not consistent."
)
ms_ant1_col = ms.getcol("ANTENNA1")
ms_ant2_col = ms.getcol("ANTENNA2")
time_index_xdr = xr.DataArray(
da.arange(vis.time.size), coords={"time": vis.time}
).pipe(with_chunks, vis.chunksizes)
nantennas = vis.configuration.id.size
# Visibility always has baselines with autocorrelations,
# and order antenna1 <= antenna2
vis_baseline_indices = pandas.MultiIndex.from_arrays(
np.triu_indices(nantennas, k=0), names=("antenna1", "antenna2")
)
if ms_is_baseline_order_reversed:
ms_baseline_indices_order = slice(None, None, -1)
else:
ms_baseline_indices_order = slice(None, None, None)
ms_baseline_indices = pandas.MultiIndex.from_arrays(
[ms_ant1_col, ms_ant2_col],
names=("antenna1", "antenna2"),
).unique()
num_baselines_in_ms = len(ms_baseline_indices)
vis_baseline_indices_to_update = np.array(
[
vis_baseline_indices.get_loc(indices[ms_baseline_indices_order])
for indices in ms_baseline_indices
]
)
crosscorr_baseline_mask = None
polarisation_order = None
if ms_is_baseline_order_reversed:
crosscorr_baseline_mask = ms_baseline_indices.get_level_values(
"antenna1"
) != ms_baseline_indices.get_level_values("antenna2")
if vis._polarisation_frame in ["linear", "circular"]:
polarisation_order = [0, 2, 1, 3]
elif vis._polarisation_frame == "linearFITS":
polarisation_order = [0, 1, 3, 2]
else:
raise RuntimeError(
"Unsupported polarisation frame '%s' "
"when antenna order in baselines is reversed",
vis._polarisation_frame,
)
# vis
vis_data_xdr = xr.map_blocks(
_load_vis_xdr,
vis.vis,
args=[
ms_name,
time_index_xdr,
vis_baseline_indices_to_update,
num_baselines_in_ms,
crosscorr_baseline_mask,
polarisation_order,
datacolumn,
field_id,
data_desc_id,
],
template=vis.vis,
)
# flags
flag_data_xdr = xr.map_blocks(
_load_flags_xdr,
vis.flags,
args=[
ms_name,
time_index_xdr,
vis_baseline_indices_to_update,
num_baselines_in_ms,
crosscorr_baseline_mask,
polarisation_order,
field_id,
data_desc_id,
],
template=vis.flags,
)
# weight
weight_data_xdr = xr.map_blocks(
_load_weight_xdr,
vis.weight,
args=[
ms_name,
time_index_xdr,
vis_baseline_indices_to_update,
num_baselines_in_ms,
crosscorr_baseline_mask,
polarisation_order,
field_id,
data_desc_id,
],
template=vis.weight,
)
# uvw
uvw_data_xdr = xr.map_blocks(
_load_uvw_xdr,
vis.uvw,
args=[
ms_name,
time_index_xdr,
vis_baseline_indices_to_update,
num_baselines_in_ms,
crosscorr_baseline_mask,
field_id,
data_desc_id,
],
template=vis.uvw,
)
return vis.assign(
{
"vis": vis_data_xdr,
"flags": flag_data_xdr,
"weight": weight_data_xdr,
"uvw": uvw_data_xdr,
}
)
[docs]
def load_ms_as_dataset_with_time_chunks(
ms_name: str,
times_per_chunk: int,
ack: bool = False,
datacolumn: str = "DATA",
field_id: int = 0,
data_desc_id: int = 0,
) -> Visibility:
"""
Load MSv2 data into a Visibility dataset using distributed time chunks.
This function loads data for a specific field and data description ID into
a Visibility object. The loading is distributed, chunking the data along
the time axis to facilitate parallel processing (e.g., with Dask).
Parameters
----------
ms_name : str
The file path to the Measurement Set.
times_per_chunk : int
The number of time steps to include in each Dask chunk.
ack : bool, optional
If True, print an acknowledgement message when opening the table.
Default is False.
datacolumn : str, optional
The name of the column to read (e.g., "DATA"). Default is "DATA".
field_id : int, optional
The Field ID to load. Default is 0.
data_desc_id : int, optional
The Data Description ID to load. Default is 0.
Returns
-------
Visibility
The loaded Visibility dataset with dask-backed arrays.
Notes
-----
The `baselines` dimension in the returned dataset is simplified to a
NumPy array of baseline IDs, rather than the standard Pandas MultiIndex
used by the Visibility class. This modification is necessary because
`xarray` operations like `map_blocks` do not support Pandas MultiIndex
coordinates.
**Important:** You must restore the baselines to the original Pandas
MultiIndex format before passing this object to any functions in
`ska-sdp-func-python`.
"""
# Get observation metadata
vis_template = simplify_baselines_dim(
create_template_vis_from_ms(
ms_name,
ack=ack,
datacolumn=datacolumn,
field_ids=[field_id],
data_desc_ids=[data_desc_id],
)[0]
)
chunks = {
"time": times_per_chunk,
"baselineid": -1,
"frequency": -1,
"polarisation": -1,
"spatial": -1,
}
vis_template = vis_template.pipe(with_chunks, chunks)
return _load_data_vars(
vis_template, ms_name, datacolumn, field_id, data_desc_id
)
def _generate_file_paths_for_vis_zarr_file(vis_cache_directory):
attributes_file = os.path.join(vis_cache_directory, ATTRS_FILE_NAME)
baselines_file = os.path.join(vis_cache_directory, BASELINE_FILE_NAME)
vis_zarr_file = os.path.join(vis_cache_directory, VIS_FILE_NAME)
return attributes_file, baselines_file, vis_zarr_file
[docs]
def write_ms_to_zarr(
input_ms_paths: str | list[str],
vis_cache_directory,
zarr_chunks,
ack=False,
datacolumn="DATA",
field_id: int = 0,
data_desc_id: int = 0,
):
"""
Convert a MSv2 into a Visibility dataset and write it to zarr.
NOTE: The baselines coordinates in Visibility are simplified.
See note section in :py:func:`load_ms_as_dataset_with_time_chunks`
"""
if isinstance(input_ms_paths, str):
input_ms_paths = [input_ms_paths]
visibilities = [
load_ms_as_dataset_with_time_chunks(
input_ms,
zarr_chunks["time"],
ack=ack,
datacolumn=datacolumn,
field_id=field_id,
data_desc_id=data_desc_id,
)
for input_ms in input_ms_paths
]
visibility = xr.concat(visibilities, dim="time", data_vars="minimal")
writer = write_visibility_to_zarr(
vis_cache_directory, zarr_chunks, visibility
)
logger.warning("Triggering eager compute to dump visibilities to zarr.")
dask.compute(writer)
[docs]
def write_visibility_to_zarr(
directory_to_write, zarr_chunks, visibility: Visibility
) -> Delayed:
"""
Writes Visibility to zarr file in the provided directory.
Since native xarray.to_zarr() function does not allow writing
python-object like attributes and coordinates, this function
first writes the attributes and "baselines" coordinate values as
python pickeled files, and removed them from visibility.
Then writes the rest of the visibility to a zarr file.
Returns
-------
dask.delayed
Returns a dask delayed zarr writer task which the user
needs to call compute on to write the actual visibilities.
"""
attributes_file, baselines_file, vis_zarr_file = (
_generate_file_paths_for_vis_zarr_file(directory_to_write)
)
attrs = deepcopy(visibility.attrs)
with open(attributes_file, "wb") as file:
pickle.dump(attrs, file)
baselines = deepcopy(visibility.baselines).compute()
with open(baselines_file, "wb") as file:
pickle.dump(baselines, file)
writer = (
visibility.drop_attrs()
.drop_vars("baselines")
.pipe(with_chunks, zarr_chunks)
.to_zarr(vis_zarr_file, mode="w", compute=False)
)
return writer
[docs]
def read_visibility_from_zarr(vis_cache_directory, vis_chunks) -> Visibility:
"""
Read a Visibility dataset from a Zarr cache directory.
This function reconstructs a Visibility object by opening the main Zarr
storage and manually reloading metadata that cannot be natively stored in
Zarr (such as complex object attributes and Pandas MultiIndex baselines)
from separate pickle files.
Parameters
----------
vis_cache_directory : str
The path to the directory containing the cached Zarr store and
associated metadata pickle files.
vis_chunks : dict
The chunking scheme to apply when opening the dataset (e.g.,
``{'time': 1, 'frequency': 10}``). Passed directly to
``xr.open_dataset``.
Returns
-------
Visibility
The fully reconstructed Visibility dataset with attributes and
baseline coordinates restored.
"""
attributes_file, baselines_file, vis_zarr_file = (
_generate_file_paths_for_vis_zarr_file(vis_cache_directory)
)
zarr_data = xr.open_dataset(
vis_zarr_file, chunks=vis_chunks, engine="zarr"
)
with open(attributes_file, "rb") as file:
attrs = pickle.load(file)
zarr_data = zarr_data.assign_attrs(attrs)
with open(baselines_file, "rb") as file:
baselines = pickle.load(file)
zarr_data = zarr_data.assign({"baselines": baselines})
# Explictly load antenna1 and antenna2 coordinates
zarr_data.antenna1.load()
zarr_data.antenna2.load()
return zarr_data
[docs]
def check_if_cache_files_exist(vis_cache_directory):
"""
Verify if the necessary cache files exist in the specified directory.
This function checks for the presence of three specific artifacts required
to reconstruct a Visibility dataset: the attributes pickle file, the
baselines pickle file, and the Zarr directory itself.
Parameters
----------
vis_cache_directory : str
The path to the directory to inspect.
Returns
-------
bool
True if all required files and directories exist; False otherwise.
"""
attributes_file, baselines_file, vis_zarr_file = (
_generate_file_paths_for_vis_zarr_file(vis_cache_directory)
)
return (
os.path.isfile(attributes_file)
and os.path.isfile(baselines_file)
and os.path.isdir(vis_zarr_file)
)
[docs]
def read_ms_field_id(ms_path: str):
"""
Read field ID from measurement set.
Parameters
----------
ms_path : str
Absolute path to the measurement set.
Returns
-------
field_id : str
Field ID of the measurement set.
"""
with table(ms_path + "/FIELD") as tb:
field_id = tb.getcol("NAME")[0]
return field_id if field_id else "UNKNOWN_FIELD"