Source code for low_comm_tools.ms_fixes.fix_weights

from __future__ import annotations

import argparse
from collections.abc import Generator
from pathlib import Path
from typing import Any, NamedTuple

import astropy.units as u
import numpy as np
import numpy.typing as npt
from casacore.tables import table
from tqdm.auto import tqdm

from low_comm_tools import ms_utils
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import FloatArray


[docs] def fd_to_fraction( fd_arr: npt.NDArray[np.floating[Any]], ) -> npt.NDArray[np.floating[Any]]: return (fd_arr / 255) ** 2
[docs] class DataChunk(NamedTuple):
[docs] weight: FloatArray | None
[docs] weight_spectrum: FloatArray | None
[docs] interval: u.Quantity[u.s]
[docs] row_start: int
[docs] chunk_size: int
[docs] def _iter_chunks( ms_table: table, chunk_size: int, ) -> Generator[DataChunk, None, None]: """Iterate over rows of a MeasurementSet main table in chunks. Args: ms_table (table): An opened, writable MeasurementSet main table. present_columns (list[str]): Visibility column names to read. chunk_size (int): Number of rows to read per chunk. Yields: BaselineChunk: One chunk of rows at a time. """ table_length = len(ms_table) lower_row = 0 with tqdm(desc="Computing weights", total=table_length, unit="rows") as pbar: while lower_row < table_length: # On the final iteration the table may have fewer rows than chunk_size; # casacore will return however many rows remain. actual_chunk = min(chunk_size, table_length - lower_row) try: weight = ms_table.getcol( "WEIGHT", startrow=lower_row, nrow=actual_chunk )[:] except RuntimeError: weight = None try: weight_spectrum = ms_table.getcol( "WEIGHT_SPECTRUM", startrow=lower_row, nrow=actual_chunk )[:] except RuntimeError: weight_spectrum = None interval = ms_table.getcol( "INTERVAL", startrow=lower_row, nrow=actual_chunk )[:] yield DataChunk( weight=weight, weight_spectrum=weight_spectrum, interval=interval * u.s, row_start=lower_row, chunk_size=actual_chunk, ) lower_row += chunk_size pbar.update(actual_chunk)
[docs] def make_new_weights( ms_path: Path, chunk_size: int = 1000, dry_run: bool = False, ) -> Path: npol = 4 freqs = ms_utils.get_freq_from_ms(ms_path=ms_path) chan_bw = np.gradient(freqs) mean_chan_bw = chan_bw.mean() with table(ms_path.as_posix(), readonly=False, ack=False) as tab: if "WEIGHT" in tab.colnames(): logger.info("Will update WEIGHT column") if "WEIGHT_SPECTRUM" in tab.colnames(): logger.info("Will update WEIGHT_SPECTRUM column") for chunk in _iter_chunks(tab, chunk_size): row_start = chunk.row_start n_rows = chunk.chunk_size # New weight: # (integration time in s) * (bandwidth in kHz) # See SKB-1081 for detail if chunk.weight_spectrum is not None: new_weight_spectrum_chunk = ( np.repeat( chan_bw.to(u.kHz).value[np.newaxis], len(chunk.interval), axis=0 ) * chunk.interval[:, np.newaxis].to(u.s).value ) new_weight_spectrum_chunk = np.repeat( new_weight_spectrum_chunk[..., np.newaxis], npol, axis=-1 ) data_fraction = fd_to_fraction(chunk.weight_spectrum) new_weight_spectrum_chunk = new_weight_spectrum_chunk * data_fraction assert new_weight_spectrum_chunk.shape == chunk.weight_spectrum.shape if not dry_run: tab.putcol( "WEIGHT_SPECTRUM", new_weight_spectrum_chunk, startrow=row_start, nrow=n_rows, ) if chunk.weight is not None: new_weight_chunk = ( chunk.interval.to(u.s).value * mean_chan_bw.to(u.kHz).value ) new_weight_chunk = np.repeat( new_weight_chunk[..., np.newaxis], npol, axis=1 ) if (chunk.weight == 0).all(): msg = "Chunk of WEIGHT column is all 0! Assuming this is NOT actually the case..." logger.critical(msg) data_fraction = np.ones_like(chunk.weight) else: data_fraction = fd_to_fraction(chunk.weight) new_weight_chunk = new_weight_chunk * data_fraction if not dry_run: tab.putcol( "WEIGHT", new_weight_chunk, startrow=row_start, nrow=n_rows, ) return ms_path
[docs] def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Fix WEIGHT/WEIGHT_SPECTRUM column using FD, INTERVAL, and channel bandwidth in MS", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("ms_path", help="Path to target MeasurementSet", type=Path) parser.add_argument( "-c", "--chunksize", type=int, default=1000, help="Number of rows to process in a chunk", ) parser.add_argument( "-d", "--dry-run", action="store_true", help="Do not write new columns to MS", ) return parser
[docs] def main() -> None: parser = get_parser() args = parser.parse_args() _ = make_new_weights( ms_path=args.ms_path, chunk_size=args.chunksize, dry_run=args.dry_run, )
if __name__ == "__main__": main()