from __future__ import annotations
import argparse
import itertools
from collections.abc import Collection
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from astropy.visualization import quantity_support
from casacore.tables import table, taql
from matplotlib.figure import Figure
from low_comm_tools import ms_utils
from low_comm_tools.constants import COARSE_CHAN_WIDTH, CORRELATION_POLS
from low_comm_tools.log_config import logger
from low_comm_tools.plotting.utils import grid_dims
mpl.use("Agg")
[docs]
def make_fringe_plot(
ms_path: Path,
pols: Collection[str] | None = None,
freq_min_mhz: float = 0,
freq_max_mhz: float = +np.inf,
spw_id: int = 0,
save_fig: bool = False,
out_dir: Path | None = None,
show_coarse_channels: bool = False,
max_baselines: int | None = None,
coarse_channel_min: int = 0,
coarse_channel_max: int = 512,
data_column: str = "DATA",
) -> Figure | None:
if pols is None:
pols = CORRELATION_POLS
freq = ms_utils.get_freq_from_ms(ms_path).to("MHz")
stations = ms_utils.get_antenna_names_from_ms(ms_path)
field = ms_utils.get_field_name_from_ms(ms_path)
times = ms_utils.get_time_from_ms(ms_path)
baselines = list(itertools.combinations(range(len(stations)), 2))
if max_baselines:
baselines = baselines[:max_baselines]
chan_mask = (freq.value >= freq_min_mhz) & (freq.value < freq_max_mhz)
chan_start = int(np.argmax(chan_mask))
chan_end = int(len(chan_mask) - np.argmax(chan_mask[::-1]) - 1)
freq_sel = freq[chan_mask]
if len(freq_sel) == 0:
msg = f"No frequencies selected! Min and max freq are {freq.to('MHz').min()} {freq.to('MHz').max()}"
raise ValueError(msg)
n_baselines = len(baselines)
ncols, nrows = grid_dims(n_baselines)
fig, axs = plt.subplots(
nrows, ncols, sharex=True, sharey=True, figsize=(ncols * 5, nrows * 3.5)
)
with table(ms_path.as_posix()) as tab:
_ = tab # keep linters happy
for ax_idx, (ant_2, ant_1) in enumerate(baselines):
ax = axs.flat[ax_idx]
sub_tab = taql(
f"SELECT {data_column}[{chan_start}:{chan_end}+1,] AS {data_column}, "
f"FLAG[{chan_start}:{chan_end}+1,] AS FLAG "
"FROM $tab WHERE "
"(ANTENNA1==$ant_1 and ANTENNA2==$ant_2) or "
"(ANTENNA1==$ant_2 and ANTENNA2==$ant_1) "
f"AND DATA_DESC_ID={spw_id}"
)
if sub_tab.nrows() == 0:
logger.warning(f"No data for baseline {ant_1} - {ant_2}")
return None
data = sub_tab.getcol(data_column)
flag = sub_tab.getcol("FLAG")
data[flag] = np.nan
baseline_m = ms_utils.get_baseline_length(
ms_path=ms_path,
ant_1=ant_1,
ant_2=ant_2,
).value
data_mean = np.nanmean(data, axis=0)
angle_mean = np.angle(data_mean)
row, col = divmod(ax_idx, ncols)
if col == 0:
ax.set_ylabel("phase / deg")
if row == nrows - 1:
ax.set_xlabel("Frequency / MHz")
ax.set(title=f"{stations[ant_2]}::{stations[ant_1]} ({baseline_m:0.0f}m)")
for p in pols:
pol_idx = CORRELATION_POLS.index(p)
ax.plot(
freq_sel,
np.rad2deg(angle_mean[..., pol_idx]),
".",
label=p,
lw=0.2,
)
if show_coarse_channels:
for i in range(coarse_channel_min, coarse_channel_max + 1, 1):
ax.axvline(
COARSE_CHAN_WIDTH * (i - 0.5),
lw=0.5,
ls="--",
c="k",
label="Coarse channel"
if i == coarse_channel_min and ax_idx == 0
else None,
)
ax.set(xlim=(freq_sel.min(), freq_sel.max()))
if ax_idx == 0:
fig.legend(title="Instrumental pol.")
for ax in axs.flat[n_baselines:]:
ax.set_visible(False)
t0_str = times[0].strftime("%H:%M")
t1_str = times[-1].strftime("%H:%M")
fig.suptitle(f"{field} — {t0_str} to {t1_str} UTC")
fig.tight_layout()
if save_fig:
out_name = ms_path.with_suffix(f".fringes.{data_column}.png").name
if out_dir is None:
out_dir = ms_path.parent
out_path = out_dir / out_name
fig.savefig(out_path)
logger.info(f"Saved to {out_path}")
return fig
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Plot fringes per for all baselines.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", type=Path, help="Path to visibilities.")
parser.add_argument(
"--out-dir",
type=Path,
default=None,
help="Output directory for plots. Defaults to the MS directory.",
)
parser.add_argument(
"--pols",
nargs="+",
choices=CORRELATION_POLS,
default=None,
help="Polarisations to plot.",
)
parser.add_argument(
"--freq-min-mhz",
type=float,
default=0,
help="Minimum frequency in MHz.",
)
parser.add_argument(
"--freq-max-mhz",
type=float,
default=float("inf"),
help="Maximum frequency in MHz.",
)
parser.add_argument(
"--spw-id",
type=int,
default=0,
help="Spectral window ID.",
)
parser.add_argument(
"--max-baselines",
type=int,
default=None,
help="Maximum number of baselines to plot.",
)
parser.add_argument(
"--show-coarse-channels",
action="store_true",
default=False,
help="Overlay coarse channel boundaries on the plot.",
)
parser.add_argument(
"--coarse-channel-min",
type=int,
default=0,
help="Minimum coarse channel index.",
)
parser.add_argument(
"--coarse-channel-max",
type=int,
default=512,
help="Maximum coarse channel index.",
)
parser.add_argument(
"--data-column",
type=str,
default="DATA",
help="Data column",
)
return parser
[docs]
def main() -> None:
parser = get_parser()
args = parser.parse_args()
_ = make_fringe_plot(
ms_path=args.ms_path,
pols=args.pols,
freq_min_mhz=args.freq_min_mhz,
freq_max_mhz=args.freq_max_mhz,
spw_id=args.spw_id,
save_fig=True,
out_dir=args.out_dir,
show_coarse_channels=args.show_coarse_channels,
max_baselines=args.max_baselines,
coarse_channel_min=args.coarse_channel_min,
coarse_channel_max=args.coarse_channel_max,
data_column=args.data_column,
)
if __name__ == "__main__":
main()