"""
Functions for calibration, including creation of gaintables, application of
gaintables, and merging gaintables.
"""
__all__ = [
"append_gaintable",
"create_gaintable_from_rows",
"gaintable_plot",
]
import copy
import logging
from typing import Union
import matplotlib.pyplot as plt
import numpy.linalg
import xarray
from ska_sdp_datamodels.calibration.calibration_model import GainTable
log = logging.getLogger("rascil-logger")
[docs]
def append_gaintable(gt: GainTable, othergt: GainTable) -> GainTable:
"""Append othergt to gt
:param gt:
:param othergt:
:return: GainTable gt + othergt
"""
assert gt.receptor_frame == othergt.receptor_frame
gt.data = numpy.hstack((gt.data, othergt.data))
return gt
[docs]
def create_gaintable_from_rows(
gt: GainTable, rows: numpy.ndarray, makecopy=True
) -> Union[GainTable, None]:
"""Create a GainTable from selected rows
:param gt: GainTable
:param rows: Boolean array of row selection
:param makecopy: Make a deep copy (True)
:return: GainTable
"""
if rows is None or numpy.sum(rows) == 0:
return None
assert (
len(rows) == gt.ntimes
), "Length of rows does not agree with length of GainTable"
if makecopy:
newgt = gt.copy(deep=True)
newgt.data = copy.deepcopy(gt.data[rows])
return newgt
else:
gt.data = copy.deepcopy(gt.data[rows])
return gt
[docs]
def gaintable_plot(
gt: GainTable,
cc="T",
title="",
ants=None,
channels=None,
label_max=0,
min_amp=1e-5,
cmap="rainbow",
**kwargs,
):
"""Standard plot of gain table
:param gt: Gaintable
:param cc: Type of gain table e.g. 'T', 'G, 'B'
:param value: 'amp' or 'phase' or 'residual'
:param ants: Antennas to plot
:param channels: Channels to plot
:param kwargs:
:return:
"""
if ants is None:
ants = range(gt.gaintable_acc.nants)
if channels is None:
channels = range(gt.gaintable_acc.nchan)
if gt.configuration is not None:
labels = [gt.configuration.names[ant] for ant in ants]
else:
labels = ["" for ant in ants]
time_axis = gt["time"].data / 86400.0
ntimes = len(time_axis)
nants = gt.gaintable_acc.nants
nchan = gt.gaintable_acc.nchan
if cc == "B":
fig, ax = plt.subplots(3, 1, sharex=True)
residual = gt["residual"].data[:, channels, 0, 0]
ax[0].imshow(residual, cmap=cmap)
ax[0].set_title("{title} RMS residual {cc}".format(title=title, cc=cc))
ax[0].set_ylabel("RMS residual (Jy)")
amp = numpy.abs(
gt["gain"].data[:, :, channels, 0, 0].reshape([ntimes * nants, nchan])
)
ax[1].imshow(amp, cmap=cmap)
ax[1].set_ylabel("Amplitude")
ax[1].set_title("{title} Amplitude {cc}".format(title=title, cc=cc))
ax[1].xaxis.set_tick_params(labelsize="small")
phase = numpy.angle(
gt["gain"].data[:, :, channels, 0, 0].reshape([ntimes * nants, nchan])
)
ax[2].imshow(phase, cmap=cmap)
ax[2].set_ylabel("Phase (radian)")
ax[2].set_title("{title} Phase {cc}".format(title=title, cc=cc))
ax[2].xaxis.set_tick_params(labelsize="small")
else:
fig, ax = plt.subplots(3, 1, sharex=True)
residual = gt["residual"].data[:, channels, 0, 0]
ax[0].plot(time_axis, residual, ".")
ax[1].set_ylabel("Residual fit (Jy)")
ax[0].set_title("{title} Residual {cc}".format(title=title, cc=cc))
for ant in ants:
amp = numpy.abs(gt["gain"].data[:, ant, channels, 0, 0])
ax[1].plot(
time_axis[amp[:, 0] > min_amp],
amp[amp[:, 0] > min_amp],
".",
label=labels[ant],
)
ax[1].set_ylabel("Amplitude (Jy)")
ax[1].set_title("{title} Amplitude {cc}".format(title=title, cc=cc))
for ant in ants:
amp = numpy.abs(gt["gain"].data[:, ant, channels, 0, 0])
angle = numpy.angle(gt["gain"].data[:, ant, channels, 0, 0])
ax[2].plot(
time_axis[amp[:, 0] > min_amp],
angle[amp[:, 0] > min_amp],
".",
label=labels[ant],
)
ax[2].set_ylabel("Phase (rad)")
ax[2].set_title("{title} Phase {cc}".format(title=title, cc=cc))
ax[2].xaxis.set_tick_params(labelsize=8)
plt.xticks(rotation=0)
if gt.configuration is not None:
if len(gt.configuration.names.data) < label_max:
ax[1].legend()
ax[1][1].legend()
def multiply_gaintables(
gt: GainTable, dgt: GainTable, time_tolerance=1e-3
) -> GainTable:
"""Multiply two gaintables
Returns gt * dgt
:param gt:
:param dgt:
:return:
"""
# Test if times align
mismatch = numpy.max(numpy.abs(gt["time"].data - dgt["time"].data))
if mismatch > time_tolerance:
raise ValueError(
f"Gaintables not aligned in time: max mismatch {mismatch} seconds"
)
if dgt.gaintable_acc.nrec == gt.gaintable_acc.nrec:
if dgt.gaintable_acc.nrec == 2:
gt["gain"].data = numpy.einsum(
"...ik,...ij->...kj", gt["gain"].data, dgt["gain"].data
)
gt["weight"].data *= dgt["weight"].data
elif dgt.gaintable_acc.nrec == 1:
gt["gain"].data *= dgt["gain"].data
gt["weight"].data *= dgt["weight"].data
else:
raise ValueError(
"Gain tables have illegal structures {} {}".format(str(gt), str(dgt))
)
else:
raise ValueError(
"Gain tables have different structures {} {}".format(str(gt), str(dgt))
)
return gt
def concatenate_gaintables(gt_list, dim="time"):
"""Concatenate a list of gaintables
:param gt_list: List of gaintables
:return: Concatendated gaintable
"""
if len(gt_list) == 0:
raise ValueError("GainTable list is empty")
return xarray.concat(
gt_list, dim=dim, data_vars="minimal", coords="minimal", compat="override"
)