# pylint: disable=invalid-name,too-many-arguments,no-member
"""
Functions to solve for delta-TEC variations across the array
"""
__all__ = [
"solve_ionosphere",
"set_cluster_maps",
"set_coeffs_and_params",
"get_param_count",
"apply_phase_distortions",
"build_normal_equation",
"cluster_design_matrix",
"solve_normal_equation",
"update_gain_table",
]
import logging
import numpy
from astropy import constants as const
from ska_sdp_datamodels.calibration.calibration_create import (
create_gaintable_from_visibility,
)
from ska_sdp_datamodels.visibility.vis_model import Visibility
from ska_sdp_func_python.calibration.ionosphere_utils import zern_array
log = logging.getLogger("func-python-logger")
# pylint: disable=too-many-locals
[docs]
def solve_ionosphere(
vis: Visibility,
modelvis: Visibility,
xyz,
cluster_id=None,
zernike_limit=None,
block_diagonal=False,
niter=15,
tol=1e-6,
):
"""
Solve for ionospheric delay as a function of station location.
Generate a gaintable by fitting for delta-TEC variations across the array.
The resulting delta-TEC variations will be converted to antenna-dependent
phase shifts and the gain_table updated.
Fits are performed within user-defined station clusters
:param vis: Visibility containing the observed data_model
:param modelvis: Visibility containing the predicted data_model
:param xyz: [n_antenna,3] array containing the antenna locations in the
local horizontal frame
:param cluster_id: [n_antenna] array containing the cluster ID of each
antenna. Defaults to a single cluster comprising all stations
:param zernike_limit: [n_cluster] list of Zernike index limits.
Default is to leave unset when calling set_coeffs_and_params().
:param block_diagonal: If true, each cluster will be solved for separately
during each iteration. This is equivalent to setting all elements of
the normal matrix to zero except for the block diagonal elements for
the cluster in question. (default False)
:param niter: Number of iterations (default 15)
:param tol: Iteration stops when the fractional change in the gain solution
is below this tolerance.
:return: GainTable containing solutions
"""
if numpy.all(modelvis.vis == 0.0):
raise ValueError("solve_ionosphere: Model visibilities are zero")
# Create a new gaintable based on the visibilities
# In general it will be filled with antenna-based phase shifts per channel
gain_table = create_gaintable_from_visibility(vis, jones_type="B")
# Ensure that the gain table and the input cluster indices are consistent
if cluster_id is None:
cluster_id = numpy.zeros(len(gain_table.antenna), "int")
n_cluster = numpy.amax(cluster_id) + 1
# Could be less strict & require max(gain_table.antenna) < len(cluster_id)
if len(gain_table.antenna) != len(cluster_id):
raise ValueError(f"cluster_id has wrong size {len(cluster_id)}")
# Calculate coefficients for each cluster and initialise parameter values
[param, coeff] = set_coeffs_and_params(xyz, cluster_id, zernike_limit)
n_param = get_param_count(param)[0]
if n_cluster == 1:
log.info(
"Setting up iono solver for %d stations in a single cluster",
len(gain_table.antenna),
)
log.info("There are %d total parameters in the cluster", n_param)
else:
log.info(
"Setting up iono solver for %d stations in %d clusters",
len(gain_table.antenna),
n_cluster,
)
log.info(
"There are %d total parameters: %d in c[0] + %d x c[1:%d]",
n_param,
len(param[0]),
len(param[1]),
len(param) - 1,
)
for it in range(niter):
if not block_diagonal:
[AA, Ab] = build_normal_equation(
vis, modelvis, param, coeff, cluster_id
)
# Solve the normal equations and update parameters
param_update = solve_normal_equation(AA, Ab, param, it)
else:
i0 = 0
param_update = []
for cid in range(n_cluster):
n_cparam = len(param[cid])
[AA, Ab] = build_normal_equation(
vis, modelvis, param, coeff, cluster_id, cid
)
# Solve the current incremental normal equations
soln_vec = numpy.linalg.lstsq(AA, Ab, rcond=None)[0]
# Update factor
nu = 0.5
# nu = 1.0 - 0.5 * (it % 2)
param_update.append(nu * soln_vec)
param[cid] += param_update[cid]
i0 += n_cparam
# Update the model
apply_phase_distortions(modelvis, param_update, coeff, cluster_id)
# test absolute relative change against tol
# flag for non-zero parameters to test relative change against
mask = numpy.abs(numpy.hstack(param).astype("float_")) > 0.0
change = numpy.max(
numpy.abs(numpy.hstack(param_update)[mask].astype("float_"))
/ numpy.abs(numpy.hstack(param)[mask].astype("float_"))
)
if change < tol:
break
# Update and return the gain table
update_gain_table(gain_table, param, coeff, cluster_id)
return gain_table
[docs]
def set_cluster_maps(cluster_id):
"""
Generate vectors to help convert between station and cluster indices.
:param: cluster_id: [n_antenna] array of antenna cluster indices
:return n_cluster: total number of clusters
:return cid2stn: mapping from cluster index to List of station indices
:return stn2cid: mapping from station index to cluster index
"""
n_station = len(cluster_id)
n_cluster = numpy.amax(cluster_id) + 1
# Mapping from station index to cluster index
stn2cid = numpy.empty(n_station, "int")
# Mapping from cluster index to a list of station indices
cid2stn = []
stations = numpy.arange(n_station).astype("int")
for cid in range(n_cluster):
mask = cluster_id == cid
cid2stn.append(stations[mask])
stn2cid[mask] = cid
return n_cluster, cid2stn, stn2cid
[docs]
def get_param_count(param):
"""
Return the total number of parameters across all clusters.
This function also returns the starting index of each cluster in stacked
parameter vectors.
:param param: [n_cluster] list of solution vectors, one for each cluster
:return n_param: int, total number of parameters
:return pidx0: [n_cluster], starting index of each cluster in param vectors
"""
n_cluster = len(param)
# Total number of parameters across all clusters
n_param = 0
# Starting parameter for each cluster
pidx0 = numpy.zeros(n_cluster, "int")
for cid in range(n_cluster):
pidx0[cid] = n_param
n_param += len(param[cid])
return n_param, pidx0
[docs]
def set_coeffs_and_params(
xyz,
cluster_id,
zernike_limit=None,
):
r"""
Initialise coefficients and parameters.
Calculate a vector of basis function values for each cluster and initialise
parameter values for each station.
:param xyz: [n_antenna,3] array containing the antenna locations in the
local horizontal frame
:param cluster_id: [n_antenna] array of antenna cluster indices
:param zernike_limit: [n_cluster] list of Zernike index limits:
Generate all Zernikes with n + \|m\| <= zernike_limit[cluster_id].
Default: [6,2,2,...,2]
:return param: [n_cluster] list of solution vectors
:return coeff: [n_station] list of basis-func value vectors
Stored as a numpy dtype=object array of variable-length coeff vectors
"""
# Get common mapping vectors between stations and clusters
[n_cluster, cid2stn, _] = set_cluster_maps(cluster_id)
n_station = len(cluster_id)
coeff = [None] * n_station
param = [None] * n_cluster
# Check list of polynomial degree limits
if zernike_limit is None:
# set a TEC offset and ramp for most clusters
zernike_limit = [2] * n_cluster
# but assume cluster zero contains the large central core
zernike_limit[0] = 6
elif len(zernike_limit) != n_cluster:
log.error("Incorrect length for zernike_limit parmater")
return numpy.empty(0), numpy.empty(0)
for cid in range(0, n_cluster):
# Generate the required Zernike polynomials for each station
zern_params = zern_array(
zernike_limit[cid], xyz[cid2stn[cid], 0], xyz[cid2stn[cid], 1]
)
# Set coefficients
for idx, stn in enumerate(cid2stn[cid]):
coeff[stn] = zern_params[idx]
# Initialise parameters
if len(cid2stn[cid]) > 0:
param[cid] = numpy.zeros(len(coeff[cid2stn[cid][0]]))
# # Get Zernike parameters for stations in the larger central cluster
# cid = 0
# zern_params = zern_array(
# zernike_limit[cid], xyz[cid2stn[cid], 0], xyz[cid2stn[cid], 1]
# )
# for idx, stn in enumerate(cid2stn[cid]):
# coeff[stn] = zern_params[idx]
# if len(cid2stn[cid]) > 0:
# param[cid] = numpy.zeros(len(coeff[cid2stn[cid][0]]))
# # now do the rest of the clusters
# for cid in range(1, n_cluster):
# # Remove the average position of the cluster
# xave = numpy.mean(xyz[cid2stn[cid], 0])
# yave = numpy.mean(xyz[cid2stn[cid], 1])
# for stn in cid2stn[cid]:
# # coeff[stn] = numpy.array([1, x[stn], y[stn]])
# coeff[stn] = numpy.array(
# [
# 1,
# xyz[stn, 0] - xave,
# xyz[stn, 1] - yave,
# ]
# )
# if len(cid2stn[cid]) > 0:
# param[cid] = numpy.zeros(len(coeff[cid2stn[cid][0]]))
return param, numpy.array(coeff, dtype=object)
[docs]
def apply_phase_distortions(
vis: Visibility,
param,
coeff,
cluster_id,
):
"""
Update visibility model with new fit solutions.
:param vis: Visibility containing the data_models to be distorted
:param param: [n_cluster] list of solution vectors, one for each cluster
:param coeff: [n_station] list of basis-func value vectors, one per station
Stored as a numpy dtype=object array of variable-length coeff vectors
:param cluster_id: [n_antenna] array of antenna cluster indices
"""
# Get common mapping vectors between stations and clusters
[n_cluster, _, stn2cid] = set_cluster_maps(cluster_id)
# set up a few references and constants
ant1 = vis.antenna1.data
ant2 = vis.antenna2.data
vis_data = vis.vis.data
# exclude auto-correlations from the mask
mask0 = ant1 != ant2
# Use einsum calls to average over parameters for all combinations of
# baseline and frequency
# [n_freq] scaling constants
# Loop over pairs of clusters and update the associated baselines
for cid1 in range(0, n_cluster):
for cid2 in range(0, n_cluster):
# A mask for all baselines in this cluster pair
mask = mask0 * (stn2cid[ant1] == cid1) * (stn2cid[ant2] == cid2)
if numpy.sum(mask) == 0:
continue
vis_data[0, mask, :, 0] *= numpy.exp(
# combine parmas for [n_baseline] then scale for [n_freq]
numpy.einsum(
"b,f->bf",
(
# combine parmas for ant i in baselines
numpy.einsum(
"bp,p->b",
numpy.vstack(coeff[ant1[mask]]).astype("float_"),
param[cid1],
)
# combine parmas for ant j in baselines
- numpy.einsum(
"bp,p->b",
numpy.vstack(coeff[ant2[mask]]).astype("float_"),
param[cid2],
)
),
# phase scaling with frequency
1j * 2.0 * numpy.pi * const.c.value / vis.frequency.data,
)
)
[docs]
def build_normal_equation(
vis: Visibility,
modelvis: Visibility,
param,
coeff,
cluster_id,
cid=None,
):
# pylint: disable=too-many-locals
"""
Build normal equations.
:param vis: Visibility containing the observed data_models
:param modelvis: Visibility containing the predicted data_models
:param param: [n_cluster] list of solution vectors, one for each cluster
:param coeff: [n_station] list of basis-func value vectors, one per station
Stored as a numpy dtype=object array of variable-length coeff vectors
:param cluster_id: [n_antenna] array of antenna cluster indices
:param cid: index of current cluster. Defaults to None, which will build
a single large matrix for all clusters.
"""
# If no cluster index is given, build matrix for all clusters
generate_full_equation = cid is None
# Get common mapping vectors between stations and clusters
[n_cluster, _, stn2cid] = set_cluster_maps(cluster_id)
# Set up a few refs/consts to use in loops and function calls
wl_const = 2.0 * numpy.pi * const.c.value / vis.frequency.data
ant1 = vis.antenna1.data
ant2 = vis.antenna2.data
vis_data = vis.vis.data
mdl_data = modelvis.vis.data
# Exclude auto-correlations from the mask
mask = ant1 != ant2
# Loop over frequency and accumulate normal equations
# Could probably handly frequency within an einsum as well.
# It is also a natural axis for parallel calculation of AA and Ab.
if generate_full_equation:
[n_param, pidx0] = get_param_count(param)
else:
n_param = len(param[cid])
# If StokesI is available, use that, if linear, use XX + YY
pols = [0]
if vis.visibility_acc.polarisation_frame.type.find("stokesI") == 0:
pols = [numpy.argwhere(vis.polarisation.data == "I")[0][0]]
elif vis.visibility_acc.polarisation_frame.type.find("linear") == 0:
pols = [
numpy.argwhere(vis.polarisation.data == "XX")[0][0],
numpy.argwhere(vis.polarisation.data == "YY")[0][0],
]
else:
raise ValueError("build_normal_equation: Unsupported polarisations")
AA = numpy.zeros((n_param, n_param))
Ab = numpy.zeros(n_param)
for chan in range(len(vis.frequency.data)):
# Could accumulate AA and Ab directly, but go via a
# design matrix for clarity. Update later if need be.
# V = M * exp(i * 2*pi * wl * fit)
# imag(V*conj(M)) = imag(|M|^2 * exp(i * 2*pi * wl * fit))
# ~ |M|^2 * 2*pi * wl * fit
# real(M*conj(M)) = |M|^2
for pol in pols:
if generate_full_equation:
A = numpy.zeros((n_param, len(vis.baselines)), "complex_")
# Loop over clusters and update the design matrix for the
# associated baselines
for _cid in range(0, n_cluster):
pid = numpy.arange(
pidx0[_cid], pidx0[_cid] + len(param[_cid])
)
A[pid, :] += cluster_design_matrix(
mdl_data[0, :, chan, pol],
mask,
ant1,
ant2,
coeff,
stn2cid,
wl_const[chan],
len(param[_cid]),
_cid,
)
else:
A = cluster_design_matrix(
mdl_data[0, :, chan, pol],
mask,
ant1,
ant2,
coeff,
stn2cid,
wl_const[chan],
n_param,
cid,
)
# incorporate flags into weights
wgt = vis.weight.data[0, :, chan, pol] * (
1 - vis.flags.data[0, :, chan, pol]
)
# Average over all baselines for each param pair
AA += numpy.real(
numpy.einsum("pb,b,qb->pq", numpy.conj(A), wgt, A)
)
Ab += numpy.imag(
numpy.einsum(
"pb,b,b->p",
numpy.conj(A),
wgt,
vis_data[0, :, chan, pol] - mdl_data[0, :, chan, pol],
)
)
return AA, Ab
[docs]
def cluster_design_matrix(
mdl_data,
mask0,
ant1,
ant2,
coeff,
stn2cid,
wl_const,
n_param,
cid,
):
"""
Generate elements of the design matrix for the current cluster.
Dereference outside of loops and the function call to avoid overheads.
:param mdl_data: [n_time,n_baseline,n_pol] predicted model vis for chan
:param mask0: [n_baseline] mask of wanted data samples
:param ant1: [n_baseline] station index of first antenna in each baseline
:param ant2: [n_baseline] station index of second antenna in each baseline
:param coeff: [n_station] list of basis-func value vectors, one per station
:param stn2cid: [n_station] mapping from station index to cluster index
:param wl_const: 2*pi*lambda for the current frequency channel
:param cid: index of current cluster
:param n_param: number of parameters in Normal equation
"""
n_baselines = len(mask0)
A = numpy.zeros((n_param, n_baselines), "complex_")
blidx_all = numpy.arange(n_baselines)
# Get all masked baselines with ant1 in this cluster
blidx = blidx_all[mask0 * (stn2cid[ant1] == cid)]
if len(blidx) > 0:
# [nvis] A0 terms x [nvis,nparam] coeffs (1st antenna)
# all masked antennas have the same number of coeffs so can vstack
A[:, blidx] += numpy.einsum(
"b,bp->pb",
wl_const * mdl_data[blidx],
numpy.vstack(coeff[ant1[blidx]]).astype("float_"),
)
# Get all masked baselines with ant2 in this cluster
blidx = blidx_all[mask0 * (stn2cid[ant2] == cid)]
if len(blidx) > 0:
# [nvis] A0 terms x [nvis,nparam] coeffs (2nd antenna)
# all masked antennas have the same number of coeffs so can vstack
A[:, blidx] -= numpy.einsum(
"b,bp->pb",
wl_const * mdl_data[blidx],
numpy.vstack(coeff[ant2[blidx]]).astype("float_"),
)
return A
[docs]
def solve_normal_equation(
AA,
Ab,
param,
it=0, # pylint: disable=unused-argument
):
"""
Solve the normal equations and update parameters.
Using the SVD-based DGELSD solver via numpy.linalg.lstsq.
Could use the LU-decomposition-based DGESV solver in numpy.linalg.solve,
but the normal matrix may not be full rank.
If n_param gets large (~ 100) it may be better to use a numerical solver
like lsmr or lsqr.
:param AA: [n_param, n_param] normal equation
:param Ab: [n_param] data vector
:param param: [n_cluster] list of solution vectors, one for each cluster
:param it: int, current iteration
:return param_update: the current incremental param update
"""
n_cluster = len(param)
[_, pidx0] = get_param_count(param)
# Solve the current incremental normal equations
soln_vec = numpy.linalg.lstsq(AA, Ab, rcond=None)[0]
# Make a copy of coeff for just the current incremental update
param_update = []
for cid in range(n_cluster):
param_update.append(numpy.zeros(len(param[cid])))
# Update factor
nu = 0.5
# StefCal-like algorithms work well with an alternating factor like this
# Some early tests of this algorithm did as well. Come back to this
# nu = 1.0 - 0.5 * (it % 2)
for cid in range(n_cluster):
param_update[cid] = (
nu
* soln_vec[pidx0[cid] : pidx0[cid] + len(param[cid])] # noqa: E203
)
param[cid] += param_update[cid]
return param_update
[docs]
def update_gain_table(
gain_table,
param,
coeff,
cluster_id,
):
"""
Add new solutions to gaintable.
Expand solutions for all stations and frequency channels and insert in the
gain table
:param gain_table: GainTable to be updated
:param param: [n_cluster] list of solution vectors, one for each cluster
:param coeff: [n_station] list of basis-func value vectors, one per station
Stored as a numpy dtype=object array of variable-length coeff vectors
:param cluster_id: [n_antenna] array of antenna cluster indices
"""
# Get common mapping vectors between stations and clusters
[n_cluster, cid2stn, _] = set_cluster_maps(cluster_id)
wl = const.c.value / gain_table.frequency.data
table_data = gain_table.gain.data
for cid in range(0, n_cluster):
# combine parmas for [n_station] phase terms then scale for [n_freq]
table_data[0, cid2stn[cid], :, 0, 0] = numpy.exp(
numpy.einsum(
"s,f->sf",
numpy.einsum(
"sp,p->s",
numpy.vstack(coeff[cid2stn[cid]]).astype("float_"),
param[cid],
),
1j * 2.0 * numpy.pi * wl,
)
)