#!/usr/bin/env python
from __future__ import annotations
import argparse
from collections.abc import Generator
from pathlib import Path
from typing import Any, NamedTuple, TypeAlias, cast
import numpy as np
import numpy.typing as npt
from casacore.tables import table
from tqdm.auto import tqdm
from low_comm_tools.exceptions import VisError
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import ComplexArray
from low_comm_tools.vis_rotate import hermite_transpose
[docs]
IntArray: TypeAlias = npt.NDArray[np.int64]
[docs]
FloatArray: TypeAlias = npt.NDArray[np.floating[Any]]
[docs]
class BaselineChunk(NamedTuple):
"""A chunk of rows read from the main table of a MeasurementSet,
containing all columns needed to reverse baselines.
"""
[docs]
data: dict[str, ComplexArray]
"""Visibility data keyed by column name. Each array has shape (chunk_size, n_chan, n_corr)."""
"""UVW coordinates, shape (chunk_size, 3)."""
"""ANTENNA1 indices, shape (chunk_size,)."""
"""ANTENNA2 indices, shape (chunk_size,)."""
"""Starting row index of this chunk within the full table."""
"""Number of rows in this chunk (may be less than requested on the final chunk)."""
[docs]
def _iter_baseline_chunks(
ms_table: table,
present_columns: list[str],
chunk_size: int,
) -> Generator[BaselineChunk, None, None]:
"""Iterate over rows of a MeasurementSet main table in chunks.
Args:
ms_table (table): An opened, writable MeasurementSet main table.
present_columns (list[str]): Visibility column names to read.
chunk_size (int): Number of rows to read per chunk.
Yields:
BaselineChunk: One chunk of rows at a time.
"""
table_length = len(ms_table)
lower_row = 0
with tqdm(desc="Reversing baselines", total=table_length, unit="rows") as pbar:
while lower_row < table_length:
# On the final iteration the table may have fewer rows than chunk_size;
# casacore will return however many rows remain.
actual_chunk = min(chunk_size, table_length - lower_row)
data = {
col: ms_table.getcol(col, startrow=lower_row, nrow=actual_chunk)
for col in present_columns
}
uvws = ms_table.getcol("UVW", startrow=lower_row, nrow=actual_chunk)
ant_1 = ms_table.getcol("ANTENNA1", startrow=lower_row, nrow=actual_chunk)
ant_2 = ms_table.getcol("ANTENNA2", startrow=lower_row, nrow=actual_chunk)
yield BaselineChunk(
data=data,
uvws=uvws,
ant_1=ant_1,
ant_2=ant_2,
row_start=lower_row,
chunk_size=actual_chunk,
)
lower_row += chunk_size
pbar.update(chunk_size)
[docs]
def check_ant_1_lt_ant_2(ms_path: Path) -> bool:
"""Checks the baseline ordering in an MS.
Returns True if ANTENNA1<ANTENNA2
Args:
ms_path (Path): Path to MS
Raises:
VisError: If baseliene ordering is mixed
Returns:
bool: True if ANTENNA1<ANTENNA2, False if ANTENNA1>ANTENNA2
"""
with table(ms_path.as_posix(), readonly=False) as tab:
antenna_1 = tab.getcol("ANTENNA1")
antenna_2 = tab.getcol("ANTENNA2")
if not ((antenna_1 >= antenna_2).all() or (antenna_1 <= antenna_2).all()):
msg = f"{ms_path} does not have all ANTENNA1>ANTENNA2 or ANTENNA1<ANTENNA2!"
raise VisError(msg)
return cast(bool, (antenna_1 <= antenna_2).all())
[docs]
def check_baseline_ordering(ms_path: Path) -> None:
"""Check if user is happy to continue.
Will prompt for interaction if `check_ordering_is_good` returns True.
Args:
ms_path (Path): Path to MS
"""
if check_ant_1_lt_ant_2(ms_path):
msg = f"{ms_path} has ANTENNA1<ANTENNA2 for all rows"
else:
msg = f"{ms_path} has ANTENNA1>ANTENNA2 for all rows"
logger.info(msg)
[docs]
def reverse_baselines(ms_path: Path, chunk_size: int = 1000) -> Path:
"""Reverse baseline order in an MS.
Performs three steps on each row:
1. Hermite transpose (transpose + conjugate) the data
2. Swap ANTENNA1/ANTENNA2 label
3. Negate UVWs
Args:
ms_path (Path): Path to MS
chunk_size (int, optional): Number of rows to process in a chunk. Defaults to 1000.
Raises:
RuntimeError: If no DATA-like are found
Returns:
Path: Path to modified MS
"""
_columns = [
"DATA",
"CORRECTED_DATA",
"MODEL_DATA",
"WEIGHT",
"WEIGHT_SPECTRUM",
"FLAG",
]
with table(ms_path.as_posix(), readonly=False) as tab:
present_columns = [col for col in _columns if col in tab.colnames()]
if not present_columns:
msg = (
f"No recognised data columns found in {ms_path}. "
f"Expected one of: {_columns}"
)
raise RuntimeError(msg)
logger.info(f"Processing {tab.nrows()} rows in: {ms_path}")
logger.info(f"Will reverse baselines in {present_columns}")
for chunk in _iter_baseline_chunks(tab, present_columns, chunk_size):
row_start = chunk.row_start
n_rows = chunk.chunk_size
# Hermite transpose
for col, vis in chunk.data.items():
vis_shape = vis.shape
n_corr = vis_shape[-1]
if n_corr == 4:
# Need to apply on 2x2 matrix
jones_shape = (*vis_shape[:-1], 2, 2)
vis_corr = hermite_transpose(vis.reshape(*jones_shape))
vis_corr = vis_corr.reshape(vis_shape)
else:
# 1- or 2-pol: no 2x2 matrix structure, conjugate only
vis_corr = vis.conj()
assert vis_corr.shape == vis.shape, (
f"Column {col} changed shape! ({vis.shape} -> ({vis_corr.shape}))"
)
tab.putcol(col, vis_corr, startrow=row_start, nrow=n_rows)
# Swap ANTENNA1 <-> ANTENNA2
tab.putcol("ANTENNA1", chunk.ant_2, startrow=row_start, nrow=n_rows)
tab.putcol("ANTENNA2", chunk.ant_1, startrow=row_start, nrow=n_rows)
# Negate UVW
tab.putcol("UVW", -chunk.uvws, startrow=row_start, nrow=n_rows)
return ms_path
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Reverse baseline ordering in a MeasurementSet",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", type=Path, help="Path to MS file")
parser.add_argument(
"-c",
"--chunksize",
type=int,
default=1000,
help="Number of rows to process in a chunk",
)
parser.add_argument(
"-d",
"--dry-run",
action="store_true",
help="Do not execute baseline reversal - just check the ANTENNA1/2 ordering",
)
return parser
[docs]
def main() -> None:
parser = get_parser()
args = parser.parse_args()
check_baseline_ordering(
args.ms_path,
)
if args.dry_run:
return
_ = reverse_baselines(
ms_path=args.ms_path,
chunk_size=args.chunksize,
)
if __name__ == "__main__":
main()