Source code for low_comm_tools.plotting.baselines

from __future__ import annotations

import argparse
import asyncio
import io
import time
from concurrent.futures import ProcessPoolExecutor
from itertools import combinations
from pathlib import Path
from typing import Any, Literal, NamedTuple

import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from astropy.time import Time
from casacore.tables import table, taql
from matplotlib import colors
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from tqdm.asyncio import tqdm

from low_comm_tools.constants import CORRELATION_POLS
from low_comm_tools.exceptions import VisError
from low_comm_tools.log_config import logger
from low_comm_tools.ms_utils import (
    get_antenna_names_from_ms,
    get_baseline_length,
    get_freq_from_ms,
    get_time_from_table,
)

mpl.use("Agg")


[docs] class SubtableData(NamedTuple):
[docs] masked_data: np.ma.MaskedArray
[docs] station_names: list[str]
[docs] freq_chan: u.Quantity
[docs] times: Time
[docs] ant_1: int
[docs] ant_2: int
[docs] def _read_subtable( ms_path: Path, ant_1: int, ant_2: int, data_column: str, station_names: list[str], freq_chan: u.Quantity, ) -> SubtableData | None: t0 = time.perf_counter() with ( table(ms_path.as_posix(), ack=False) as tab, taql( "select from $tab where " "(ANTENNA1==$ant_1 and ANTENNA2==$ant_2) or " "(ANTENNA1==$ant_2 and ANTENNA2==$ant_1)" ) as subtab, ): _ = tab # keep linters happy if subtab.nrows() == 0: return None data = np.array(subtab.getcol(data_column)) # (time, freq, pol) flags = np.array(subtab.getcol("FLAG")) time_centroid = get_time_from_table(subtab) masked_data = np.ma.masked_array(data, mask=flags) times = time_centroid - time_centroid.min() _, _, npol = data.shape _expected_shape = (len(times), len(freq_chan), npol) if data.shape != _expected_shape: msg = f"Column '{data_column}' has unexpected shape {data.shape} - expected {_expected_shape}" raise VisError(msg) logger.info( f"Done reading baseline {ant_1} - {ant_2} in {time.perf_counter() - t0:.1f}s" ) return SubtableData( masked_data=masked_data, station_names=station_names, freq_chan=freq_chan, times=times, ant_1=ant_1, ant_2=ant_2, )
[docs] def _compute_plot_arrays( subtable_data: SubtableData, plot_type: Literal["spectrum", "delay", "delay-rate"], ) -> tuple[ npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]], npt.NDArray[np.complexfloating[Any, Any]], str, str, ]: if plot_type == "spectrum": return ( subtable_data.freq_chan.to("MHz").value, subtable_data.times.to("s").value, subtable_data.masked_data, "Frequency / MHz", "Time / s", ) delay_time_array = np.fft.fftshift( np.fft.fft(subtable_data.masked_data.filled(0 + 0j), axis=1), axes=1 ) delay_s = np.fft.fftshift( np.fft.fftfreq( n=len(subtable_data.freq_chan.to("Hz").value), d=np.diff(subtable_data.freq_chan.to("Hz").value).mean(), ) ) if plot_type == "delay": return ( delay_s * 1e6, subtable_data.times.to("s").value, delay_time_array, "Delay / µs", "Time / s", ) delay_rate_array = np.fft.fftshift( np.fft.fft2(subtable_data.masked_data.filled(0 + 0j), axes=(0, 1)), axes=(0, 1), ) delay_rate = np.fft.fftshift( np.fft.fftfreq( n=len(subtable_data.times), d=np.diff(subtable_data.times).mean(), ) ) return ( delay_s * 1e6, delay_rate, delay_rate_array, "Delay / µs", "Delay rate / Hz", )
[docs] def _make_plot_sync( x_array: npt.NDArray[np.floating[Any]], y_array: npt.NDArray[np.floating[Any]], z_array: npt.NDArray[np.complexfloating[Any, Any]], ant_1: int, ant_2: int, station_names: list[str], baseline_length: u.Quantity[u.m], data_type: str, xlabel: str | None = None, ylabel: str | None = None, fast_plot: bool = True, norm: Normalize | None = None, ) -> Figure: t0 = time.perf_counter() fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(12, 10)) if norm is None: norm = Normalize() match data_type: case "amp": z_plot = np.abs(z_array) z_label = "Amplitude / Jy" cmap = "viridis" norm.vmin = np.nanmin(z_plot) norm.vmax = np.nanmax(z_plot) case "phase": z_plot = np.rad2deg(np.angle(z_array)) z_label = "Phase / deg" cmap = "twilight" norm.vmin = -180 norm.vmax = +180 case "real": z_plot = np.real(z_array) z_label = "Real / Jy" cmap = "RdBu_r" vmax = np.nanmax(np.abs(z_plot)) norm.vmin = -vmax norm.vmax = +vmax case "imag": z_plot = np.imag(z_array) z_label = "Imag. / Jy" cmap = "RdBu_r" vmax = np.nanmax(np.abs(z_plot)) norm.vmin = -vmax norm.vmax = +vmax case _: msg = f"Unknown `data_type` '{data_type}'" raise ValueError(msg) for i, pol in enumerate(CORRELATION_POLS): ax = axs.flatten()[i] if fast_plot: im = ax.pcolorfast( x_array, y_array, z_plot[:-1, :-1, i], norm=norm, cmap=cmap, ) else: im = ax.pcolormesh( x_array, y_array, z_plot[..., i], norm=norm, cmap=cmap, ) ax.set(aspect="auto", title=pol, xlabel=xlabel, ylabel=ylabel) fig.colorbar(im, ax=ax).set_label(z_label) fig.suptitle( f"Baseline {ant_1}::{ant_2} ({station_names[ant_1]}::{station_names[ant_2]}) ({baseline_length.value:0.0f}m)" ) fig.subplots_adjust( left=0.08, right=0.95, top=0.93, bottom=0.08, hspace=0.3, wspace=0.3 ) logger.info( f"Done making plot for baseline {ant_1} - {ant_2} in {time.perf_counter() - t0:.1f}s" ) return fig
[docs] def _process_baseline_sync( ms_path: Path, ant_1: int, ant_2: int, data_column: str, station_names: list[str], freq_chan: u.Quantity, plot_type: Literal["spectrum", "delay", "delay-rate"], data_type: Literal["amp", "phase", "real", "imag"], fast_plot: bool, norm: Normalize | None, ) -> bytes | None: """Full per-baseline pipeline: read → plot → render. Runs in a worker process.""" logger.info(f"Processing baseline {ant_1} - {ant_2}") t0 = time.perf_counter() subtable_data = _read_subtable( ms_path=ms_path, ant_1=ant_1, ant_2=ant_2, data_column=data_column, station_names=station_names, freq_chan=freq_chan, ) if subtable_data is None: logger.warning(f"No data for baseline {ant_1} - {ant_2}") return None x_array, y_array, z_array, xlabel, ylabel = _compute_plot_arrays( subtable_data, plot_type ) baseline_length = get_baseline_length( ms_path=ms_path, ant_1=ant_1, ant_2=ant_2, ) figure = _make_plot_sync( x_array=x_array, y_array=y_array, z_array=z_array, ant_1=ant_1, ant_2=ant_2, station_names=station_names, baseline_length=baseline_length, data_type=data_type, xlabel=xlabel, ylabel=ylabel, fast_plot=fast_plot, norm=norm, ) t_render = time.perf_counter() buf = io.BytesIO() figure.savefig( buf, format="png", dpi=300, bbox_inches=None, pil_kwargs={"compress_level": 1}, ) plt.close(figure) png_bytes = buf.getvalue() logger.info( f"Done rendering baseline {ant_1} - {ant_2} in {time.perf_counter() - t_render:.1f}s" ) logger.info( f"Done processing baseline {ant_1} - {ant_2} in {time.perf_counter() - t0:.1f}s" ) return png_bytes
[docs] async def process_baseline( ms_path: Path, ant_1: int, ant_2: int, data_column: str, station_names: list[str], freq_chan: u.Quantity, plot_type: Literal["spectrum", "delay", "delay-rate"], data_type: Literal["amp", "phase", "real", "imag"], fast_plot: bool, norm: Normalize | None, out_dir: Path, prefix: str | None, executor: ProcessPoolExecutor, sem: asyncio.Semaphore, ) -> Path | None: loop = asyncio.get_running_loop() async with sem: png_bytes = await loop.run_in_executor( executor, _process_baseline_sync, ms_path, ant_1, ant_2, data_column, station_names, freq_chan, plot_type, data_type, fast_plot, norm, ) if png_bytes is None: return None out_name = f"baseline_{station_names[ant_1]}-{station_names[ant_2]}_{data_column}_{plot_type}_{data_type}.png" if prefix is not None: out_name = f"{prefix}_{out_name}" out_path = out_dir / out_name t0 = time.perf_counter() await asyncio.to_thread(out_path.write_bytes, png_bytes) logger.info(f"Done saving {out_path.name} in {time.perf_counter() - t0:.1f}s") return out_path
[docs] async def plot_baselines( ms_path: Path | str, fast_plot: bool = True, norm: Normalize | None = None, plot_type: Literal["spectrum", "delay", "delay-rate"] = "spectrum", data_column: str = "DATA", data_type: Literal["amp", "phase", "real", "imag"] = "amp", out_dir: Path | None = None, max_concurrent: int = 4, max_baselines: int | None = None, ) -> list[Path | None]: """Plot visibilities for every unique baseline concurrently. Args: ms_path: Path to the measurement set. fast_plot: Use ``pcolorfast`` instead of ``pcolormesh``. Defaults to True. norm: Colour-scale normalisation. Defaults to None (linear). plot_type: One of ``"spectrum"``, ``"delay"``, or ``"delay-rate"``. data_column: MS data column to plot. Defaults to ``"DATA"``. data_type: ``"amp"`` or ``"phase"``. Defaults to ``"amp"``. out_dir: Output directory. Defaults to the MS parent directory. max_concurrent: Number of worker processes and max concurrent baselines. Defaults to 4. """ if isinstance(ms_path, str): ms_path = Path(ms_path) if out_dir is None: out_dir = ms_path.parent station_names = get_antenna_names_from_ms(ms_path) freq_chan = get_freq_from_ms(ms_path) n_ant = len(station_names) sem = asyncio.Semaphore(max_concurrent) with ProcessPoolExecutor(max_workers=max_concurrent) as executor: baseline_pairs = list(combinations(range(n_ant), 2)) if max_baselines is not None: baseline_pairs = baseline_pairs[:max_baselines] coros = [ process_baseline( ms_path=ms_path, ant_1=ant_1, ant_2=ant_2, data_column=data_column, station_names=station_names, freq_chan=freq_chan, plot_type=plot_type, data_type=data_type, fast_plot=fast_plot, norm=norm, out_dir=out_dir, prefix=ms_path.stem, executor=executor, sem=sem, ) for ant_1, ant_2 in baseline_pairs ] return await tqdm.gather(*coros)
[docs] norms: dict[str, colors.Normalize] = { "lin": colors.Normalize(), "log": colors.LogNorm(), "sqrt": colors.PowerNorm(gamma=0.5), }
[docs] def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Plot visibilities per baseline.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("ms_path", type=Path, help="Path to visibilities.") parser.add_argument( "--out-dir", type=Path, default=None, help="Output directory for plots. Defaults to the MS directory.", ) parser.add_argument( "--no-fast-plot", action="store_true", help="Use pcolormesh instead of pcolorfast.", ) parser.add_argument( "--norm", type=str, default="lin", choices=["lin", "log", "sqrt"], help="Plot normalisation.", ) parser.add_argument( "--plot-type", type=str, choices=["spectrum", "delay", "delay-rate"], default="spectrum", help="Type of plot.", ) parser.add_argument( "--data-column", type=str, default="DATA", help="Data column to plot.", ) parser.add_argument( "--data-type", type=str, choices=["amp", "phase", "real", "imag"], default="amp", help="Type of data to plot.", ) parser.add_argument( "--max-concurrent", type=int, default=4, help="Number of worker processes.", ) parser.add_argument( "--max-baselines", type=int, default=None, help="Maximum number of baselines to plot.", ) return parser
[docs] def main() -> None: parser = get_parser() args = parser.parse_args() asyncio.run( plot_baselines( ms_path=args.ms_path, fast_plot=not args.no_fast_plot, norm=norms.get(args.norm), plot_type=args.plot_type, data_column=args.data_column, data_type=args.data_type, out_dir=args.out_dir, max_concurrent=args.max_concurrent, max_baselines=args.max_baselines, ) )
if __name__ == "__main__": main()