from __future__ import annotations
import argparse
from itertools import combinations
from pathlib import Path
from typing import Literal, TypeAlias, cast
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from astropy.visualization import quantity_support
from casacore.tables import table, taql
from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.figure import Figure
from low_comm_tools import ms_utils
from low_comm_tools.constants import CORRELATION_POLS
from low_comm_tools.log_config import ch, logger
from low_comm_tools.plotting.utils import SKA_COLOURS_RGBA_DICT, grid_dims
mpl.use("Agg")
[docs]
Baseline: TypeAlias = tuple[int, int]
[docs]
Triple: TypeAlias = tuple[int, int, int]
[docs]
Quad: TypeAlias = tuple[int, int, int, int]
[docs]
class MeasurementSet:
"""
Visilibities for a set of baselines from the same measurement set
axes
"""
def __init__(self, ms_path: Path, data_column: str = "DATA"):
logger.info(f"Initialising measurement set from:\n\t{ms_path}")
[docs]
self.frequencies = ms_utils.get_freq_from_ms(ms_path)
[docs]
self.times = ms_utils.get_time_from_ms(ms_path)
[docs]
self.data_column = data_column
[docs]
self.station_names = ms_utils.get_antenna_names_from_ms(ms_path)
[docs]
self.visibilities: dict[Baseline, np.ndarray] = {}
[docs]
self.global_vrange = [np.inf, -np.inf]
[docs]
def _get_bandwidth(self) -> u.Quantity[u.Hz]:
channel_width = self.frequencies[1] - self.frequencies[0]
return self.frequencies[-1] - self.frequencies[0] + channel_width
[docs]
def add_baseline_data(
self,
baseline: Baseline,
visibilities: np.ndarray,
force_override: bool = False,
) -> None:
if baseline in self.visibilities and not force_override:
logger.debug(f"Baseline {baseline} already exists, skipping")
return
logger.info(f"Adding baseline {baseline}")
self.visibilities[baseline] = visibilities
[docs]
def import_baselines(
self,
station_names: tuple[str] | None = None,
uv_min_m: float | None = None,
force_override: bool = True,
include_autos: bool = False,
) -> None:
# Construct taql query
# Aim is to get all relevant visibilities in a single read op
self.taql_query = "SELECT ANTENNA1, ANTENNA2, DATA, FLAG, UVW FROM $tab"
self.taql_wheres: list[str] = []
if station_names is not None:
# Remove duplicates
unique_station_names = sorted(set(station_names))
if any(
station_name not in self.station_names
for station_name in unique_station_names
):
err = (
f"Invalid station name. Must be drawn from:\n\t{self.station_names}"
)
raise ValueError(err)
self.included_antennas = [
self.station_names.index(station_name)
for station_name in unique_station_names
]
ant_idxs_str = (
"[" + ",".join(str(ant_idx) for ant_idx in self.included_antennas) + "]"
)
self.taql_wheres.append(
f"ANTENNA1 IN {ant_idxs_str} AND ANTENNA2 IN {ant_idxs_str}"
)
logger.info(f"Importing baselines for stations {station_names}")
else:
self.included_antennas = list(range(len(self.station_names)))
if uv_min_m is not None:
self.taql_wheres.append(f"SQRT(SUMSQUARE(UVW[0:1])) > {uv_min_m}")
if not include_autos:
self.taql_wheres.append("ANTENNA1 != ANTENNA2")
if len(self.taql_wheres) > 0:
self.taql_query += " WHERE " + " AND ".join(self.taql_wheres)
logger.debug(f'TAQL query = "{self.taql_query}"')
with (
table(self.ms_path.as_posix(), ack=False) as tab,
taql(self.taql_query) as subtab,
):
_ = tab # keep linters happy
if subtab.nrows() == 0:
return
data = subtab.getcol(self.data_column)
flags = subtab.getcol("FLAG")
ant_1 = subtab.getcol("ANTENNA1")
ant_2 = subtab.getcol("ANTENNA2")
masked_data = np.ma.masked_array(data, mask=flags)
# Some AI-inspired corner-turn to get the data indexed by baseline.
# The loop is only over baselines
logger.info("Finished importing data from ms, now indexing by baseline")
key = ant_1 * (ant_2.max() + 1) + ant_2
uniq, inv = np.unique(key, return_inverse=True)
for i, _k in enumerate(uniq):
rows = inv == i
baseline = (int(ant_1[rows][0]), int(ant_2[rows][0]))
self.add_baseline_data(
baseline,
masked_data[rows],
force_override=force_override,
)
[docs]
def get_baselines(self) -> list[Baseline]:
return list(self.visibilities.keys())
[docs]
def __getitem__(self, baseline: Baseline) -> np.ndarray:
logger.debug(f"Getting baseline {baseline}")
if baseline in self.visibilities:
return self.visibilities[baseline]
logger.debug(f"Baseline {baseline} not found")
if baseline[::-1] in self.visibilities:
logger.debug(f"Conjugating {baseline[::-1]} --> {baseline}")
# To avoid repeating this conjugation operation when iterating over baselines,
# also store the conjugated visibilities in memory
self.add_baseline_data(
baseline,
np.conj(self.visibilities[baseline[::-1]]),
)
return self.visibilities[baseline]
# If we got this far, the baseline doesn't exist, or wasn't imported
raise KeyError(baseline)
[docs]
def get_abs_visibilities(self, baseline: Baseline) -> np.ndarray:
logger.debug(f"Getting absolute values of baseline {baseline}")
if baseline in self.visibilities:
return np.asarray(np.abs(self.visibilities[baseline]))
if baseline[::-1] in self.visibilities:
return np.asarray(np.abs(self.visibilities[baseline[::-1]]))
raise KeyError(baseline)
[docs]
def _triple_to_baselines(
self, triple: Triple
) -> tuple[Baseline, Baseline, Baseline]:
baseline1 = triple[0], triple[1]
baseline2 = triple[1], triple[2]
baseline3 = triple[2], triple[0]
logger.debug(f"{baseline1 = }, {baseline2 = }, {baseline3 = }")
return baseline1, baseline2, baseline3
[docs]
def calc_closure_phases(self, triple: Triple) -> np.ndarray:
logger.info(f"Calculating closure phase for antennas {triple}")
baseline1, baseline2, baseline3 = self._triple_to_baselines(triple)
triple_products = self[baseline1] * self[baseline2] * self[baseline3]
logger.debug(f"{triple_products.shape = }")
return np.asarray(np.angle(triple_products))
[docs]
def calc_closure_amps(self, quad: Quad) -> np.ndarray:
logger.info(f"Calculating closure amplitudes for antennas {quad}")
baselines = self._find_valid_quad(quad)
if baselines is None:
err = f"{quad} is an invalid subset of stations for calculating closure amplitudes"
raise ValueError(err)
b01, b23, b02, b13 = baselines
num = self.get_abs_visibilities(b01) * self.get_abs_visibilities(b23)
den = self.get_abs_visibilities(b02) * self.get_abs_visibilities(b13)
return np.asarray(num / den)
[docs]
def calc_closures(
self,
station_subset: Triple | Quad,
closure_type: Literal["phase", "amp"] = "phase",
freq_avg_mhz: float | None = None,
) -> tuple[np.ndarray, u.Quantity[u.Hz]]:
if closure_type == "phase":
closures = self.calc_closure_phases(cast(Triple, station_subset))
else: # closure_type == "amp"
closures = self.calc_closure_amps(cast(Quad, station_subset))
# If no averaging is explicitly asked for, average over the whole band
if freq_avg_mhz is None:
return np.nanmean(closures, axis=1), self.frequencies
# Calculate the number of frequency bins that get averaged together
channel_width = (self.frequencies[1] - self.frequencies[0]).to("MHz").value
ffactor = int(np.round(freq_avg_mhz / channel_width))
Nf = len(self.frequencies)
# Do the same for the frequency axis
f_avg = self.frequencies[: (Nf // ffactor) * ffactor]
f_avg = np.nanmean(f_avg.reshape(Nf // ffactor, ffactor), axis=1)
# Truncate as necessary to get int number of downsamples
closures = closures[:, : (Nf // ffactor) * ffactor, :]
closures = np.nanmean(
closures.reshape(
-1, Nf // ffactor, ffactor, 4
), # WARNING: hardcoding 4 pols here
axis=2,
)
return closures, f_avg
[docs]
def plot_closures(
self,
ax: Axes,
station_subset: Triple | Quad,
closure_type: Literal["phase", "amp"] = "phase",
pol_idxs: list[int] | None = None,
freq_avg_mhz: float | None = None,
) -> None:
closures, f_avg = self.calc_closures(
station_subset, closure_type, freq_avg_mhz=freq_avg_mhz
)
logger.info(f"Plotting closure {closure_type}s for {station_subset}")
logger.debug(f"{self.times.datetime.shape = }")
logger.debug(f"{closures.shape = }")
pol_idxs = pol_idxs or list(range(len(CORRELATION_POLS)))
if freq_avg_mhz is None:
for pol_idx in pol_idxs:
# ...then f_avg will be None too, and the closures are purely a function of time
pol = CORRELATION_POLS[pol_idx]
ax.plot(
self.times.datetime,
closures[:, pol_idx],
".",
color=SKA_COLOURS_RGBA_DICT[pol],
label=pol,
)
# Rotate datetimes for readability...
ax.tick_params(axis="x", labelrotation=45)
# ... then the closures will be a function of time *and* frequency, and we need
# to make a waterfall plot. In this case, only plot the first pol in the provided
# list.
elif closure_type == "phase":
ax.pcolormesh(
f_avg.to("MHz").value,
self.times.datetime,
closures[:, :, pol_idxs[0]],
vmin=-np.pi,
vmax=np.pi,
cmap="twilight_shifted",
)
else: # closure_type == "amp"
# Update global data range
self.global_vrange = [
min(np.min(closures[:, :, pol_idxs[0]]), self.global_vrange[0]),
max(np.max(closures[:, :, pol_idxs[0]]), self.global_vrange[1]),
]
# Make the plots
ax.pcolormesh(
f_avg.to("MHz").value,
self.times.datetime,
closures[:, :, pol_idxs[0]],
)
[docs]
def _baseline_exists(self, baseline: Baseline) -> bool:
baselines = self.visibilities.keys()
return baseline in baselines or baseline[::-1] in baselines
[docs]
def _is_valid_triple(self, triple: Triple) -> bool:
baseline1, baseline2, baseline3 = self._triple_to_baselines(triple)
return (
self._baseline_exists(baseline1)
and self._baseline_exists(baseline2)
and self._baseline_exists(baseline3)
)
[docs]
def _find_valid_quad(
self, quad: Quad
) -> tuple[Baseline, Baseline, Baseline, Baseline] | None:
"""
For given antennas (0, 1, 2, 3), will return EITHER
0-1, 2-3, 0-2, 1-3
OR
0-1, 2-3, 0-3, 1-2
OR
0-2, 1-3, 0-3, 1-2
- whichever has all valid baselines - or None, if none of the above
baseline sets can be used.
"""
b01, b02, b03, b12, b13, b23 = list(
combinations(quad, 2),
) # Is this ordering guaranteed?
b0123 = self._baseline_exists(b01) and self._baseline_exists(b23)
b0213 = self._baseline_exists(b02) and self._baseline_exists(b13)
b0312 = self._baseline_exists(b03) and self._baseline_exists(b12)
if b0123 and b0213:
return b01, b23, b02, b13
if b0123 and b0312:
return b01, b23, b03, b12
if b0213 and b0312:
return b02, b13, b03, b12
return None
[docs]
def plot_all_closures(
self,
closure_type: Literal["phase", "amp"] = "phase",
freq_avg_mhz: float | None = None,
) -> Figure:
station_subsets: list[Triple | Quad]
if closure_type == "phase":
# Get all distinct triples of antennas
station_subsets = [
triple
for triple in combinations(self.included_antennas, 3)
if self._is_valid_triple(triple)
]
else: # closure_type == "amp"
station_subsets = [
quad
for quad in combinations(self.included_antennas, 4)
if self._find_valid_quad(quad) is not None
]
# Set up grid of axes
ncols, nrows = grid_dims(len(station_subsets))
fig, axs = plt.subplots(
nrows=nrows,
ncols=ncols,
sharex=True,
sharey=True,
figsize=(8 + 2 * ncols, 6 + 2 * nrows),
)
for ax, station_subset in zip(
np.atleast_1d(axs).flatten(), station_subsets, strict=False
):
self.plot_closures(
ax, station_subset, closure_type, freq_avg_mhz=freq_avg_mhz
)
station_name_list = [
self.station_names[ant_idx] for ant_idx in station_subset
]
ax.set_title(" - ".join(station_name_list))
# Set up the axes appropriately
if freq_avg_mhz is None:
# ... then we are dealing with closure quantity vs time plots
np.atleast_2d(axs)[0, -1].legend(loc="upper right")
if closure_type == "phase":
np.atleast_2d(axs)[0, 0].set_ylim(
[-np.pi, np.pi]
) # Because sharey=True, this applies to all
else:
np.atleast_2d(axs)[0, 0].set_yscale("log")
# np.atleast_2d(axs)[0, 0].set_ylim([0, None])
for ax in np.atleast_2d(axs)[-1, :]: # Bottom row of axes
ax.set_xlabel("Datetime")
for ax in np.atleast_2d(axs)[:, 0]: # Left column of axes
unit_str = " (rad)" if closure_type == "phase" else ""
ax.set_ylabel(f"Closure {closure_type}{unit_str}")
else:
# ... then we are dealing with waterfall plots
for ax in np.atleast_2d(axs)[-1, :]: # Bottom row of axes
ax.set_xlabel("Frequency (MHz)")
for ax in np.atleast_2d(axs)[:, 0]: # Left column of axes
ax.set_ylabel("Datetime")
all_axes = np.atleast_1d(axs).flatten().tolist()
# Make a cax for the big colorbar
top = max(ax.get_position().y1 for ax in all_axes)
bottom = min(ax.get_position().y0 for ax in all_axes)
cax = fig.add_axes((0.92, bottom, 0.015, top - bottom))
if closure_type == "phase":
# Norm and cmap are fixed, so build a ScalarMappable directly
norm = mpl.colors.Normalize(vmin=-np.pi, vmax=np.pi)
sm = mpl.cm.ScalarMappable(norm=norm, cmap="twilight_shifted")
sm.set_array([])
fig.colorbar(sm, cax=cax, label="Closure phase (rad)")
else: # closure_type == "amp"
# For closure amplitudes, we also need to guarantee that all plots
# are using the same vmin and vmax. In this scenario, the parameter
# global_vrange should be set in the class
for ax in all_axes:
mesh = next(
(c for c in ax.get_children() if isinstance(c, QuadMesh)), None
)
if mesh is not None:
mesh.set_clim(self.global_vrange[0], self.global_vrange[1])
# Grab any one mesh as the colorbar mappable
ref_mesh = next(
(
c
for ax in all_axes
for c in ax.get_children()
if isinstance(c, QuadMesh)
),
None,
)
if ref_mesh is not None:
fig.colorbar(ref_mesh, cax=cax, label="Closure amplitude")
# if freq_avg_mhz is None:
# fig.tight_layout()
return fig
[docs]
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Plot closure quantities for all/selected baselines.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("ms_path", type=Path, help="Path to visibilities.")
parser.add_argument(
"--stations",
type=str,
nargs="+",
default=None,
help="Only plot closure quantities involving the specified stations. Defaults to using all baselines.",
)
parser.add_argument(
"--closure-type",
type=str,
choices=["amp", "phase"],
default="phase",
help="Whether to plot closure phases or closure amplitudes. Defaults to phases.",
)
parser.add_argument(
"--out-png",
type=Path,
default=None,
help="Output filename for plot. Defaults to 'closure_{phases/amps}.png'",
)
parser.add_argument(
"--data-column",
type=str,
default="DATA",
help="Data column",
)
parser.add_argument(
"--waterfall",
metavar="FREQ_AVG_MHZ",
type=float,
default=None,
help="Display waterfall plots with frequencies averaged up to channels of width FREQ_AVG_MHZ. The default behaviour is to average over the whole band and plot the closure phases vs time in a stadard scatter plot.",
)
parser.add_argument(
"--uv-min-m",
type=float,
default=None,
help="Only plot closure quantities involving baselines longer than the specified distance",
)
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set logging level. Defaults to 'INFO'",
)
return parser
[docs]
def main() -> None:
parser = get_parser()
args = parser.parse_args()
logger.setLevel(args.log_level)
ch.setLevel(args.log_level)
ms = MeasurementSet(args.ms_path, data_column=args.data_column)
ms.import_baselines(
station_names=args.stations,
uv_min_m=args.uv_min_m,
)
fig = ms.plot_all_closures(args.closure_type, freq_avg_mhz=args.waterfall)
_default_path = (
args.ms_path.parent / f"{args.ms_path.stem}_closure-{args.closure_type}s.png"
)
out_png = args.out_png or _default_path.as_posix()
fig.savefig(out_png, bbox_inches="tight", dpi=150)
if __name__ == "__main__":
main()