from __future__ import annotations
import argparse
from collections.abc import Generator
from pathlib import Path
from typing import Any, NamedTuple
import astropy.units as u
import numpy as np
import numpy.typing as npt
from casacore.tables import table
from tqdm.auto import tqdm
from low_comm_tools import ms_utils
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import FloatArray
[docs]
def fd_to_fraction(
fd_arr: npt.NDArray[np.floating[Any]],
) -> npt.NDArray[np.floating[Any]]:
return (fd_arr / 255) ** 2
[docs]
class DataChunk(NamedTuple):
[docs]
weight: FloatArray | None
[docs]
weight_spectrum: FloatArray | None
[docs]
interval: u.Quantity[u.s]
[docs]
def _iter_chunks(
ms_table: table,
chunk_size: int,
) -> Generator[DataChunk, None, None]:
"""Iterate over rows of a MeasurementSet main table in chunks.
Args:
ms_table (table): An opened, writable MeasurementSet main table.
present_columns (list[str]): Visibility column names to read.
chunk_size (int): Number of rows to read per chunk.
Yields:
BaselineChunk: One chunk of rows at a time.
"""
table_length = len(ms_table)
lower_row = 0
with tqdm(desc="Computing weights", total=table_length, unit="rows") as pbar:
while lower_row < table_length:
# On the final iteration the table may have fewer rows than chunk_size;
# casacore will return however many rows remain.
actual_chunk = min(chunk_size, table_length - lower_row)
try:
weight = ms_table.getcol(
"WEIGHT", startrow=lower_row, nrow=actual_chunk
)[:]
except RuntimeError:
weight = None
try:
weight_spectrum = ms_table.getcol(
"WEIGHT_SPECTRUM", startrow=lower_row, nrow=actual_chunk
)[:]
except RuntimeError:
weight_spectrum = None
interval = ms_table.getcol(
"INTERVAL", startrow=lower_row, nrow=actual_chunk
)[:]
yield DataChunk(
weight=weight,
weight_spectrum=weight_spectrum,
interval=interval * u.s,
row_start=lower_row,
chunk_size=actual_chunk,
)
lower_row += chunk_size
pbar.update(actual_chunk)
[docs]
def make_new_weights(
ms_path: Path,
chunk_size: int = 1000,
dry_run: bool = False,
) -> Path:
npol = 4
freqs = ms_utils.get_freq_from_ms(ms_path=ms_path)
chan_bw = np.gradient(freqs)
mean_chan_bw = chan_bw.mean()
with table(ms_path.as_posix(), readonly=False, ack=False) as tab:
if "WEIGHT" in tab.colnames():
logger.info("Will update WEIGHT column")
if "WEIGHT_SPECTRUM" in tab.colnames():
logger.info("Will update WEIGHT_SPECTRUM column")
for chunk in _iter_chunks(tab, chunk_size):
row_start = chunk.row_start
n_rows = chunk.chunk_size
# New weight:
# (integration time in s) * (bandwidth in kHz)
# See SKB-1081 for detail
if chunk.weight_spectrum is not None:
new_weight_spectrum_chunk = (
np.repeat(
chan_bw.to(u.kHz).value[np.newaxis], len(chunk.interval), axis=0
)
* chunk.interval[:, np.newaxis].to(u.s).value
)
new_weight_spectrum_chunk = np.repeat(
new_weight_spectrum_chunk[..., np.newaxis], npol, axis=-1
)
data_fraction = fd_to_fraction(chunk.weight_spectrum)
new_weight_spectrum_chunk = new_weight_spectrum_chunk * data_fraction
assert new_weight_spectrum_chunk.shape == chunk.weight_spectrum.shape
if not dry_run:
tab.putcol(
"WEIGHT_SPECTRUM",
new_weight_spectrum_chunk,
startrow=row_start,
nrow=n_rows,
)
if chunk.weight is not None:
new_weight_chunk = (
chunk.interval.to(u.s).value * mean_chan_bw.to(u.kHz).value
)
new_weight_chunk = np.repeat(
new_weight_chunk[..., np.newaxis], npol, axis=1
)
if (chunk.weight == 0).all():
msg = "Chunk of WEIGHT column is all 0! Assuming this is NOT actually the case..."
logger.critical(msg)
data_fraction = np.ones_like(chunk.weight)
else:
data_fraction = fd_to_fraction(chunk.weight)
new_weight_chunk = new_weight_chunk * data_fraction
if not dry_run:
tab.putcol(
"WEIGHT",
new_weight_chunk,
startrow=row_start,
nrow=n_rows,
)
return ms_path
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Fix WEIGHT/WEIGHT_SPECTRUM column using FD, INTERVAL, and channel bandwidth in MS",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", help="Path to target MeasurementSet", type=Path)
parser.add_argument(
"-c",
"--chunksize",
type=int,
default=1000,
help="Number of rows to process in a chunk",
)
parser.add_argument(
"-d",
"--dry-run",
action="store_true",
help="Do not write new columns to MS",
)
return parser
[docs]
def main() -> None:
parser = get_parser()
args = parser.parse_args()
_ = make_new_weights(
ms_path=args.ms_path,
chunk_size=args.chunksize,
dry_run=args.dry_run,
)
if __name__ == "__main__":
main()