Source code for ska_sdp_func_python.calibration.solvers

"""
Functions to solve for antenna/station gain.
"""

__all__ = ["solve_gaintable"]

import logging

import numpy
import scipy
from ska_sdp_datamodels.calibration.calibration_create import (
    create_gaintable_from_visibility,
)
from ska_sdp_datamodels.calibration.calibration_model import GainTable
from ska_sdp_datamodels.visibility.vis_model import Visibility

from ska_sdp_func_python.visibility.operations import divide_visibility

log = logging.getLogger("func-python-logger")


def find_best_refant_from_vis(vis):
    """
    This method comes from katsdpcal.
    (https://github.com/ska-sa/katsdpcal/blob/
    200c2f6e60b2540f0a89e7b655b26a2b04a8f360/katsdpcal/calprocs.py#L332)
    Determine antenna whose FFT has the maximum peak to noise ratio (PNR) by
    taking the median PNR of the FFT over all baselines to each antenna.

    When the input vis has only one channel, this uses all the vis of the
    same antenna for the operations peak, mean and std.

    :param vis: Visibilities
    :return: Array of indices of antennas in decreasing order
            of median of PNR over all baselines

    """
    visdata = vis.visibility_acc.flagged_vis
    _, _, nchan, _ = visdata.shape
    baselines = numpy.array(vis.baselines.data.tolist())
    nants = vis.visibility_acc.nants
    med_pnr_ants = numpy.zeros((nants))
    if nchan == 1:
        weightdata = vis.visibility_acc.flagged_weight
        for a in range(nants):
            mask = (baselines[:, 0] == a) ^ (baselines[:, 1] == a)
            weightdata_ant = weightdata[:, mask]
            mean_of_weight_ant = numpy.sum(weightdata_ant)
            med_pnr_ants[a] = mean_of_weight_ant
        med_pnr_ants += numpy.linspace(1e-8, 1e-9, nants)
    else:
        ft_vis = scipy.fftpack.fft(visdata, axis=2)
        max_value_arg = numpy.argmax(numpy.abs(ft_vis), axis=2)
        index = numpy.array(
            [numpy.roll(range(nchan), -n) for n in max_value_arg.ravel()]
        )
        index = index.reshape(list(max_value_arg.shape) + [nchan])
        index = numpy.transpose(index, (0, 1, 3, 2))
        ft_vis = numpy.take_along_axis(ft_vis, index, axis=2)

        peak = numpy.max(numpy.abs(ft_vis), axis=2)

        chan_slice = numpy.s_[
            nchan // 2 - nchan // 4 : nchan // 2 + nchan // 4 + 1
        ]
        mean = numpy.mean(numpy.abs(ft_vis[:, :, chan_slice]), axis=2)
        std = numpy.std(numpy.abs(ft_vis[:, :, chan_slice]), axis=2) + 1e-9
        for a in range(nants):
            mask = (baselines[:, 0] == a) ^ (baselines[:, 1] == a)

            pnr = (peak[:, mask] - mean[:, mask]) / std[:, mask]
            med_pnr = numpy.median(pnr)
            med_pnr_ants[a] = med_pnr
    return numpy.argsort(med_pnr_ants)[::-1]


