Source code for ska_sdp_func_python.calibration.operations

"""
Functions for calibration operations.
"""

__all__ = [
    "apply_gaintable",
    "concatenate_gaintables",
    "multiply_gaintables",
]

import copy
import logging

import numpy.linalg
import xarray
from astropy.time import Time
from ska_sdp_datamodels.calibration.calibration_model import GainTable
from ska_sdp_datamodels.visibility.vis_model import Visibility

log = logging.getLogger("func-python-logger")


[docs] def apply_gaintable( vis: Visibility, gt: GainTable, inverse=False, use_flags=False, ) -> Visibility: """ Apply a GainTable to a Visibility. The corrected visibility is:: V_corrected = {g_i * g_j^*}^-1 V_obs see https://github.com/ratt-ru/fundamentals_of_interferometry/blob/ master/8_Calibration/8_1_calibration_least_squares_problem.ipynb If the visibility data are polarised e.g. polarisation_frame("linear") then the inverse operator represents an actual inverse of the gains. :param vis: Visibility to have gains applied :param gt: GainTable to be applied :param inverse: Apply the inverse (default=False) :return: input Visibility with gains applied """ ntimes, nants, nchan, _, _ = gt.gain.shape if inverse: log.debug("apply_gaintable: Apply inverse gaintable") else: log.debug("apply_gaintable: Apply gaintable") if vis.visibility_acc.npol == 1: log.debug("apply_gaintable: scalar gains") row_numbers = numpy.arange(len(vis.time)) for row in range(ntimes): vis_rows = ( numpy.abs(vis.time.data - gt.time.data[row]) < gt.interval.data[row] / 2.0 ) vis_rows = row_numbers[vis_rows] if len(vis_rows) > 0: # Lookup the gain for this set of visibilities gain = gt["gain"].data[row] cgain = numpy.conjugate(gt["gain"].data[row]) nant = gain.shape[0] # Get the number of channels from Gain nchan = gain.shape[1] baselines = vis.baselines.data # Get the number of channels from Visibility-> vchan vchan = vis["vis"].data.shape[2] gchan_id = numpy.arange(nchan) # Counting repetitions repeat_times = numpy.ceil(vchan / nchan).astype(int) # Repeat elements vchan_to_nchan = numpy.repeat(gchan_id, repeat_times)[:vchan] # Try to ignore visibility flags in application of gains. # Should have no impact # and will save time in applying the flags flagged = ( use_flags and numpy.max(vis["flags"][vis_rows].data) > 0.0 ) if flagged: log.debug("apply_gaintable:Applying flags") original = vis.visibility_acc.flagged_vis[vis_rows] applied = copy.deepcopy(original) appliedwt = copy.deepcopy( vis.visibility_acc.flagged_weight[vis_rows] ) else: log.debug("apply_gaintable:flags are absent or being ignored") original = vis["vis"].data[vis_rows] applied = copy.deepcopy(original) appliedwt = copy.deepcopy(vis["weight"].data[vis_rows]) if vis.visibility_acc.npol == 1: if inverse: lgain = numpy.zeros_like(gain) try: numpy.putmask(lgain, numpy.abs(gain) > 0.0, 1.0 / gain) except FloatingPointError: pass else: lgain = gain # Optimized (SIM-423) # smueller1 = numpy.ones([nchan, nant, nant], dtype='complex') smueller1 = numpy.einsum( "ijlm,kjlm->jik", lgain, numpy.conjugate(lgain) ) for sub_vis_row in range(original.shape[0]): for ibaseline, (a1, a2) in enumerate(baselines): for chan in range(vchan): if ( numpy.abs( smueller1[vchan_to_nchan[chan], a1, a2] ) > 0.0 ): applied[sub_vis_row, ibaseline, chan, 0] = ( original[sub_vis_row, ibaseline, chan, 0] * smueller1[vchan_to_nchan[chan], a1, a2] ) else: applied[sub_vis_row, ibaseline, chan, 0] = 0.0 appliedwt[sub_vis_row, ibaseline, chan, 0] = ( 0.0 ) elif vis.visibility_acc.npol == 2: has_inverse_ant = numpy.zeros([nant, vchan], dtype="bool") if inverse: igain = gain.copy() cigain = cgain.copy() for a1 in range(nants): for chan in range(vchan): try: igain[a1, vchan_to_nchan[chan], :, :] = ( numpy.linalg.inv( gain[a1, vchan_to_nchan[chan], :, :] ) ) cigain[a1, vchan_to_nchan[chan], :, :] = ( numpy.conjugate( igain[a1, vchan_to_nchan[chan], :, :] ) ) has_inverse_ant[a1, chan] = True except numpy.linalg.LinAlgError: has_inverse_ant[a1, chan] = False for sub_vis_row in range(original.shape[0]): for ibaseline, (a1, a2) in enumerate(baselines): for chan in range(vchan): if ( has_inverse_ant[a1, chan] and has_inverse_ant[a2, chan] ): cfs = numpy.diag( original[ sub_vis_row, ibaseline, chan, ... ] ) applied[ sub_vis_row, ibaseline, chan, ... ] = numpy.diag( igain[a1, vchan_to_nchan[chan], :, :]
[docs] @ cfs @ cigain[ a2, vchan_to_nchan[chan], :, : ] ).reshape( [2] ) else: applied[ sub_vis_row, ibaseline, chan, 0 ] = 0.0 appliedwt[ sub_vis_row, ibaseline, chan, 0 ] = 0.0 else: for sub_vis_row in range(original.shape[0]): for ibaseline, (a1, a2) in enumerate(baselines): for chan in range(vchan): cfs = numpy.diag( original[sub_vis_row, ibaseline, chan, ...] ) applied[sub_vis_row, ibaseline, chan, ...] = ( numpy.diag( gain[a1, vchan_to_nchan[chan], :, :] @ cfs @ cgain[a2, vchan_to_nchan[chan], :, :] ).reshape([2]) ) elif vis.visibility_acc.npol == 4: has_inverse_ant = numpy.zeros([nant, vchan], dtype="bool") if inverse: igain = gain.copy() cigain = cgain.copy() for a1 in range(nants): for chan in range(vchan): try: igain[a1, vchan_to_nchan[chan], :, :] = ( numpy.linalg.inv( gain[a1, vchan_to_nchan[chan], :, :] ) ) cigain[a1, vchan_to_nchan[chan], :, :] = ( numpy.conjugate( igain[a1, vchan_to_nchan[chan], :, :] ) ) has_inverse_ant[a1, chan] = True except numpy.linalg.LinAlgError: has_inverse_ant[a1, chan] = False for sub_vis_row in range(original.shape[0]): for ibaseline, baseline in enumerate(baselines): for chan in range(vchan): if ( has_inverse_ant[baseline[0], chan] and has_inverse_ant[baseline[1], chan] ): cfs = original[ sub_vis_row, ibaseline, chan, ... ].reshape([2, 2]) applied[ sub_vis_row, ibaseline, chan, ... ] = ( igain[ baseline[0], vchan_to_nchan[chan], :, :, ] @ cfs @ cigain[ baseline[1], vchan_to_nchan[chan], :, :, ] ).reshape( [4] ) else: applied[ sub_vis_row, ibaseline, chan, ... ] = 0.0 appliedwt[ sub_vis_row, ibaseline, chan, ... ] = 0.0 else: for sub_vis_row in range(original.shape[0]): for ibaseline, baseline in enumerate(baselines): for chan in range(vchan): cfs = original[ sub_vis_row, ibaseline, chan, ... ].reshape([2, 2]) applied[sub_vis_row, ibaseline, chan, ...] = ( gain[ baseline[0], vchan_to_nchan[chan], :, : ] @ cfs @ cgain[ baseline[1], vchan_to_nchan[chan], :, : ] ).reshape([4]) else: times = Time(vis.time / 86400.0, format="mjd", scale="utc") log.warning( "No row in gaintable for visibility " "row, time range %s to %s", times[0].isot, times[-1].isot, ) vis["vis"].data[vis_rows] = applied vis["weight"].data[vis_rows] = appliedwt return vis
def multiply_gaintables( gt: GainTable, dgt: GainTable, time_tolerance=1e-3 ) -> GainTable: """ Multiply two GainTables. Returns gt * dgt. :param gt: First GainTable :param dgt: Second GainTable :param time_tolerance: Maximum tolerance of time separation in the GainTable data :return: Multiplication product """ # Test if times align mismatch = numpy.max(numpy.abs(gt["time"].data - dgt["time"].data)) if mismatch > time_tolerance: raise ValueError( f"Gaintables not aligned in time: max mismatch {mismatch} seconds" ) if dgt.gaintable_acc.nrec == gt.gaintable_acc.nrec: if dgt.gaintable_acc.nrec == 2: gt["gain"].data = numpy.einsum( "...ik,...ij->...kj", gt["gain"].data, dgt["gain"].data ) gt["weight"].data *= dgt["weight"].data elif dgt.gaintable_acc.nrec == 1: gt["gain"].data *= dgt["gain"].data gt["weight"].data *= dgt["weight"].data else: raise ValueError( f"Gain tables have illegal structures {str(gt)} {str(dgt)}" ) else: raise ValueError( f"Gain tables have different structures {str(gt)} {str(dgt)}" ) return gt
[docs] def concatenate_gaintables(gt_list, dim="time"): """ Concatenate a list of GainTables. :param gt_list: List of GainTables :param dim: Dimension to concatenate :return: Concatenated GainTable """ if len(gt_list) == 0: raise ValueError("GainTable list is empty") return xarray.concat( gt_list, dim=dim, data_vars="minimal", coords="minimal", compat="override", )