from __future__ import annotations
import argparse
from itertools import combinations
from pathlib import Path
from typing import Literal
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from casacore.tables import table, taql
from matplotlib.figure import Figure
from tqdm.auto import tqdm
from low_comm_tools import ms_utils
from low_comm_tools.constants import CORRELATION_POLS
from low_comm_tools.log_config import logger
[docs]
def _parse_freq_selection(
freq_mhz: float | Literal["start", "mid", "end"],
channel_freqs: u.Quantity[u.Hz],
) -> int:
n_chan = len(channel_freqs)
if isinstance(freq_mhz, str):
if freq_mhz == "start":
return 0
if freq_mhz == "mid":
return n_chan // 2
if freq_mhz == "end":
# Not using -1 here to allow spacing by bandwidth
return n_chan - 1
msg = f"Unsupported option '{freq_mhz}'" # type: ignore[unreachable]
raise ValueError(msg)
if freq_mhz < 0:
msg = f"Frequencies cannot be negative! Got {freq_mhz=}"
raise ValueError(msg)
return int(np.argmin(np.abs(channel_freqs.to("MHz").value - freq_mhz)))
[docs]
def _parse_bandwidth_selection(
bandwidth_mhz: float,
channel_freqs: u.Quantity[u.Hz],
) -> int:
if bandwidth_mhz < 0:
msg = f"Bandwidths cannot be negative! Got {bandwidth_mhz=}"
raise ValueError(msg)
total_bw = channel_freqs.max() - channel_freqs.min()
chan_bw = total_bw / len(channel_freqs) # bw per chan
return int(np.ceil(bandwidth_mhz / chan_bw.to(u.MHz).value))
[docs]
def channel_selection(
channel_freqs: u.Quantity[u.Hz],
freq_mhz: float | Literal["start", "mid", "end"],
bandwidth_mhz: float,
) -> tuple[int, int]:
chan_select = _parse_freq_selection(
freq_mhz=freq_mhz,
channel_freqs=channel_freqs,
)
bw_select = _parse_bandwidth_selection(
bandwidth_mhz=bandwidth_mhz, channel_freqs=channel_freqs
)
# Clamp selection
chan_select = min(chan_select, max(0, len(channel_freqs) - bw_select))
return chan_select, bw_select
[docs]
def make_fringe_rate_plots(
ms_path: Path,
freq_mhz: float | Literal["start", "mid", "end"],
bandwidth_mhz: float,
uv_min_m: float = 500,
out_dir: Path | None = None,
data_column: str = "DATA",
) -> list[Figure]:
if out_dir is None:
out_dir = ms_path.parent
channel_freqs = ms_utils.get_freq_from_ms(ms_path)
chan_select, bw_select = channel_selection(
channel_freqs=channel_freqs,
freq_mhz=freq_mhz,
bandwidth_mhz=bandwidth_mhz,
)
sub_freqs = channel_freqs[chan_select : chan_select + bw_select]
logger.info(
f"Looking at fringes over {(sub_freqs.max() - sub_freqs.min()).to('MHz'):0.1f} @ {sub_freqs.to('MHz').mean():0.1f}"
)
stations = ms_utils.get_antenna_names_from_ms(ms_path)
baselines = list(combinations(range(len(stations)), 2))
figures: list[Figure] = []
for s in tqdm(stations, desc="Station"):
fig, axs = plt.subplots(2, 2, figsize=(16, 8), sharex=True, sharey=True)
for _, (ant_2, ant_1) in enumerate(tqdm(baselines, desc="Baseline")):
if s not in (stations[ant_1], stations[ant_2]):
continue
with (
table(ms_path.as_posix(), ack=False) as tab,
taql(
"select from $tab where "
"(ANTENNA1==$ant_1 and ANTENNA2==$ant_2) or "
"(ANTENNA1==$ant_2 and ANTENNA2==$ant_1)"
) as subtab,
):
_ = tab # keep linters happy
if subtab.nrows() == 0:
continue
if ms_utils.get_baseline_length(ms_path, ant_1, ant_2).value < uv_min_m:
continue
data = np.array(subtab.getcol(data_column)) # time, freq, pol
times = ms_utils.get_time_from_table(subtab)
sub_data = np.nanmean(
data[:, chan_select : chan_select + bw_select],
axis=1,
)
for i, (pol, ax, title) in enumerate(
zip(
sub_data.T,
axs.flatten(),
CORRELATION_POLS,
strict=False,
)
):
ax.plot(
times.datetime,
np.rad2deg(np.angle(pol)),
".",
label=f"{stations[ant_1]}::{stations[ant_2]}"
if i == 0
else None,
lw=1,
)
ax.set(title=title, xlabel="time", ylabel="phase")
ax.grid(visible=True)
fig.legend(ncol=3)
fig.suptitle(
f"All (>{uv_min_m}m) baselines to {s} - {(sub_freqs.max() - sub_freqs.min()).to('MHz'):0.1f} @ {sub_freqs.to('MHz').mean():0.1f}"
)
out_name = f"{ms_path.stem}_{s}_{data_column}_fringe_rate.png"
out_path = out_dir / out_name
fig.savefig(out_path, dpi=300, bbox_inches="tight")
logger.info(f"Wrote {out_path}")
figures.append(fig)
return figures
[docs]
def _parse_freq(value: str) -> float | str:
if value in {"start", "mid", "end"}:
return value
try:
return float(value)
except ValueError as e:
msg = "Must be a float or one of: start, mid, end"
raise argparse.ArgumentTypeError(msg) from e
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Plot phase vs time over baselines per station.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", type=Path, help="Path to visibilities.")
parser.add_argument(
"--out-dir",
type=Path,
default=None,
help="Output directory for plots. Defaults to the MS directory.",
)
parser.add_argument(
"--freq-mhz",
type=_parse_freq,
default="mid",
help="Frequency to select in MHz or special valuues of 'start', 'mid', or 'end'",
)
parser.add_argument(
"--bw-mhz", type=float, default=1, help="Bandwidth to average over in MHz"
)
parser.add_argument(
"--uv-min-m",
type=float,
default=500,
help="Minimum UV distance in metres for a baseline to be included",
)
parser.add_argument(
"--data-column",
type=str,
default="DATA",
help="Data column to plot.",
)
return parser
[docs]
def main() -> None:
args = get_parser().parse_args()
_ = make_fringe_rate_plots(
ms_path=args.ms_path,
freq_mhz=args.freq_mhz,
bandwidth_mhz=args.bw_mhz,
out_dir=args.out_dir,
uv_min_m=args.uv_min_m,
data_column=args.data_column,
)
if __name__ == "__main__":
main()