Source code for low_comm_tools.plotting.fringe_rate

from __future__ import annotations

import argparse
from itertools import combinations
from pathlib import Path
from typing import Literal

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from casacore.tables import table, taql
from matplotlib.figure import Figure
from tqdm.auto import tqdm

from low_comm_tools import ms_utils
from low_comm_tools.constants import CORRELATION_POLS
from low_comm_tools.log_config import logger


[docs] def _parse_freq_selection( freq_mhz: float | Literal["start", "mid", "end"], channel_freqs: u.Quantity[u.Hz], ) -> int: n_chan = len(channel_freqs) if isinstance(freq_mhz, str): if freq_mhz == "start": return 0 if freq_mhz == "mid": return n_chan // 2 if freq_mhz == "end": # Not using -1 here to allow spacing by bandwidth return n_chan - 1 msg = f"Unsupported option '{freq_mhz}'" # type: ignore[unreachable] raise ValueError(msg) if freq_mhz < 0: msg = f"Frequencies cannot be negative! Got {freq_mhz=}" raise ValueError(msg) return int(np.argmin(np.abs(channel_freqs.to("MHz").value - freq_mhz)))
[docs] def _parse_bandwidth_selection( bandwidth_mhz: float, channel_freqs: u.Quantity[u.Hz], ) -> int: if bandwidth_mhz < 0: msg = f"Bandwidths cannot be negative! Got {bandwidth_mhz=}" raise ValueError(msg) total_bw = channel_freqs.max() - channel_freqs.min() chan_bw = total_bw / len(channel_freqs) # bw per chan return int(np.ceil(bandwidth_mhz / chan_bw.to(u.MHz).value))
[docs] def channel_selection( channel_freqs: u.Quantity[u.Hz], freq_mhz: float | Literal["start", "mid", "end"], bandwidth_mhz: float, ) -> tuple[int, int]: chan_select = _parse_freq_selection( freq_mhz=freq_mhz, channel_freqs=channel_freqs, ) bw_select = _parse_bandwidth_selection( bandwidth_mhz=bandwidth_mhz, channel_freqs=channel_freqs ) # Clamp selection chan_select = min(chan_select, max(0, len(channel_freqs) - bw_select)) return chan_select, bw_select
[docs] def make_fringe_rate_plots( ms_path: Path, freq_mhz: float | Literal["start", "mid", "end"], bandwidth_mhz: float, uv_min_m: float = 500, out_dir: Path | None = None, data_column: str = "DATA", ) -> list[Figure]: if out_dir is None: out_dir = ms_path.parent channel_freqs = ms_utils.get_freq_from_ms(ms_path) chan_select, bw_select = channel_selection( channel_freqs=channel_freqs, freq_mhz=freq_mhz, bandwidth_mhz=bandwidth_mhz, ) sub_freqs = channel_freqs[chan_select : chan_select + bw_select] logger.info( f"Looking at fringes over {(sub_freqs.max() - sub_freqs.min()).to('MHz'):0.1f} @ {sub_freqs.to('MHz').mean():0.1f}" ) stations = ms_utils.get_antenna_names_from_ms(ms_path) baselines = list(combinations(range(len(stations)), 2)) figures: list[Figure] = [] for s in tqdm(stations, desc="Station"): fig, axs = plt.subplots(2, 2, figsize=(16, 8), sharex=True, sharey=True) for _, (ant_2, ant_1) in enumerate(tqdm(baselines, desc="Baseline")): if s not in (stations[ant_1], stations[ant_2]): continue with ( table(ms_path.as_posix(), ack=False) as tab, taql( "select from $tab where " "(ANTENNA1==$ant_1 and ANTENNA2==$ant_2) or " "(ANTENNA1==$ant_2 and ANTENNA2==$ant_1)" ) as subtab, ): _ = tab # keep linters happy if subtab.nrows() == 0: continue if ms_utils.get_baseline_length(ms_path, ant_1, ant_2).value < uv_min_m: continue data = np.array(subtab.getcol(data_column)) # time, freq, pol times = ms_utils.get_time_from_table(subtab) sub_data = np.nanmean( data[:, chan_select : chan_select + bw_select], axis=1, ) for i, (pol, ax, title) in enumerate( zip( sub_data.T, axs.flatten(), CORRELATION_POLS, strict=False, ) ): ax.plot( times.datetime, np.rad2deg(np.angle(pol)), ".", label=f"{stations[ant_1]}::{stations[ant_2]}" if i == 0 else None, lw=1, ) ax.set(title=title, xlabel="time", ylabel="phase") ax.grid(visible=True) fig.legend(ncol=3) fig.suptitle( f"All (>{uv_min_m}m) baselines to {s} - {(sub_freqs.max() - sub_freqs.min()).to('MHz'):0.1f} @ {sub_freqs.to('MHz').mean():0.1f}" ) out_name = f"{ms_path.stem}_{s}_{data_column}_fringe_rate.png" out_path = out_dir / out_name fig.savefig(out_path, dpi=300, bbox_inches="tight") logger.info(f"Wrote {out_path}") figures.append(fig) return figures
[docs] def _parse_freq(value: str) -> float | str: if value in {"start", "mid", "end"}: return value try: return float(value) except ValueError as e: msg = "Must be a float or one of: start, mid, end" raise argparse.ArgumentTypeError(msg) from e
[docs] def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Plot phase vs time over baselines per station.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("ms_path", type=Path, help="Path to visibilities.") parser.add_argument( "--out-dir", type=Path, default=None, help="Output directory for plots. Defaults to the MS directory.", ) parser.add_argument( "--freq-mhz", type=_parse_freq, default="mid", help="Frequency to select in MHz or special valuues of 'start', 'mid', or 'end'", ) parser.add_argument( "--bw-mhz", type=float, default=1, help="Bandwidth to average over in MHz" ) parser.add_argument( "--uv-min-m", type=float, default=500, help="Minimum UV distance in metres for a baseline to be included", ) parser.add_argument( "--data-column", type=str, default="DATA", help="Data column to plot.", ) return parser
[docs] def main() -> None: args = get_parser().parse_args() _ = make_fringe_rate_plots( ms_path=args.ms_path, freq_mhz=args.freq_mhz, bandwidth_mhz=args.bw_mhz, out_dir=args.out_dir, uv_min_m=args.uv_min_m, data_column=args.data_column, )
if __name__ == "__main__": main()