[docs] def solve_gaintable( vis: Visibility, modelvis: Visibility = None, gain_table=None, phase_only=True, niter=200, tol=1e-6, crosspol=False, normalise_gains="mean", jones_type="T", timeslice=None, refant=0, ) -> GainTable: """ Solve a gain table by fitting an observed visibility to a model visibility. If modelvis is None, a point source model is assumed. :param vis: Visibility containing the observed data_models :param modelvis: Visibility containing the visibility predicted by a model :param gain_table: Existing gaintable :param phase_only: Solve only for the phases (default=True) :param niter: Number of iterations (default 30) :param tol: Iteration stops when the fractional change in the gain solution is below this tolerance :param crosspol: Do solutions including cross polarisations i.e. XY, YX or RL, LR :param normalise_gains: Normalises the gains (default="mean") options are None, "mean", "median". None means no normalization. :param jones_type: Type of calibration matrix T or G or B :param timeslice: Time interval between solutions (s) :param refant: Reference antenna (default 0) :return: GainTable containing solution """ if modelvis is not None: # pylint: disable=unneeded-not if not numpy.max(numpy.abs(modelvis.vis)) > 0.0: raise ValueError("solve_gaintable: Model visibility is zero") point_vis = ( divide_visibility(vis, modelvis) if modelvis is not None else vis ) if phase_only: log.debug("solve_gaintable: Solving for phase only") else: log.debug("solve_gaintable: Solving for complex gain") if gain_table is None: log.debug("solve_gaintable: creating new gaintable") gain_table = create_gaintable_from_visibility( vis, jones_type=jones_type, timeslice=timeslice ) else: log.debug("solve_gaintable: starting from existing gaintable") nants = gain_table.gaintable_acc.nants nchan = gain_table.gaintable_acc.nchan npol = point_vis.visibility_acc.npol axes = (0, 2) if nchan == 1 else 0 for row, time in enumerate(gain_table.time): time_slice = { "time": slice( time - gain_table.interval[row] / 2, time + gain_table.interval[row] / 2, ) } pointvis_sel = point_vis.sel(time_slice) # pylint: disable=unneeded-not if not pointvis_sel.visibility_acc.ntimes > 0: log.warning( "Gaintable %s, vis time mismatch %s", gain_table.time, vis.time ) continue refant_sort = find_best_refant_from_vis(pointvis_sel) x_b = numpy.sum( (pointvis_sel.vis.data * pointvis_sel.weight.data) * (1 - pointvis_sel.flags.data), axis=axes, ) xwt_b = numpy.sum( pointvis_sel.weight.data * (1 - pointvis_sel.flags.data), axis=axes, ) x = numpy.zeros([nants, nants, nchan, npol], dtype="complex") xwt = numpy.zeros([nants, nants, nchan, npol]) for ibaseline, (a1, a2) in enumerate(point_vis.baselines.data): x[a1, a2, ...] = numpy.conjugate(x_b[ibaseline, ...]) xwt[a1, a2, ...] = xwt_b[ibaseline, ...] x[a2, a1, ...] = x_b[ibaseline, ...] xwt[a2, a1, ...] = xwt_b[ibaseline, ...] mask = numpy.abs(xwt) > 0.0 if numpy.sum(mask) > 0: _solve_with_mask( crosspol, gain_table, mask, niter, phase_only, row, tol, vis, x, xwt, refant, refant_sort, ) else: gain_table["gain"].data[row, ...] = 1.0 + 0.0j gain_table["weight"].data[row, ...] = 0.0 gain_table["residual"].data[row, ...] = 0.0 if normalise_gains in ["median", "mean"] and not phase_only: normaliser = { "median": numpy.median, "mean": numpy.mean, } gabs = normaliser[normalise_gains]( numpy.abs(gain_table["gain"].data[:]) ) gain_table["gain"].data[:] /= gabs return gain_table
def _solve_with_mask( crosspol, gain_table, mask, niter, phase_only, row, tol, vis, x, xwt, refant, refant_sort, ): """ Method extracted from solve_gaintable to decrease complexity. Calculations when `numpy.sum(mask) > 0` """ x_shape = x.shape x[mask] = x[mask] / xwt[mask] x[~mask] = 0.0 xwt[mask] = xwt[mask] / numpy.max(xwt[mask]) xwt[~mask] = 0.0 x = x.reshape(x_shape) if vis.visibility_acc.npol == 2 or ( vis.visibility_acc.npol == 4 and not crosspol ): ( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], gain_table["residual"].data[row, ...], ) = _solve_antenna_gains_itsubs_nocrossdata( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], x, xwt, phase_only=phase_only, niter=niter, tol=tol, refant=refant, refant_sort=refant_sort, ) elif vis.visibility_acc.npol == 4 and crosspol: ( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], gain_table["residual"].data[row, ...], ) = _solve_antenna_gains_itsubs_matrix( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], x, xwt, phase_only=phase_only, niter=niter, tol=tol, refant=refant, refant_sort=refant_sort, ) else: ( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], gain_table["residual"].data[row, ...], ) = _solve_antenna_gains_itsubs_scalar( gain_table["gain"].data[row, ...], gain_table["weight"].data[row, ...], x, xwt, phase_only=phase_only, niter=niter, tol=tol, refant=refant, refant_sort=refant_sort, ) def _determine_refant(refant, bad_ant, refant_sort): """ Determine the final reference antenna :param refant: the given reference antenna :param bad_ant: a list including all bad antennas :param refant_sort: a list with the decrease order of reference antenna :return reference antenna """ if refant in bad_ant: # Keep the original value of refant thisrefant = refant for ant_id in refant_sort: if ant_id not in bad_ant: # fetch an antenna from the refant_sort list and further judge # if the antenna is not a bad antenna, if Yes, this antenna # would be the final reference antenna refant = ant_id log.warning( "warning, ant: %d is masked, \ change refant to ant: %d", thisrefant, refant, ) break else: # If we cannot find a reference antenna, we have to use # the original refant kept in thisrefant log.warning( "warning, Cannot find a suitable reference antenna,\ use initial settings: %d", thisrefant, ) refant = thisrefant def _solve_antenna_gains_itsubs_scalar( gain, gwt, x, xwt, niter=200, tol=1e-6, phase_only=True, damping=0.5, refant=0, refant_sort=None, ): """Solve for the antenna gains. x(antenna2, antenna1) = gain(antenna1) conj(gain(antenna2)) This uses an iterative substitution algorithm due to Larry D'Addario c 1980'ish (see ThompsonDaddario1982 Appendix 1). Used in the original VLA Dec-10 Antsol. :param gain: gains :param gwt: gain weight :param x: Equivalent point source visibility[nants, nants, ...] :param xwt: Equivalent point source weight [nants, nants, ...] :param niter: Number of iterations :param tol: tolerance on solution change :param phase_only: Do solution for only the phase? (default True) :param refant: Reference antenna for phase (default=0) :param refant_sort: Sorted list of reference antenna :param damping: Damping parameter :return: gain [nants, ...], weight [nants, ...] """ if refant_sort is None: refant_sort = [] nants = x.shape[0] # Optimized i_diag = numpy.diag_indices(nants, nants) x[i_diag[0], i_diag[1], ...] = 0.0 xwt[i_diag[0], i_diag[1], ...] = 0.0 i_lower = numpy.tril_indices(nants, -1) i_upper = (i_lower[1], i_lower[0]) x[i_upper] = numpy.conjugate(x[i_lower]) xwt[i_upper] = xwt[i_lower] reduce_oneside_x = numpy.abs(numpy.einsum("ij...->j...", x * xwt)) gainmask = reduce_oneside_x <= 0.0 bad_ant = [] for iant in range(nants): thismask = gainmask[iant, 0] if numpy.all(thismask) is True: bad_ant.append(iant) _determine_refant(refant, bad_ant, refant_sort) numpy.putmask(gain, gainmask, 0.0) for _ in range(niter): gainLast = gain gain, gwt = _gain_substitution_scalar(gain, x, xwt) if phase_only: mask = numpy.abs(gain) > 0.0 gain[mask] = gain[mask] / numpy.abs(gain[mask]) gain = (1.0 - damping) * gain + damping * gainLast change = numpy.max(numpy.abs(gain - gainLast)) if change < tol: if phase_only: mask = numpy.abs(gain) > 0.0 gain[mask] = gain[mask] / numpy.abs(gain[mask]) angles = numpy.angle(gain) gain *= numpy.exp(-1j * angles)[refant, ...] numpy.putmask(gain, gainmask, 1.0) return gain, gwt, _solution_residual_scalar(gain, x, xwt) log.warning( "solve_antenna_gains_itsubs_scalar: " "gain solution failed, retaining gain solutions" ) if phase_only: mask = numpy.abs(gain) > 0.0 gain[mask] = gain[mask] / numpy.abs(gain[mask]) angles = numpy.angle(gain) gain *= numpy.exp(-1j * angles)[refant, ...] numpy.putmask(gain, gainmask, 1.0) return gain, gwt, _solution_residual_scalar(gain, x, xwt) def _gain_substitution_scalar(gain, x, xwt): """ Substitute gains across all baselines of gain for point source equivalent visibilities. TODO: Check this function description :param gain: gains (numpy.array of shape [nant, nchan, nrec, nrec]) :param x: Equivalent point source visibility [nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan]? :return: gain [nants, nchan, nrec, nrec], weight [nants, nchan, nrec, nrec] """ nants, nchan, nrec, _ = gain.shape newgain1 = numpy.ones_like(gain, dtype="complex128") gwt1 = numpy.zeros_like(gain, dtype="double") xxwt = x * xwt[:, :, :] cgain = numpy.conjugate(gain) gcg = gain[:, :] * cgain[:, :] n_top = numpy.einsum("ik...,ijk...->jk...", gain, xxwt) n_bot = numpy.einsum("ik...,ijk...->jk...", gcg, xwt).real # Convert mask to putmask numpy.putmask(newgain1, n_bot > 0.0, n_top / n_bot) numpy.putmask(newgain1, n_bot <= 0.0, 0.0) gwt1[:, :] = n_bot numpy.putmask(gwt1, n_bot <= 0.0, 0.0) newgain1 = newgain1.reshape([nants, nchan, nrec, nrec]) gwt1 = gwt1.reshape([nants, nchan, nrec, nrec]) return newgain1, gwt1 def _solve_antenna_gains_itsubs_nocrossdata( gain, gwt, x, xwt, niter=200, tol=1e-6, phase_only=True, refant=0, refant_sort=None, ): """Solve for the antenna gains using full matrix expressions, but no cross hands. x(antenna2, antenna1) = gain(antenna1) conj(gain(antenna2)) See Appendix D, section D.1 in: J. P. Hamaker, “Understanding radio polarimetry - IV. The full-coherency analogue of scalar self-calibration: Self-alignment, dynamic range and polarimetric fidelity,” Astronomy and Astrophysics Supplement Series, vol. 143, no. 3, pp. 515–534, May 2000. :param gain: gains :param gwt: gain weight :param x: Equivalent point source visibility [nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan] :param niter: Number of iterations :param tol: tolerance on solution change :param phase_only: Do solution for only the phase? (default True) :param refant: Reference antenna for phase (default=0) :param refant_sort: Sorted list of reference antenna :return: gain [nants, nchan, nrec, nrec], weight [nants, nchan, nrec, nrec] """ # This implementation is sub-optimal. TODO: Reimplement IQ, IV calibration if refant_sort is None: refant_sort = [] nants, _, nchan, npol = x.shape if npol == 2: newshape = (nants, nants, nchan, 4) x_fill = numpy.zeros(newshape, dtype="complex") x_fill[..., 0] = x[..., 0] x_fill[..., 3] = x[..., 1] xwt_fill = numpy.zeros(newshape, dtype="float") xwt_fill[..., 0] = xwt[..., 0] xwt_fill[..., 3] = xwt[..., 1] else: x_fill = x x_fill[..., 1] = 0.0 x_fill[..., 2] = 0.0 xwt_fill = xwt xwt_fill[..., 1] = 0.0 xwt_fill[..., 2] = 0.0 return _solve_antenna_gains_itsubs_matrix( gain, gwt, x_fill, xwt_fill, niter=niter, tol=tol, phase_only=phase_only, refant=refant, refant_sort=refant_sort, ) def _solve_antenna_gains_itsubs_matrix( gain, gwt, x, xwt, niter=200, tol=1e-6, phase_only=True, refant=0, refant_sort=None, ): """Solve for the antenna gains using full matrix expressions. x(antenna2, antenna1) = gain(antenna1) conj(gain(antenna2)) See Appendix D, section D.1 in: J. P. Hamaker, “Understanding radio polarimetry - IV. The full-coherency analogue of scalar self-calibration: Self-alignment, dynamic range and polarimetric fidelity,” Astronomy and Astrophysics Supplement Series, vol. 143, no. 3, pp. 515–534, May 2000. :param gain: gains :param gwt: gain weight :param x: Equivalent point source visibility[nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan] :param niter: Number of iterations :param tol: tolerance on solution change :param phase_only: Do solution for only the phase? (default True) :param refant: Reference antenna for phase (default=0) :param refant_sort: Sorted list of reference antenna :return: gain [nants, nchan, nrec, nrec], weight [nants, nchan, nrec, nrec] """ if refant_sort is None: refant_sort = [] nants, _, nchan, npol = x.shape assert npol == 4 newshape = (nants, nants, nchan, 2, 2) x = x.reshape(newshape) xwt = xwt.reshape(newshape) # Optimzied i_diag = numpy.diag_indices(nants, nants) x[i_diag[0], i_diag[1], ...] = 0.0 xwt[i_diag[0], i_diag[1], ...] = 0.0 i_lower = numpy.tril_indices(nants, -1) i_upper = (i_lower[1], i_lower[0]) x[i_upper] = numpy.conjugate(x[i_lower]) xwt[i_upper] = xwt[i_lower] gain[..., 0, 1] = 0.0 gain[..., 1, 0] = 0.0 reduce_oneside_x = numpy.abs(numpy.einsum("ij...->j...", x * xwt)) gainmask = reduce_oneside_x <= 0.0 # If the cross pol item is masked, its fallback value is 0 cross_mask = gainmask.copy() cross_mask[..., 0, 0] = False cross_mask[..., 1, 1] = False cross_mask[..., 0, 1] = gainmask[..., 0, 1] cross_mask[..., 1, 0] = gainmask[..., 1, 0] bad_ant = [] for iant in range(nants): # The current judgment uses channel 0. # If all polarizations of this channel are masked, # the antenna is considered bad thismask = gainmask[iant, 0] if numpy.all(thismask) is True: bad_ant.append(iant) _determine_refant(refant, bad_ant, refant_sort) numpy.putmask(gain, gainmask, 0.0) for _ in range(niter): gainLast = gain gain, gwt = _gain_substitution_matrix(gain, x, xwt) if phase_only: mask = numpy.abs(gain) > 0.0 gain[mask] = gain[mask] / numpy.abs(gain[mask]) change = numpy.max(numpy.abs(gain - gainLast)) gain = 0.5 * (gain + gainLast) if change < tol: angles = numpy.angle(gain) gain *= numpy.exp(-1j * angles)[refant, ...] numpy.putmask(gain, gainmask, 1.0) numpy.putmask(gain, cross_mask, 0.0) return gain, gwt, _solution_residual_matrix(gain, x, xwt) log.warning( "solve_antenna_gains_itsubs_scalar: " "gain solution failed, retaining gain solutions" ) angles = numpy.angle(gain) gain *= numpy.exp(-1j * angles)[refant, ...] numpy.putmask(gain, gainmask, 1.0) numpy.putmask(gain, cross_mask, 0.0) return gain, gwt, _solution_residual_matrix(gain, x, xwt) def _gain_substitution_matrix(gain, x, xwt): """ Substitute gains across all baselines of gain for point source equivalent visibilities. TODO: Check this function description :param gain: gains (numpy.array of shape [nant, nchan, nrec, nrec]) :param x: Equivalent point source visibility [nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan] :return: gain [nants, nchan, nrec, nrec], weight [nants, nchan, nrec, nrec] """ nants, nchan, nrec, _ = gain.shape # We are going to work with Jones 2x2 matrix formalism # so everything has to be converted to that format x = x.reshape([nants, nants, nchan, nrec, nrec]) diag = numpy.ones_like(x) xwt = xwt.reshape([nants, nants, nchan, nrec, nrec]) # Write these loops out explicitly. # Derivation of these vector equations is tedious but they are # structurally identical to the scalar case with the following changes # Vis -> 2x2 coherency vector, g-> 2x2 Jones matrix, # *-> matmul, conjugate->Hermitean transpose (.H) gain_conj = numpy.conjugate(gain) for ant in range(nants): diag[ant, ant, ...] = 0 n_top1 = numpy.einsum("ij...->j...", xwt * diag * x * gain[:, None, ...]) n_bot = diag * xwt * gain_conj * gain n_bot1 = numpy.einsum("ij...->i...", n_bot) # Using putmask: faster than using Boolen Index n_top2 = n_top1.copy() numpy.putmask(n_top2, n_bot1[...] <= 0, 0.0) n_bot2 = n_bot1.copy() numpy.putmask(n_bot2, n_bot1[...] <= 0, 1.0) newgain1 = n_top2 / n_bot2 gwt1 = n_bot1.real return newgain1, gwt1 def _solution_residual_scalar(gain, x, xwt): """Calculate residual across all baselines of gain for point source equivalent visibilities. :param gain: gains (numpy.array of shape [nant, nchan, nrec, nrec]) :param x: Equivalent point source visibility [nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan] :return: residual[nchan, nrec, nrec] """ nant, nchan, nrec, _ = gain.shape x = x.reshape(nant, nant, nchan, nrec, nrec) xwt = xwt.reshape(nant, nant, nchan, nrec, nrec) residual = numpy.zeros([nchan, nrec, nrec]) sumwt = numpy.zeros([nchan, nrec, nrec]) for chan in range(nchan): lgain = gain[:, chan, 0, 0] clgain = numpy.conjugate(lgain) smueller = numpy.ma.outer(clgain, lgain).reshape([nant, nant]) error = x[:, :, chan, 0, 0] - smueller for i in range(nant): error[i, i] = 0.0 residual[chan] += numpy.sum( error * xwt[:, :, chan, 0, 0] * numpy.conjugate(error) ).real sumwt[chan] += numpy.sum(xwt[:, :, chan, 0, 0]) residual[sumwt > 0.0] = numpy.sqrt( residual[sumwt > 0.0] / sumwt[sumwt > 0.0] ) residual[sumwt <= 0.0] = 0.0 return residual def _solution_residual_matrix(gain, x, xwt): """Calculate residual across all baselines of gain for point source equivalent visibilities. :param gain: gains (numpy.array of shape [nant, nchan, nrec, nrec]) :param x: Equivalent point source visibility [nants, nants, nchan, npol] :param xwt: Equivalent point source weight [nants, nants, nchan] :return: residual[nchan, nrec, nrec] """ n_gain = numpy.einsum("i...,j...->ij...", numpy.conjugate(gain), gain) n_error = numpy.conjugate(x - n_gain) nn_residual = (n_error * xwt * numpy.conjugate(n_error)).real n_residual = numpy.einsum("ijk...->k...", nn_residual) n_sumwt = numpy.einsum("ijk...->k...", xwt) n_residual[n_sumwt > 0.0] = numpy.sqrt( n_residual[n_sumwt > 0.0] / n_sumwt[n_sumwt > 0.0] ) n_residual[n_sumwt <= 0.0] = 0.0 return n_residual