Source code for low_comm_tools.plotting.sdp_calibrate

from __future__ import annotations

from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import xarray as xr
from matplotlib.figure import Figure
from ska_sdp_datamodels.calibration.calibration_model import GainTable
from ska_sdp_datamodels.visibility import Visibility

from low_comm_tools.sdp.utils import MetaData, phase


[docs] def plot_gains( vis: Visibility, gaintable: GainTable, metadata: MetaData, ) -> Figure: # ============================================================================ # # gain plots x = vis.frequency.data / 1e6 fig, axs = plt.subplots( 2, metadata.nstation, figsize=(14, 6), sharex=True, sharey=False ) fig.suptitle("gains (EveryBeam-based model)") for k in range(metadata.nstation): ax = axs[0, k] ax.plot(x, np.abs(gaintable.gain.data[0, k, :, 0, 0]), "b", label="J00") ax.plot(x, np.abs(gaintable.gain.data[0, k, :, 0, 1]), "c", label="J01") ax.plot(x, np.abs(gaintable.gain.data[0, k, :, 1, 0]), "m", label="J10") ax.plot(x, np.abs(gaintable.gain.data[0, k, :, 1, 1]), "r", label="J11") ax.grid() ax.set_title(f"|{metadata.stations[k]}|") ax = axs[1, k] ax.plot(x, phase(gaintable.gain.data[0, k, :, 0, 0]), "b") ax.plot(x, phase(gaintable.gain.data[0, k, :, 0, 1]), "c") ax.plot(x, phase(gaintable.gain.data[0, k, :, 1, 0]), "m") ax.plot(x, phase(gaintable.gain.data[0, k, :, 1, 1]), "r") ax.grid() ax.set_title(f"{metadata.stations[k]} phase") ax.set_xlabel("frequency (MHz)") if k == 0: fig.legend() return fig
[docs] class VisibilityPlots(NamedTuple):
[docs] raw_fig: Figure
[docs] model_fig: Figure
[docs] cal_fig: Figure
[docs] def plot_vis( vis: Visibility, calvis: Visibility | xr.Dataset, mdlvis: Visibility, metadata: MetaData, jones_eb: npt.NDArray[np.complex128], centre_correct: bool = False, ) -> VisibilityPlots: # ============================================================================ # # vis plots # If not done earlier, remove central beam response if not centre_correct: invjones_eb = np.linalg.pinv(jones_eb) corvis = vis.copy(deep=True) corvis.vis.data = np.einsum( "bfpx,tbfxy,bfqy->tbfpq", invjones_eb[metadata.ant1], calvis.vis.data.reshape((*vis.vis.shape[:3], 2, 2)), invjones_eb[metadata.ant2].conj(), ).reshape(vis.vis.shape) else: corvis = calvis nbl = len(vis.baselines) x = vis.frequency.data / 1e6 # ylim_abs = (-1, 31) # ylim_re = (-31, 31) # ylim_im = (-31, 31) ylim_abs = np.array([-0.05, 1.05]) * np.max(np.abs(vis.vis.data)) raw_fig, axs = plt.subplots(2, nbl, figsize=(14, 6), sharex=True, sharey=False) raw_fig.suptitle("Raw data") for k in range(nbl): tag = f"{metadata.stations[metadata.ant1[k]]} x {metadata.stations[metadata.ant2[k]]}" ax = axs[0, k] ax.plot(x, np.abs(vis.vis.data[0, k, :, 0]), "b", label="XX") ax.plot(x, np.abs(vis.vis.data[0, k, :, 1]), "c", label="XY") ax.plot(x, np.abs(vis.vis.data[0, k, :, 2]), "m", label="YX") ax.plot(x, np.abs(vis.vis.data[0, k, :, 3]), "r", label="YY") ax.grid() ax.set_ylim(ylim_abs) ax.set_title(f"|{tag}|") ax = axs[1, k] ax.plot(x, phase(vis.vis.data[0, k, :, 0]), "b") ax.plot(x, phase(vis.vis.data[0, k, :, 1]), "c") ax.plot(x, phase(vis.vis.data[0, k, :, 2]), "m") ax.plot(x, phase(vis.vis.data[0, k, :, 3]), "r") ax.grid() ax.set_title(f"{tag} phase") ax.set_xlabel("frequency (MHz)") if k == 0: raw_fig.legend() ylim_abs = np.array([-0.05, 1.05]) * np.max(np.abs(calvis.vis.data)) model_fig, axs = plt.subplots(2, nbl, figsize=(14, 6), sharex=True, sharey=False) model_fig.suptitle("Calibrated vis and EveryBeam-based model") for k in range(nbl): tag = f"{metadata.stations[metadata.ant1[k]]} x {metadata.stations[metadata.ant2[k]]}" ax = axs[0, k] ax.plot( x, np.abs(calvis.vis.data[0, k, :, 0]), "b", alpha=0.3, label="XX calibrated", ) ax.plot( x, np.abs(calvis.vis.data[0, k, :, 1]), "c", alpha=0.3, label="XY calibrated", ) ax.plot( x, np.abs(calvis.vis.data[0, k, :, 2]), "m", alpha=0.3, label="YX calibrated", ) ax.plot( x, np.abs(calvis.vis.data[0, k, :, 3]), "r", alpha=0.3, label="YY calibrated", ) ax.plot(x, np.abs(mdlvis.vis.data[0, k, :, 0]), "b", label="XX model") ax.plot(x, np.abs(mdlvis.vis.data[0, k, :, 1]), "c", label="XY model") ax.plot(x, np.abs(mdlvis.vis.data[0, k, :, 2]), "m", label="YX model") ax.plot(x, np.abs(mdlvis.vis.data[0, k, :, 3]), "r", label="YY model") ax.grid() ax.set_ylim(ylim_abs) ax.set_title(f"|{tag}|") ax = axs[1, k] ax.plot(x, phase(calvis.vis.data[0, k, :, 0]), "b", alpha=0.3) ax.plot(x, phase(calvis.vis.data[0, k, :, 1]), "c", alpha=0.3) ax.plot(x, phase(calvis.vis.data[0, k, :, 2]), "m", alpha=0.3) ax.plot(x, phase(calvis.vis.data[0, k, :, 3]), "r", alpha=0.3) ax.plot(x, phase(mdlvis.vis.data[0, k, :, 0]), "b") ax.plot(x, phase(mdlvis.vis.data[0, k, :, 1]), "c") ax.plot(x, phase(mdlvis.vis.data[0, k, :, 2]), "m") ax.plot(x, phase(mdlvis.vis.data[0, k, :, 3]), "r") ax.grid() ax.set_title(f"{tag} phase") ax.set_xlabel("frequency (MHz)") if k == 0: model_fig.legend() ylim_abs = np.array([-0.05, 1.05]) * np.max(np.abs(corvis.vis.data)) cal_fig, axs = plt.subplots(2, nbl, figsize=(14, 6), sharex=True, sharey=False) cal_fig.suptitle("Beam-corrected, calibratred vis") for k in range(nbl): tag = f"{metadata.stations[metadata.ant1[k]]} x {metadata.stations[metadata.ant2[k]]}" ax = axs[0, k] ax.plot(x, np.abs(corvis.vis.data[0, k, :, 0]), "b", label="XX") ax.plot(x, np.abs(corvis.vis.data[0, k, :, 1]), "c", label="XY") ax.plot(x, np.abs(corvis.vis.data[0, k, :, 2]), "m", label="YX") ax.plot(x, np.abs(corvis.vis.data[0, k, :, 3]), "r", label="YY") ax.grid() ax.set_ylim(ylim_abs) ax.set_title(f"|{tag}|") ax = axs[1, k] ax.plot(x, phase(corvis.vis.data[0, k, :, 0]), "b") ax.plot(x, phase(corvis.vis.data[0, k, :, 1]), "c") ax.plot(x, phase(corvis.vis.data[0, k, :, 2]), "m") ax.plot(x, phase(corvis.vis.data[0, k, :, 3]), "r") ax.grid() ax.set_title(f"{tag} phase") ax.set_xlabel("frequency (MHz)") if k == 0: cal_fig.legend() return VisibilityPlots( raw_fig=raw_fig, model_fig=model_fig, cal_fig=cal_fig, )