Source code for low_comm_tools.ms_fixes.reverse_baselines

#!/usr/bin/env python

from __future__ import annotations

import argparse
from collections.abc import Generator
from pathlib import Path
from typing import Any, NamedTuple, TypeAlias, cast

import numpy as np
import numpy.typing as npt
from casacore.tables import table
from tqdm.auto import tqdm

from low_comm_tools.exceptions import VisError
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import ComplexArray
from low_comm_tools.vis_rotate import hermite_transpose

[docs] IntArray: TypeAlias = npt.NDArray[np.int64]
[docs] FloatArray: TypeAlias = npt.NDArray[np.floating[Any]]
[docs] class BaselineChunk(NamedTuple): """A chunk of rows read from the main table of a MeasurementSet, containing all columns needed to reverse baselines. """
[docs] data: dict[str, ComplexArray]
"""Visibility data keyed by column name. Each array has shape (chunk_size, n_chan, n_corr)."""
[docs] uvws: FloatArray
"""UVW coordinates, shape (chunk_size, 3)."""
[docs] ant_1: IntArray
"""ANTENNA1 indices, shape (chunk_size,)."""
[docs] ant_2: IntArray
"""ANTENNA2 indices, shape (chunk_size,)."""
[docs] row_start: int
"""Starting row index of this chunk within the full table."""
[docs] chunk_size: int
"""Number of rows in this chunk (may be less than requested on the final chunk)."""
[docs] def _iter_baseline_chunks( ms_table: table, present_columns: list[str], chunk_size: int, ) -> Generator[BaselineChunk, None, None]: """Iterate over rows of a MeasurementSet main table in chunks. Args: ms_table (table): An opened, writable MeasurementSet main table. present_columns (list[str]): Visibility column names to read. chunk_size (int): Number of rows to read per chunk. Yields: BaselineChunk: One chunk of rows at a time. """ table_length = len(ms_table) lower_row = 0 with tqdm(desc="Reversing baselines", total=table_length, unit="rows") as pbar: while lower_row < table_length: # On the final iteration the table may have fewer rows than chunk_size; # casacore will return however many rows remain. actual_chunk = min(chunk_size, table_length - lower_row) data = { col: ms_table.getcol(col, startrow=lower_row, nrow=actual_chunk) for col in present_columns } uvws = ms_table.getcol("UVW", startrow=lower_row, nrow=actual_chunk) ant_1 = ms_table.getcol("ANTENNA1", startrow=lower_row, nrow=actual_chunk) ant_2 = ms_table.getcol("ANTENNA2", startrow=lower_row, nrow=actual_chunk) yield BaselineChunk( data=data, uvws=uvws, ant_1=ant_1, ant_2=ant_2, row_start=lower_row, chunk_size=actual_chunk, ) lower_row += chunk_size pbar.update(chunk_size)
[docs] def check_ant_1_lt_ant_2(ms_path: Path) -> bool: """Checks the baseline ordering in an MS. Returns True if ANTENNA1<ANTENNA2 Args: ms_path (Path): Path to MS Raises: VisError: If baseliene ordering is mixed Returns: bool: True if ANTENNA1<ANTENNA2, False if ANTENNA1>ANTENNA2 """ with table(ms_path.as_posix(), readonly=False) as tab: antenna_1 = tab.getcol("ANTENNA1") antenna_2 = tab.getcol("ANTENNA2") if not ((antenna_1 >= antenna_2).all() or (antenna_1 <= antenna_2).all()): msg = f"{ms_path} does not have all ANTENNA1>ANTENNA2 or ANTENNA1<ANTENNA2!" raise VisError(msg) return cast(bool, (antenna_1 <= antenna_2).all())
[docs] def check_baseline_ordering(ms_path: Path) -> None: """Check if user is happy to continue. Will prompt for interaction if `check_ordering_is_good` returns True. Args: ms_path (Path): Path to MS """ if check_ant_1_lt_ant_2(ms_path): msg = f"{ms_path} has ANTENNA1<ANTENNA2 for all rows" else: msg = f"{ms_path} has ANTENNA1>ANTENNA2 for all rows" logger.info(msg)
[docs] def reverse_baselines(ms_path: Path, chunk_size: int = 1000) -> Path: """Reverse baseline order in an MS. Performs three steps on each row: 1. Hermite transpose (transpose + conjugate) the data 2. Swap ANTENNA1/ANTENNA2 label 3. Negate UVWs Args: ms_path (Path): Path to MS chunk_size (int, optional): Number of rows to process in a chunk. Defaults to 1000. Raises: RuntimeError: If no DATA-like are found Returns: Path: Path to modified MS """ _columns = [ "DATA", "CORRECTED_DATA", "MODEL_DATA", "WEIGHT", "WEIGHT_SPECTRUM", "FLAG", ] with table(ms_path.as_posix(), readonly=False) as tab: present_columns = [col for col in _columns if col in tab.colnames()] if not present_columns: msg = ( f"No recognised data columns found in {ms_path}. " f"Expected one of: {_columns}" ) raise RuntimeError(msg) logger.info(f"Processing {tab.nrows()} rows in: {ms_path}") logger.info(f"Will reverse baselines in {present_columns}") for chunk in _iter_baseline_chunks(tab, present_columns, chunk_size): row_start = chunk.row_start n_rows = chunk.chunk_size # Hermite transpose for col, vis in chunk.data.items(): vis_shape = vis.shape n_corr = vis_shape[-1] if n_corr == 4: # Need to apply on 2x2 matrix jones_shape = (*vis_shape[:-1], 2, 2) vis_corr = hermite_transpose(vis.reshape(*jones_shape)) vis_corr = vis_corr.reshape(vis_shape) else: # 1- or 2-pol: no 2x2 matrix structure, conjugate only vis_corr = vis.conj() assert vis_corr.shape == vis.shape, ( f"Column {col} changed shape! ({vis.shape} -> ({vis_corr.shape}))" ) tab.putcol(col, vis_corr, startrow=row_start, nrow=n_rows) # Swap ANTENNA1 <-> ANTENNA2 tab.putcol("ANTENNA1", chunk.ant_2, startrow=row_start, nrow=n_rows) tab.putcol("ANTENNA2", chunk.ant_1, startrow=row_start, nrow=n_rows) # Negate UVW tab.putcol("UVW", -chunk.uvws, startrow=row_start, nrow=n_rows) return ms_path
[docs] def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Reverse baseline ordering in a MeasurementSet", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("ms_path", type=Path, help="Path to MS file") parser.add_argument( "-c", "--chunksize", type=int, default=1000, help="Number of rows to process in a chunk", ) parser.add_argument( "-d", "--dry-run", action="store_true", help="Do not execute baseline reversal - just check the ANTENNA1/2 ordering", ) return parser
[docs] def main() -> None: parser = get_parser() args = parser.parse_args() check_baseline_ordering( args.ms_path, ) if args.dry_run: return _ = reverse_baselines( ms_path=args.ms_path, chunk_size=args.chunksize, )
if __name__ == "__main__": main()