Source code for low_comm_tools.plotting.closures

from __future__ import annotations

import argparse
from itertools import combinations
from pathlib import Path
from typing import Literal, TypeAlias, cast

import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from astropy.visualization import quantity_support
from casacore.tables import table, taql
from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.figure import Figure

from low_comm_tools import ms_utils
from low_comm_tools.constants import CORRELATION_POLS
from low_comm_tools.log_config import ch, logger
from low_comm_tools.plotting.utils import SKA_COLOURS_RGBA_DICT, grid_dims

mpl.use("Agg")
[docs] _ = quantity_support()
[docs] Baseline: TypeAlias = tuple[int, int]
[docs] Triple: TypeAlias = tuple[int, int, int]
[docs] Quad: TypeAlias = tuple[int, int, int, int]
[docs] class MeasurementSet: """ Visilibities for a set of baselines from the same measurement set axes """ def __init__(self, ms_path: Path, data_column: str = "DATA"): logger.info(f"Initialising measurement set from:\n\t{ms_path}")
[docs] self.ms_path = ms_path
[docs] self.frequencies = ms_utils.get_freq_from_ms(ms_path)
[docs] self.times = ms_utils.get_time_from_ms(ms_path)
[docs] self.data_column = data_column
[docs] self.station_names = ms_utils.get_antenna_names_from_ms(ms_path)
[docs] self.visibilities: dict[Baseline, np.ndarray] = {}
[docs] self.global_vrange = [np.inf, -np.inf]
[docs] def _get_bandwidth(self) -> u.Quantity[u.Hz]: channel_width = self.frequencies[1] - self.frequencies[0] return self.frequencies[-1] - self.frequencies[0] + channel_width
[docs] def add_baseline_data( self, baseline: Baseline, visibilities: np.ndarray, force_override: bool = False, ) -> None: if baseline in self.visibilities and not force_override: logger.debug(f"Baseline {baseline} already exists, skipping") return logger.info(f"Adding baseline {baseline}") self.visibilities[baseline] = visibilities
[docs] def import_baselines( self, station_names: tuple[str] | None = None, uv_min_m: float | None = None, force_override: bool = True, include_autos: bool = False, ) -> None: # Construct taql query # Aim is to get all relevant visibilities in a single read op self.taql_query = "SELECT ANTENNA1, ANTENNA2, DATA, FLAG, UVW FROM $tab" self.taql_wheres: list[str] = [] if station_names is not None: # Remove duplicates unique_station_names = sorted(set(station_names)) if any( station_name not in self.station_names for station_name in unique_station_names ): err = ( f"Invalid station name. Must be drawn from:\n\t{self.station_names}" ) raise ValueError(err) self.included_antennas = [ self.station_names.index(station_name) for station_name in unique_station_names ] ant_idxs_str = ( "[" + ",".join(str(ant_idx) for ant_idx in self.included_antennas) + "]" ) self.taql_wheres.append( f"ANTENNA1 IN {ant_idxs_str} AND ANTENNA2 IN {ant_idxs_str}" ) logger.info(f"Importing baselines for stations {station_names}") else: self.included_antennas = list(range(len(self.station_names))) if uv_min_m is not None: self.taql_wheres.append(f"SQRT(SUMSQUARE(UVW[0:1])) > {uv_min_m}") if not include_autos: self.taql_wheres.append("ANTENNA1 != ANTENNA2") if len(self.taql_wheres) > 0: self.taql_query += " WHERE " + " AND ".join(self.taql_wheres) logger.debug(f'TAQL query = "{self.taql_query}"') with ( table(self.ms_path.as_posix(), ack=False) as tab, taql(self.taql_query) as subtab, ): _ = tab # keep linters happy if subtab.nrows() == 0: return data = subtab.getcol(self.data_column) flags = subtab.getcol("FLAG") ant_1 = subtab.getcol("ANTENNA1") ant_2 = subtab.getcol("ANTENNA2") masked_data = np.ma.masked_array(data, mask=flags) # Some AI-inspired corner-turn to get the data indexed by baseline. # The loop is only over baselines logger.info("Finished importing data from ms, now indexing by baseline") key = ant_1 * (ant_2.max() + 1) + ant_2 uniq, inv = np.unique(key, return_inverse=True) for i, _k in enumerate(uniq): rows = inv == i baseline = (int(ant_1[rows][0]), int(ant_2[rows][0])) self.add_baseline_data( baseline, masked_data[rows], force_override=force_override, )
[docs] def get_baselines(self) -> list[Baseline]: return list(self.visibilities.keys())
[docs] def __getitem__(self, baseline: Baseline) -> np.ndarray: logger.debug(f"Getting baseline {baseline}") if baseline in self.visibilities: return self.visibilities[baseline] logger.debug(f"Baseline {baseline} not found") if baseline[::-1] in self.visibilities: logger.debug(f"Conjugating {baseline[::-1]} --> {baseline}") # To avoid repeating this conjugation operation when iterating over baselines, # also store the conjugated visibilities in memory self.add_baseline_data( baseline, np.conj(self.visibilities[baseline[::-1]]), ) return self.visibilities[baseline] # If we got this far, the baseline doesn't exist, or wasn't imported raise KeyError(baseline)
[docs] def get_abs_visibilities(self, baseline: Baseline) -> np.ndarray: logger.debug(f"Getting absolute values of baseline {baseline}") if baseline in self.visibilities: return np.asarray(np.abs(self.visibilities[baseline])) if baseline[::-1] in self.visibilities: return np.asarray(np.abs(self.visibilities[baseline[::-1]])) raise KeyError(baseline)
[docs] def _triple_to_baselines( self, triple: Triple ) -> tuple[Baseline, Baseline, Baseline]: baseline1 = triple[0], triple[1] baseline2 = triple[1], triple[2] baseline3 = triple[2], triple[0] logger.debug(f"{baseline1 = }, {baseline2 = }, {baseline3 = }") return baseline1, baseline2, baseline3
[docs] def calc_closure_phases(self, triple: Triple) -> np.ndarray: logger.info(f"Calculating closure phase for antennas {triple}") baseline1, baseline2, baseline3 = self._triple_to_baselines(triple) triple_products = self[baseline1] * self[baseline2] * self[baseline3] logger.debug(f"{triple_products.shape = }") return np.asarray(np.angle(triple_products))
[docs] def calc_closure_amps(self, quad: Quad) -> np.ndarray: logger.info(f"Calculating closure amplitudes for antennas {quad}") baselines = self._find_valid_quad(quad) if baselines is None: err = f"{quad} is an invalid subset of stations for calculating closure amplitudes" raise ValueError(err) b01, b23, b02, b13 = baselines num = self.get_abs_visibilities(b01) * self.get_abs_visibilities(b23) den = self.get_abs_visibilities(b02) * self.get_abs_visibilities(b13) return np.asarray(num / den)
[docs] def calc_closures( self, station_subset: Triple | Quad, closure_type: Literal["phase", "amp"] = "phase", freq_avg_mhz: float | None = None, ) -> tuple[np.ndarray, u.Quantity[u.Hz]]: if closure_type == "phase": closures = self.calc_closure_phases(cast(Triple, station_subset)) else: # closure_type == "amp" closures = self.calc_closure_amps(cast(Quad, station_subset)) # If no averaging is explicitly asked for, average over the whole band if freq_avg_mhz is None: return np.nanmean(closures, axis=1), self.frequencies # Calculate the number of frequency bins that get averaged together channel_width = (self.frequencies[1] - self.frequencies[0]).to("MHz").value ffactor = int(np.round(freq_avg_mhz / channel_width)) Nf = len(self.frequencies) # Do the same for the frequency axis f_avg = self.frequencies[: (Nf // ffactor) * ffactor] f_avg = np.nanmean(f_avg.reshape(Nf // ffactor, ffactor), axis=1) # Truncate as necessary to get int number of downsamples closures = closures[:, : (Nf // ffactor) * ffactor, :] closures = np.nanmean( closures.reshape( -1, Nf // ffactor, ffactor, 4 ), # WARNING: hardcoding 4 pols here axis=2, ) return closures, f_avg
[docs] def plot_closures( self, ax: Axes, station_subset: Triple | Quad, closure_type: Literal["phase", "amp"] = "phase", pol_idxs: list[int] | None = None, freq_avg_mhz: float | None = None, ) -> None: closures, f_avg = self.calc_closures( station_subset, closure_type, freq_avg_mhz=freq_avg_mhz ) logger.info(f"Plotting closure {closure_type}s for {station_subset}") logger.debug(f"{self.times.datetime.shape = }") logger.debug(f"{closures.shape = }") pol_idxs = pol_idxs or list(range(len(CORRELATION_POLS))) if freq_avg_mhz is None: for pol_idx in pol_idxs: # ...then f_avg will be None too, and the closures are purely a function of time pol = CORRELATION_POLS[pol_idx] ax.plot( self.times.datetime, closures[:, pol_idx], ".", color=SKA_COLOURS_RGBA_DICT[pol], label=pol, ) # Rotate datetimes for readability... ax.tick_params(axis="x", labelrotation=45) # ... then the closures will be a function of time *and* frequency, and we need # to make a waterfall plot. In this case, only plot the first pol in the provided # list. elif closure_type == "phase": ax.pcolormesh( f_avg.to("MHz").value, self.times.datetime, closures[:, :, pol_idxs[0]], vmin=-np.pi, vmax=np.pi, cmap="twilight_shifted", ) else: # closure_type == "amp" # Update global data range self.global_vrange = [ min(np.min(closures[:, :, pol_idxs[0]]), self.global_vrange[0]), max(np.max(closures[:, :, pol_idxs[0]]), self.global_vrange[1]), ] # Make the plots ax.pcolormesh( f_avg.to("MHz").value, self.times.datetime, closures[:, :, pol_idxs[0]], )
[docs] def _baseline_exists(self, baseline: Baseline) -> bool: baselines = self.visibilities.keys() return baseline in baselines or baseline[::-1] in baselines
[docs] def _is_valid_triple(self, triple: Triple) -> bool: baseline1, baseline2, baseline3 = self._triple_to_baselines(triple) return ( self._baseline_exists(baseline1) and self._baseline_exists(baseline2) and self._baseline_exists(baseline3) )
[docs] def _find_valid_quad( self, quad: Quad ) -> tuple[Baseline, Baseline, Baseline, Baseline] | None: """ For given antennas (0, 1, 2, 3), will return EITHER 0-1, 2-3, 0-2, 1-3 OR 0-1, 2-3, 0-3, 1-2 OR 0-2, 1-3, 0-3, 1-2 - whichever has all valid baselines - or None, if none of the above baseline sets can be used. """ b01, b02, b03, b12, b13, b23 = list( combinations(quad, 2), ) # Is this ordering guaranteed? b0123 = self._baseline_exists(b01) and self._baseline_exists(b23) b0213 = self._baseline_exists(b02) and self._baseline_exists(b13) b0312 = self._baseline_exists(b03) and self._baseline_exists(b12) if b0123 and b0213: return b01, b23, b02, b13 if b0123 and b0312: return b01, b23, b03, b12 if b0213 and b0312: return b02, b13, b03, b12 return None
[docs] def plot_all_closures( self, closure_type: Literal["phase", "amp"] = "phase", freq_avg_mhz: float | None = None, ) -> Figure: station_subsets: list[Triple | Quad] if closure_type == "phase": # Get all distinct triples of antennas station_subsets = [ triple for triple in combinations(self.included_antennas, 3) if self._is_valid_triple(triple) ] else: # closure_type == "amp" station_subsets = [ quad for quad in combinations(self.included_antennas, 4) if self._find_valid_quad(quad) is not None ] # Set up grid of axes ncols, nrows = grid_dims(len(station_subsets)) fig, axs = plt.subplots( nrows=nrows, ncols=ncols, sharex=True, sharey=True, figsize=(8 + 2 * ncols, 6 + 2 * nrows), ) for ax, station_subset in zip( np.atleast_1d(axs).flatten(), station_subsets, strict=False ): self.plot_closures( ax, station_subset, closure_type, freq_avg_mhz=freq_avg_mhz ) station_name_list = [ self.station_names[ant_idx] for ant_idx in station_subset ] ax.set_title(" - ".join(station_name_list)) # Set up the axes appropriately if freq_avg_mhz is None: # ... then we are dealing with closure quantity vs time plots np.atleast_2d(axs)[0, -1].legend(loc="upper right") if closure_type == "phase": np.atleast_2d(axs)[0, 0].set_ylim( [-np.pi, np.pi] ) # Because sharey=True, this applies to all else: np.atleast_2d(axs)[0, 0].set_yscale("log") # np.atleast_2d(axs)[0, 0].set_ylim([0, None]) for ax in np.atleast_2d(axs)[-1, :]: # Bottom row of axes ax.set_xlabel("Datetime") for ax in np.atleast_2d(axs)[:, 0]: # Left column of axes unit_str = " (rad)" if closure_type == "phase" else "" ax.set_ylabel(f"Closure {closure_type}{unit_str}") else: # ... then we are dealing with waterfall plots for ax in np.atleast_2d(axs)[-1, :]: # Bottom row of axes ax.set_xlabel("Frequency (MHz)") for ax in np.atleast_2d(axs)[:, 0]: # Left column of axes ax.set_ylabel("Datetime") all_axes = np.atleast_1d(axs).flatten().tolist() # Make a cax for the big colorbar top = max(ax.get_position().y1 for ax in all_axes) bottom = min(ax.get_position().y0 for ax in all_axes) cax = fig.add_axes((0.92, bottom, 0.015, top - bottom)) if closure_type == "phase": # Norm and cmap are fixed, so build a ScalarMappable directly norm = mpl.colors.Normalize(vmin=-np.pi, vmax=np.pi) sm = mpl.cm.ScalarMappable(norm=norm, cmap="twilight_shifted") sm.set_array([]) fig.colorbar(sm, cax=cax, label="Closure phase (rad)") else: # closure_type == "amp" # For closure amplitudes, we also need to guarantee that all plots # are using the same vmin and vmax. In this scenario, the parameter # global_vrange should be set in the class for ax in all_axes: mesh = next( (c for c in ax.get_children() if isinstance(c, QuadMesh)), None ) if mesh is not None: mesh.set_clim(self.global_vrange[0], self.global_vrange[1]) # Grab any one mesh as the colorbar mappable ref_mesh = next( ( c for ax in all_axes for c in ax.get_children() if isinstance(c, QuadMesh) ), None, ) if ref_mesh is not None: fig.colorbar(ref_mesh, cax=cax, label="Closure amplitude") # if freq_avg_mhz is None: # fig.tight_layout() return fig
[docs] def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Plot closure quantities for all/selected baselines.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("ms_path", type=Path, help="Path to visibilities.") parser.add_argument( "--stations", type=str, nargs="+", default=None, help="Only plot closure quantities involving the specified stations. Defaults to using all baselines.", ) parser.add_argument( "--closure-type", type=str, choices=["amp", "phase"], default="phase", help="Whether to plot closure phases or closure amplitudes. Defaults to phases.", ) parser.add_argument( "--out-png", type=Path, default=None, help="Output filename for plot. Defaults to 'closure_{phases/amps}.png'", ) parser.add_argument( "--data-column", type=str, default="DATA", help="Data column", ) parser.add_argument( "--waterfall", metavar="FREQ_AVG_MHZ", type=float, default=None, help="Display waterfall plots with frequencies averaged up to channels of width FREQ_AVG_MHZ. The default behaviour is to average over the whole band and plot the closure phases vs time in a stadard scatter plot.", ) parser.add_argument( "--uv-min-m", type=float, default=None, help="Only plot closure quantities involving baselines longer than the specified distance", ) parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set logging level. Defaults to 'INFO'", ) return parser
[docs] def main() -> None: parser = get_parser() args = parser.parse_args() logger.setLevel(args.log_level) ch.setLevel(args.log_level) ms = MeasurementSet(args.ms_path, data_column=args.data_column) ms.import_baselines( station_names=args.stations, uv_min_m=args.uv_min_m, ) fig = ms.plot_all_closures(args.closure_type, freq_avg_mhz=args.waterfall) _default_path = ( args.ms_path.parent / f"{args.ms_path.stem}_closure-{args.closure_type}s.png" ) out_png = args.out_png or _default_path.as_posix() fig.savefig(out_png, bbox_inches="tight", dpi=150)
if __name__ == "__main__": main()