from __future__ import annotations
import argparse
import time
from pathlib import Path
from typing import Any, NamedTuple
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from casacore.tables import table, taql
from matplotlib import animation
from matplotlib.collections import PathCollection
from matplotlib.figure import Figure
from tqdm.auto import tqdm
from low_comm_tools.constants import COARSE_CHAN_WIDTH, CORRELATION_POLS
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import (
get_antenna_names_from_ms,
get_freq_from_ms,
get_time_from_table,
)
from low_comm_tools.plotting import baselines
from low_comm_tools.plotting.utils import SKA_HEX_COLOURS
[docs]
class BaselineData(NamedTuple):
"""
Visibilities for a given baseline along with its corresponding Frequency
and Time arrays
"""
"""1D array of frequencies in MHz"""
"""1D array of times in s"""
"""2D array of baseline visibilities"""
[docs]
def animate(
frame: int,
scat_phases: list[PathCollection],
scat_ampls: list[PathCollection],
time_text: mpl.text.Text,
f: np.ndarray,
t: np.ndarray,
z: np.ndarray,
pols: list[str],
deg: bool,
log: bool,
) -> tuple[mpl.artist.Artist, ...]:
for i, pol in enumerate(pols):
pol_idx = CORRELATION_POLS.index(pol)
new_phase_data = np.column_stack((f, np.angle(z[frame, :, pol_idx], deg=deg)))
ampl = np.abs(z[frame, :, pol_idx])
new_ampl_data = np.column_stack((f, np.log10(ampl) if log else ampl))
scat_phases[i].set_offsets(new_phase_data)
scat_ampls[i].set_offsets(new_ampl_data)
time_text.set_text(f"t = {t[frame]:.1f} s")
return *scat_phases, *scat_ampls, time_text
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Make an animation (default = GIF) of amplitudes an phases for a given baseline over time",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", type=Path, help="Path to visibilities")
parser.add_argument(
"stations", nargs=2, help="The two stations comprising the desired baseline"
)
parser.add_argument(
"--tfactor",
type=int,
default=1,
help="Number of time bins to average together (default: do no averaging)",
)
parser.add_argument(
"--ffactor",
type=int,
default=1,
help="Number of channels to average together (default: do no averaging)",
)
parser.add_argument(
"--rad",
action="store_true",
help="Plot phase angles in radians instead of the default degrees",
)
parser.add_argument(
"--linear",
action="store_true",
help="Plot amplitudes on a linear scale instead of the default log scale",
)
parser.add_argument(
"--delay_samples",
type=int,
default=0,
help="Apply an overall phase slope across all channels corresponding to a delay of DELAY_SAMPLES*1.08us of pre-channelised data",
)
parser.add_argument(
"--intra_chan_delay_samples",
type=int,
default=0,
help="Apply a phase slope across coarse channels corresponding to a delay of DELAY_SAMPLES*1.08us of post-channelised data",
)
parser.add_argument(
"--output_animation",
help="Output (path to) file name where the animation gets written (default = 'STATION1_STATION2_anim.gif'). If given the special value SHOW, then plt.show() is called instead of writing the result to file.",
type=str,
default=None,
)
parser.add_argument(
"--pols",
default=f"{CORRELATION_POLS[0]},{CORRELATION_POLS[-1]}",
help=f"Comma-separated list of polarisations to plot. Supported polarisations are {CORRELATION_POLS}.",
type=str,
)
return parser
# This differs from the _read_subtable in plotting.baselines by keeping strictly to the
# order of stations supplied by the user. If the order is "wrong" as per the measurement
# set, # this function (deliberately) raises an error, which the caller must handle.
[docs]
def _read_subtable(
ms_path: Path,
ant_1: int,
ant_2: int,
data_column: str,
station_names: list[str],
freq_chan: u.Quantity,
) -> baselines.SubtableData | None:
t0 = time.perf_counter()
with (
table(ms_path.as_posix(), ack=False) as tab,
taql(
"select from $tab where (ANTENNA1==$ant_1 and ANTENNA2==$ant_2)"
) as subtab,
):
_ = tab # keep linters happy
if subtab.nrows() == 0:
return None
data = np.array(subtab.getcol(data_column))
flags = np.array(subtab.getcol("FLAG"))
time_centroid = get_time_from_table(subtab)
masked_data = np.ma.masked_array(data, mask=flags)
times = time_centroid - time_centroid.min()
logger.info(
f"Done reading baseline {ant_1} - {ant_2} in {time.perf_counter() - t0:.1f}s"
)
return baselines.SubtableData(
masked_data=masked_data,
station_names=station_names,
freq_chan=freq_chan,
times=times,
ant_1=ant_1,
ant_2=ant_2,
)
[docs]
def _get_baseline(
ms_path: Path,
station1: str,
station2: str,
delay_samples: int = 0,
intra_chan_delay_samples: int = 0,
) -> BaselineData:
stations = get_antenna_names_from_ms(ms_path)
try:
ant_1 = stations.index(station1)
ant_2 = stations.index(station2)
except ValueError as err:
msg = f"Station not found. Available stations are: {', '.join(stations)}"
raise ValueError(msg) from err
logger.info(f"Reading baseline {station1}_{station2}")
switched_order = False
subtable_data = _read_subtable(
ms_path=ms_path,
ant_1=ant_1,
ant_2=ant_2,
data_column="DATA",
station_names=[station1, station2],
freq_chan=get_freq_from_ms(ms_path),
)
if subtable_data is None:
switched_order = True
logger.info(f"Switching to reading baseline {station2}_{station1}")
subtable_data = _read_subtable(
ms_path=ms_path,
ant_1=ant_2,
ant_2=ant_1,
data_column="DATA",
station_names=[station1, station2],
freq_chan=get_freq_from_ms(ms_path),
)
if subtable_data is None:
msg = "Unable to read subtable data"
raise ValueError(msg)
f, t, z, _, _ = baselines._compute_plot_arrays(
subtable_data,
"spectrum",
)
if switched_order:
z = np.conj(
z
) # This ensures the phases are correct up to the given order of antennas
coarse_channel_sample_rate = (1 / COARSE_CHAN_WIDTH).to(u.us).value # = 1.08
if delay_samples != 0:
logger.info(
f"Applying a phase ramp equivalent to {delay_samples} samples across the whole observation"
)
z *= np.exp(
-2j
* np.pi
* delay_samples
* coarse_channel_sample_rate
* f[np.newaxis, :, np.newaxis]
)
if intra_chan_delay_samples != 0:
chan_ctr_freqs = np.round(f / COARSE_CHAN_WIDTH.value) * COARSE_CHAN_WIDTH.value
fine_chan_freqs = f - chan_ctr_freqs
logger.info(
f"Applying a phase ramp equivalent to {intra_chan_delay_samples} samples across each coarse channel"
)
z *= np.exp(
-2j
* np.pi
* intra_chan_delay_samples
* coarse_channel_sample_rate
* fine_chan_freqs[np.newaxis, :, np.newaxis]
)
return BaselineData(f, t, z)
[docs]
def _plot_coarse_chan_edges(
f: np.ndarray,
axs: np.ndarray,
alpha: float = 0.3,
c: str = "k",
ls: str = "dashed",
**kwargs: dict[str, Any],
) -> None:
f_lo_hi = [
np.min(f),
np.max(f) + COARSE_CHAN_WIDTH.value,
] * u.MHz # Add one coarse channel to make n+1 edges for n channels
chan_lo_hi = np.round((f_lo_hi / COARSE_CHAN_WIDTH).decompose())
edges = (np.arange(*chan_lo_hi) - 0.5) * COARSE_CHAN_WIDTH.value
for ax in axs.flatten():
for edge in edges:
ax.axvline(edge, ls=ls, alpha=alpha, c=c, **kwargs)
[docs]
def _plot_single_integration_phase_and_ampl(
x: np.ndarray,
z: np.ndarray,
n: int,
ax_phase: mpl.axes.Axes,
ax_ampl: mpl.axes.Axes,
log: bool = True,
pols: list[str] | None = None,
deg: bool = True,
s: float = 3.0,
c: list[str] = SKA_HEX_COLOURS,
**kwargs: dict[str, Any],
) -> tuple[list[PathCollection], list[PathCollection]]:
scat_phases = []
scat_ampls = []
if pols is None:
pols = [CORRELATION_POLS[0], CORRELATION_POLS[-1]]
if ax_phase is not None:
for i, pol in enumerate(pols):
pol_idx = CORRELATION_POLS.index(pol)
scat_phases.append(
ax_phase.scatter(
x,
np.angle(z[n, :, pol_idx], deg=deg),
c=c[i % len(c)],
s=s,
label=pol,
**kwargs, # type:ignore[arg-type]
)
)
ax_phase.set_ylim(
[-180.0, 180.0] if deg else [-np.pi, np.pi], # type:ignore[arg-type]
)
ax_phase.set_ylabel(f"Phase ({'deg' if deg else 'rad'})")
ax_phase.set_xlabel("Frequency (MHz)")
ax_phase.set_yticks(
[-180.0, -90.0, 0.0, 90.0, 180.0]
if deg
else [-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]
)
if ax_ampl is not None:
for i, pol in enumerate(pols):
pol_idx = CORRELATION_POLS.index(pol)
ampl = np.abs(z[:, :, pol_idx])
if log:
ampl = np.log10(ampl)
ax_ampl.set_ylabel("log10 (Amplitude (a.u.))")
else:
ax_ampl.set_ylabel("Amplitude (a.u.)")
scat_ampls.append(
ax_ampl.scatter(
x,
ampl[n, :],
c=c[i % len(c)],
s=s,
label=pol,
**kwargs, # type:ignore[arg-type]
)
)
ax_ampl.set_xlabel("Frequency (MHz)")
ax_ampl.set_ylim(
[np.min(ampl), np.max(ampl)], # type:ignore[arg-type]
)
return scat_phases, scat_ampls
[docs]
def _scrunch(
baseline_data: BaselineData,
tfactor: int = 1,
ffactor: int = 1,
) -> tuple[Any, Any, Any]:
# Truncate as necessary to get int number of downsamples
f = baseline_data.f[: (len(baseline_data.f) // ffactor) * ffactor]
t = baseline_data.t[: (len(baseline_data.t) // tfactor) * tfactor]
z = baseline_data.z[
: (len(baseline_data.t) // tfactor) * tfactor,
: (len(baseline_data.f) // ffactor) * ffactor,
]
f_avg = np.mean(f.reshape(len(f) // ffactor, ffactor), axis=1)
t_avg = np.mean(t.reshape(len(t) // tfactor, tfactor), axis=1)
z_avg = np.mean(
z.reshape(len(t) // tfactor, tfactor, len(f) // ffactor, ffactor, 4),
axis=(1, 3),
)
return f_avg, t_avg, z_avg
[docs]
def make_fringe_animation(
ms_path: Path,
deg: bool,
log: bool,
tfactor: int,
ffactor: int,
delay_samples: int,
intra_chan_delay_samples: int,
stations: tuple[str, str],
output_animation: str | None,
pols: str,
) -> tuple[Figure, animation.FuncAnimation]:
_default_path = ms_path.parent / f"{ms_path.stem}.{'_'.join(stations)}_anim.gif"
output_animation = output_animation or _default_path.as_posix()
# F = Frequencies
# T = Times
# Z = complex-valued samples
# Similarly for f, t, z (after downsampling)
baseline_data = _get_baseline(
ms_path,
stations[0],
stations[1],
delay_samples=delay_samples,
intra_chan_delay_samples=intra_chan_delay_samples,
)
f, t, z = _scrunch(baseline_data, tfactor=tfactor, ffactor=ffactor)
pol_list = pols.split(",")
plt.close()
fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(12, 6))
scat_phases, scat_ampls = _plot_single_integration_phase_and_ampl(
f,
z,
0, # The "0" here means plotting the first time bin--only needed to generate the matplotlib objects
axs[0],
axs[1],
pols=pol_list,
deg=deg,
log=log,
)
_plot_coarse_chan_edges(f, axs)
axs[0].set_title(f"{', '.join(pol_list)} of {stations[0]} - {stations[1]}")
time_text = axs[1].text(
0.05,
0.95,
"t = 0.0 s",
transform=axs[1].transAxes,
va="top",
ha="left",
)
axs[1].legend()
plt.tight_layout()
progress = tqdm(range(len(t)), desc="Rendering animation")
ani = animation.FuncAnimation(
fig,
lambda frame, *args: (progress.update(1), animate(frame, *args))[1],
frames=len(t),
interval=50,
blit=True,
save_count=50,
fargs=(scat_phases, scat_ampls, time_text, f, t, z, pol_list, deg, log),
)
if output_animation == "SHOW":
plt.show()
return fig, ani
with tqdm(total=len(t), desc="Saving animation") as pbar:
ani.save(
output_animation,
progress_callback=lambda i, n: pbar.update(1), # noqa: ARG005
)
return fig, ani
[docs]
def main() -> None:
parser = get_parser()
args = parser.parse_args()
_ = make_fringe_animation(
ms_path=args.ms_path,
deg=not args.rad,
log=not args.linear,
tfactor=args.tfactor,
ffactor=args.ffactor,
stations=args.stations,
output_animation=args.output_animation,
delay_samples=args.delay_samples,
intra_chan_delay_samples=args.intra_chan_delay_samples,
pols=args.pols,
)
if __name__ == "__main__":
main()