""" A collection of solvers for antenna-based Jones matrices
Algorithm 1 of solve_jones iteratively updates the set of Jones matrices, but within each
iteration performs an independent least-squares optimisation for each matrix
without considering how the others are updating. This is equivalent to forming
a normal matrix and setting the off-diagonal terms to zero. While it does not
converge as fast as the normal-equation approach, it scales well in terms of
operations and memory, and allows the free parameter for each antenna/station
to be the 2x2 Jones matrix. It is based on the equivalent solver in the MWA
RealTime System (Mitchell et at., 2008, IEEE JSTSP, 2, JSTSP.2008.2005327).::
Accumulation option 0: easy to read
Accumulation option 1: a bit faster (default)
Algorithm 2 is a full normal-equation based linear least-squares algorithm. It
is based on the approach in Yandasoft.::
Accumulation option 0: via design matrix (default)
Accumulation option 1: direct accumulation of normal matrix with pre-summing of normal equation products in time and frequency.
For example::
solve_jones(vis, modelvis, jones, niter=25)
"""
import copy
import logging
import numpy as np
from rascil.data_models.memory_data_models import BlockVisibility, GainTable
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsmr, lsqr
from jones_solvers.processing_components.util import (
gen_cdm,
gen_pol_matrix,
update_design_matrix,
)
log = logging.getLogger()
[docs]def solve_jones(
vis: BlockVisibility,
modelvis: BlockVisibility,
jones: GainTable,
niter=30,
nu=None,
tol=1e-6,
algorithm=1,
accum_opt=None,
lin_solver="lsmr",
lin_solver_normal=None,
lin_solver_rcond=1e-6,
testvis: BlockVisibility = None,
):
"""Solve Jones matrices by fitting observed visibilities to model visibilities
A single set of matrices is found for all times and frequencies.
:param vis: BlockVisibility containing the observed visibility data
:param modelvis: BlockVisibility containing the predicted data_model (updated on exit)
:param jones: Existing GainTable containing the Jones matrices (updated on exit)
:param niter: Number of iterations (default 30)
:param nu: iterative adaptation update factor (default variable)
:param tol: Convergence fractional-change tolerance (default 1e-6)
:param algorithm: Solver algorithm used (default 1)
:param accum_opt: Accumulation option (1 for algorithm 1, 0 for algorithm 2)
:param lin_solver: linear solver used in each iteration of algorithm 2:
"lsmr" (scipy, default), "lsqr" (scipy), "lstsq" (numpy) or "svd" (numpy)
:param lin_solver_normal: whether or not to form normal equations before calling lin_solver
(default False where allowed: accum_opt 0 with lsmr or lsqr). Otherwise True
:param lin_solver_rcond: cutoff ratio used when lin_solver = "svd" or "lstsq" (default 1e-6)
:param testvis: Optional BlockVisibility for comparisons (e.g. noiseless simulated data).
Generates matplotlib output
:return: numpy array containing chisq values as a function of iteration
"""
assert jones is not None, "Initial Jones matrix input is required"
assert modelvis is not None, "Initial model visibility input is required"
assert np.max(np.abs(modelvis.vis)) > 0.0, "Model visibility is zero"
if algorithm == 1:
if accum_opt == None:
accum_opt = 1
else:
assert (
accum_opt >= 0 and accum_opt <= 1
), "Unknown algorithm accum_opt"
elif algorithm == 2:
if accum_opt == None:
accum_opt = 0
else:
assert (
accum_opt >= 0 and accum_opt <= 2
), "Unknown algorithm accum_opt"
assert (
lin_solver == "lsmr"
or lin_solver == "lsqr"
or lin_solver == "lstsq"
or lin_solver == "svd"
), "Unknown linear solver"
if accum_opt == 0:
if lin_solver == "lsmr" or lin_solver == "lsqr":
if lin_solver_normal == None:
lin_solver_normal = False
else:
assert lin_solver_normal != False, (
"must use lin_solver_normal with solver " + lin_solver
)
lin_solver_normal = True
else:
assert (
lin_solver_normal != False
), "must use lin_solver_normal with direct Normal accumulation"
lin_solver_normal = True
else:
raise ValueError("Unknown algorithm index")
nstations = jones["antenna"].shape[0]
assert nstations > 0, "Jones GainTable has no antennas"
# get a list of all stations that are in at least one baseline
subarray = np.unique(
np.concatenate((vis["antenna1"].data, vis["antenna2"].data), axis=0)
).flatten()
nsubarray = len(subarray)
assert nsubarray > 0, "Dataset has no antennas"
assert (
nsubarray <= nstations
), "Dataset has more antennas than the GainTable"
subarray_lookup = np.empty(nstations, "int")
for idx in range(0, nsubarray):
subarray_lookup[subarray[idx]] = idx
log.debug("")
log.debug("starting solver")
log.debug(" - nstations = {}".format(nstations))
log.debug(" - nsubarray = {}".format(nsubarray))
log.debug(" - niter = {}".format(niter))
log.debug(" - algorithm = {}".format(algorithm))
log.debug(" - accum_opt = {}".format(accum_opt))
if algorithm == 2:
log.debug(" - lin_solver = {}".format(lin_solver))
log.debug(" - lin_solver_normal = {}".format(lin_solver_normal))
log.debug(" - lin_solver_rcond = {}".format(lin_solver_rcond))
assert all(
vis["polarisation"].data == ["XX", "XY", "YX", "YY"]
), "linear polarisations required in solver"
assert all(
modelvis["polarisation"].data == ["XX", "XY", "YX", "YY"]
), "linear polarisations required in solver"
# could perhaps discard autos here (if not done earlier)
# set up references to the data
stn1 = modelvis["antenna1"].data
stn2 = modelvis["antenna2"].data
nt = modelvis["datetime"].shape[0]
nbl = modelvis["baselines"].shape[0]
nch = modelvis["frequency"].shape[0]
nvis = nt * nbl * nch
# these seem to be references rather than copies.
# - perhaps they shouldn't be... (vmdl is changed on exit)
vmdl = np.reshape(modelvis["vis"].data, (nt, nbl, nch, 2, 2))
vobs = np.reshape(vis["vis"].data, (nt, nbl, nch, 2, 2))
# check for equal weights across polarisation, or flag and skip?
# - for now just using the first pol weight
wgt = modelvis["weight"].data
J = jones["gain"].data
chisq = []
if testvis != None:
vtst = np.reshape(testvis["vis"].data, (nt, nbl, nch, 2, 2))
err = 0
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
vis_err = vmdl[t, k, f, :, :] - vtst[t, k, f, :, :]
sqr_err = np.real(vis_err @ vis_err.conj().T)
err += sqr_err[0, 0] + sqr_err[1, 1]
chisq.append(err / 2.0 / nvis)
log.debug("init err = {:.2e}".format(err / 2.0 / nbl))
if algorithm == 1:
for it in range(0, niter):
Som = np.zeros((nsubarray, 2, 2), "complex")
Smm = np.zeros((nsubarray, 2, 2), "complex")
if accum_opt == 0:
# Algorithm 1a: accumulate matix products directly, forming Jones matrices first
# - not as efficient but easier to follow
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
if stn1[k] == stn2[k]:
continue
# update sums for station 1
Som[stn1[k]] += (
wgt[t, k, f, 0]
* vobs[t, k, f, :, :]
@ vmdl[t, k, f, :, :].conj().T
)
Smm[stn1[k]] += (
wgt[t, k, f, 0]
* vmdl[t, k, f, :, :]
@ vmdl[t, k, f, :, :].conj().T
)
# update sums for station 2
Som[stn2[k]] += (
wgt[t, k, f, 0]
* vobs[t, k, f, :, :].conj().T
@ vmdl[t, k, f, :, :]
)
Smm[stn2[k]] += (
wgt[t, k, f, 0]
* vmdl[t, k, f, :, :].conj().T
@ vmdl[t, k, f, :, :]
)
elif accum_opt == 1:
# Algorithm 1b: accumulate matix products separately using vectors over baselines
# - avoids the loop over all nbl visibilities
for idx in range(0, nsubarray):
stn = subarray[idx]
# update np.sums for station 1
ind = (
stn1 == stn
) # all baselines with station stn as the first correlator element
ind *= stn2 != stn # make sure they are cross-correlations
Som[idx, 0, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 0]
* vmdl[:, ind, :, 0, 0].conj()
+ vobs[:, ind, :, 0, 1]
* vmdl[:, ind, :, 0, 1].conj()
)
)
Som[idx, 0, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 0]
* vmdl[:, ind, :, 1, 0].conj()
+ vobs[:, ind, :, 0, 1]
* vmdl[:, ind, :, 1, 1].conj()
)
)
Som[idx, 1, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 1, 0]
* vmdl[:, ind, :, 0, 0].conj()
+ vobs[:, ind, :, 1, 1]
* vmdl[:, ind, :, 0, 1].conj()
)
)
Som[idx, 1, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 1, 0]
* vmdl[:, ind, :, 1, 0].conj()
+ vobs[:, ind, :, 1, 1]
* vmdl[:, ind, :, 1, 1].conj()
)
)
#
Smm[idx, 0, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 0]
* vmdl[:, ind, :, 0, 0].conj()
+ vmdl[:, ind, :, 0, 1]
* vmdl[:, ind, :, 0, 1].conj()
)
)
Smm[idx, 0, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 0]
* vmdl[:, ind, :, 1, 0].conj()
+ vmdl[:, ind, :, 0, 1]
* vmdl[:, ind, :, 1, 1].conj()
)
)
Smm[idx, 1, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 1, 0]
* vmdl[:, ind, :, 0, 0].conj()
+ vmdl[:, ind, :, 1, 1]
* vmdl[:, ind, :, 0, 1].conj()
)
)
Smm[idx, 1, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 1, 0]
* vmdl[:, ind, :, 1, 0].conj()
+ vmdl[:, ind, :, 1, 1]
* vmdl[:, ind, :, 1, 1].conj()
)
)
# update np.sums for station 2
ind = (
stn2 == stn
) # all baselines with station stn as the second correlator element
ind *= stn1 != stn # make sure they are cross-correlations
Som[idx, 0, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 0].conj()
* vmdl[:, ind, :, 0, 0]
+ vobs[:, ind, :, 1, 0].conj()
* vmdl[:, ind, :, 1, 0]
)
)
Som[idx, 0, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 0].conj()
* vmdl[:, ind, :, 0, 1]
+ vobs[:, ind, :, 1, 0].conj()
* vmdl[:, ind, :, 1, 1]
)
)
Som[idx, 1, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 1].conj()
* vmdl[:, ind, :, 0, 0]
+ vobs[:, ind, :, 1, 1].conj()
* vmdl[:, ind, :, 1, 0]
)
)
Som[idx, 1, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vobs[:, ind, :, 0, 1].conj()
* vmdl[:, ind, :, 0, 1]
+ vobs[:, ind, :, 1, 1].conj()
* vmdl[:, ind, :, 1, 1]
)
)
#
Smm[idx, 0, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 0].conj()
* vmdl[:, ind, :, 0, 0]
+ vmdl[:, ind, :, 1, 0].conj()
* vmdl[:, ind, :, 1, 0]
)
)
Smm[idx, 0, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 0].conj()
* vmdl[:, ind, :, 0, 1]
+ vmdl[:, ind, :, 1, 0].conj()
* vmdl[:, ind, :, 1, 1]
)
)
Smm[idx, 1, 0] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 1].conj()
* vmdl[:, ind, :, 0, 0]
+ vmdl[:, ind, :, 1, 1].conj()
* vmdl[:, ind, :, 1, 0]
)
)
Smm[idx, 1, 1] += np.sum(
wgt[:, ind, :, 0]
* (
vmdl[:, ind, :, 0, 1].conj()
* vmdl[:, ind, :, 0, 1]
+ vmdl[:, ind, :, 1, 1].conj()
* vmdl[:, ind, :, 1, 1]
)
)
else:
log.warning("Invalid accumulation option index")
break
if nu == None:
nu_it = 1.0 - 0.5 * (it % 2)
else:
nu_it = nu
oldJ = copy.deepcopy(np.array(J))
update = []
for idx in range(0, nsubarray):
stn = subarray[idx]
update.append(
np.eye(2)
+ nu_it * (Som[idx] @ np.linalg.inv(Smm[idx]) - np.eye(2))
)
J[0, stn, 0] = update[idx] @ J[0, stn, 0]
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
vmdl[t, k, f] = (
update[subarray_lookup[stn1[k]]]
@ vmdl[t, k, f]
@ update[subarray_lookup[stn2[k]]].conj().T
)
if testvis != None:
err = 0
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
vis_err = vmdl[t, k, f, :, :] - vtst[t, k, f, :, :]
sqr_err = np.real(vis_err @ vis_err.conj().T)
err += sqr_err[0, 0] + sqr_err[1, 1]
chisq.append(err / 2.0 / nvis)
log.debug("iter {:2d} = {:.2e}".format(it, err / 2.0 / nbl))
if np.max(np.abs(np.array(J) - oldJ)) / np.max(np.abs(oldJ)) < tol:
break
elif algorithm == 2:
gX = np.zeros(nsubarray, "complex")
gY = np.zeros(nsubarray, "complex")
dXY = np.zeros(nsubarray, "complex")
dYX = np.zeros(nsubarray, "complex")
for idx in range(0, nsubarray):
stn = subarray[idx]
gX[idx] = J[0, stn, 0][0, 0]
gY[idx] = J[0, stn, 0][1, 1]
dXY[idx] = J[0, stn, 0][0, 1] / J[0, stn, 0][0, 0]
dYX[idx] = -J[0, stn, 0][1, 0] / J[0, stn, 0][1, 1]
vfit = vmdl.copy()
vsky = vmdl.copy()
if accum_opt == 1:
Som = np.zeros((nbl, 4, 4), "complex")
Smm = np.zeros((nbl, 4, 4), "complex")
# vis shape: (nt,nbl,nch,2,2)
for t in range(0, nt):
for k in range(0, nbl):
if stn1[k] == stn2[k]:
continue
for f in range(0, nch):
smdl = vmdl[t, k, f, :, :]
sobs = vobs[t, k, f, :, :]
modelVec = np.reshape(smdl, (4))[np.newaxis, :]
measuredVec = np.reshape(sobs, (4))[np.newaxis, :]
Smm[k] += (
wgt[t, k, f, 0] * modelVec.conj().T @ modelVec
)
Som[k] += (
wgt[t, k, f, 0] * modelVec.conj().T @ measuredVec
)
for it in range(0, niter):
if accum_opt == 0:
# accumulation space for design matrix and data vector
# - rows: all visibilities x 4 pol x real & imag
# - cols: all stations x 4 pol x real & imag
A = np.zeros([8 * nvis, 8 * nsubarray])
dv = np.zeros([8 * nvis, 1])
vres = vobs - vfit
# vis shape: (nt,nbl,nch,2,2)
for t in range(0, nt):
for k in range(0, nbl):
for f in range(0, nch):
if stn1[k] == stn2[k]:
continue
i = subarray_lookup[stn1[k]]
j = subarray_lookup[stn2[k]]
ix2 = 2 * i
jx2 = 2 * j
kx2 = 2 * (t * nbl * nch + k * nch + f)
# visibility indices (i.e. design matrix rows)
kXX = kx2
kYY = kx2 + 2 * nvis
kXY = kx2 + 4 * nvis
kYX = kx2 + 6 * nvis
# parameter indices (i.e. design matrix columns)
iXX = ix2
iYY = ix2 + 2 * nsubarray
iXY = ix2 + 4 * nsubarray
iYX = ix2 + 6 * nsubarray
jXX = jx2
jYY = jx2 + 2 * nsubarray
jXY = jx2 + 4 * nsubarray
jYX = jx2 + 6 * nsubarray
sqrt_wgt = np.sqrt(wgt[t, k, f, 0])
sres = sqrt_wgt * vres[t, k, f, :, :]
smdl = sqrt_wgt * vmdl[t, k, f, :, :]
# sobs = sqrt_wgt * vobs[t,k,f,:,:]
# generate first derivatives of the real and imag parts of this vis WRT all free parameters
dv[kXX] = np.real(sres[0, 0])
dv[kXX + 1] = np.imag(sres[0, 0])
dv[kYY] = np.real(sres[1, 1])
dv[kYY + 1] = np.imag(sres[1, 1])
dv[kXY] = np.real(sres[0, 1])
dv[kXY + 1] = np.imag(sres[0, 1])
dv[kYX] = np.real(sres[1, 0])
dv[kYX + 1] = np.imag(sres[1, 0])
update_design_matrix(
A,
kXX,
kYY,
kXY,
kYX,
iXX,
iYY,
iXY,
iYX,
jXX,
jYY,
jXY,
jYX,
smdl[0, 0],
smdl[1, 1],
smdl[0, 1],
smdl[1, 0],
gX[i],
gY[i],
dXY[i],
dYX[i],
gX[j],
gY[j],
dXY[j],
dYX[j],
)
# could make A and v a few rows longer and add extra contraints. e.g.:
# if accum_opt == 0:
# # set station ref to have phase = 0
# A[8*nvis,ref+1] = 1
# # sum_i(dXY[i]+dYX[i]') = 0
# A1[8*nvis+1,4*nsubarray:8*nsubarray-1:2] = 1 # sum of the real parts dXY & dYX'
# A2[8*nvis+1,4*nsubarray:8*nsubarray-1:2] = 1 # sum of the real parts dXY & dYX'
# A3[8*nvis+1,4*nsubarray:8*nsubarray-1:2] = 1 # sum of the real parts dXY & dYX'
# A1[8*nvis+2,4*nsubarray+1:6*nsubarray:2] = 1 # sum of the imaginary parts of dXY
# A2[8*nvis+2,4*nsubarray+1:6*nsubarray:2] = 1 # sum of the imaginary parts of dXY
# A3[8*nvis+2,4*nsubarray+1:6*nsubarray:2] = 1 # sum of the imaginary parts of dXY
# A1[8*nvis+2,6*nsubarray+1:8*nsubarray:2] = -1 # sum of the imaginary parts of dYX'
# A2[8*nvis+2,6*nsubarray+1:8*nsubarray:2] = -1 # sum of the imaginary parts of dYX'
# A3[8*nvis+2,6*nsubarray+1:8*nsubarray:2] = -1 # sum of the imaginary parts of dYX'
# # make A full rank? Does not constrain the leakage phase...
# A1[8*nvis+3,2*nsubarray+ref+1] = 1
# A2[8*nvis+3,2*nsubarray+ref+1] = 1
# A3[8*nvis+3,2*nsubarray+ref+1] = 1
if lin_solver_normal:
AA = A.T @ A
Av = A.T @ dv
elif accum_opt == 1:
# accumulation space for normal eqautions products
# - all stations x 4 pol x real & imag
AA = np.zeros([8 * nsubarray, 8 * nsubarray])
Av = np.zeros([8 * nsubarray, 1])
for k in range(0, nbl):
if stn1[k] == stn2[k]:
continue
i = subarray_lookup[stn1[k]]
j = subarray_lookup[stn2[k]]
ix2 = 2 * i
jx2 = 2 * j
kx2 = 2 * (t * nbl * nch + k * nch + f)
# visibility indices (i.e. design matrix rows)
kXX = kx2
kYY = kx2 + 2 * nvis
kXY = kx2 + 4 * nvis
kYX = kx2 + 6 * nvis
# parameter indices (i.e. design matrix columns)
iXX = ix2
iYY = ix2 + 2 * nsubarray
iXY = ix2 + 4 * nsubarray
iYX = ix2 + 6 * nsubarray
jXX = jx2
jYY = jx2 + 2 * nsubarray
jXY = jx2 + 4 * nsubarray
jYX = jx2 + 6 * nsubarray
# generate the 4x4 Complex Diff matrices for each free parameter of baseline i-j
params = gen_cdm(
gX[i],
gY[i],
dXY[i],
dYX[i],
gX[j],
gY[j],
dXY[j],
dYX[j],
)
paramsRe = params[0:16:2]
paramsIm = params[1:16:2]
values = gen_pol_matrix(
gX[i],
gY[i],
dXY[i],
dYX[i],
gX[j],
gY[j],
dXY[j],
dYX[j],
)
pos = np.zeros(8, "int")
pos[0] = iXX
pos[1] = jXX
pos[2] = iYY
pos[3] = jYY
pos[4] = iXY
pos[5] = jXY
pos[6] = iYX
pos[7] = jYX
# very slow approach, showing the products more clearly
# for param1 in range(0,8):
# v_re = 0
# v_im = 0
# for p in range(0,4):
# for p1 in range(0,4):
# v_re += np.real(np.conj(paramsRe[param1][p,p1]) * Som[k][p1,p])
# v_im += np.real(np.conj(paramsIm[param1][p,p1]) * Som[k][p1,p])
# for p2 in range(0,4):
# v_re -= np.real(np.conj(paramsRe[param1][p,p1]) * values[p,p2] * Smm[k][p1,p2])
# v_im -= np.real(np.conj(paramsIm[param1][p,p1]) * values[p,p2] * Smm[k][p1,p2])
# Av[pos[param1] ] += v_re
# Av[pos[param1]+1] += v_im
# for param2 in range(0,8):
# nm00 = 0
# nm01 = 0
# nm10 = 0
# nm11 = 0
# for p in range(0,4):
# for p1 in range(0,4):
# for p2 in range(0,4):
# nm00 += np.real(np.conj(paramsRe[param1][p,p1]) *
# paramsRe[param2][p,p2] * Smm[k][p1,p2]);
# nm01 += np.real(np.conj(paramsRe[param1][p,p1]) *
# paramsIm[param2][p,p2] * Smm[k][p1,p2]);
# nm10 += np.real(np.conj(paramsIm[param1][p,p1]) *
# paramsRe[param2][p,p2] * Smm[k][p1,p2]);
# nm11 += np.real(np.conj(paramsIm[param1][p,p1]) *
# paramsIm[param2][p,p2] * Smm[k][p1,p2]);
# AA[pos[param1] ,pos[param2] ] += nm00
# AA[pos[param1] ,pos[param2]+1] += nm01
# AA[pos[param1]+1,pos[param2] ] += nm10
# AA[pos[param1]+1,pos[param2]+1] += nm11
# this is a faster way of forming the products above
for param1 in range(0, 8):
pos1 = pos[param1]
hermp1Re = paramsRe[param1].conj().T
hermp1Im = paramsIm[param1].conj().T
v_re = 0
v_im = 0
for p in range(0, 4):
v_re += np.real(
np.conj(paramsRe[param1][p, :][np.newaxis, :])
@ (
Som[k][:, p][:, np.newaxis]
- Smm[k] @ values[p, :][np.newaxis, :].T
)
)
v_im += np.real(
np.conj(paramsIm[param1][p, :][np.newaxis, :])
@ (
Som[k][:, p][:, np.newaxis]
- Smm[k] @ values[p, :][np.newaxis, :].T
)
)
Av[pos[param1]] += v_re[0, 0]
Av[pos[param1] + 1] += v_im[0, 0]
for param2 in range(0, 8):
pos2 = pos[param2]
AA[pos1, pos2] += np.sum(
np.real((hermp1Re @ paramsRe[param2]) * Smm[k])
)
AA[pos1, pos2 + 1] += np.sum(
np.real((hermp1Re @ paramsIm[param2]) * Smm[k])
)
AA[pos1 + 1, pos2] += np.sum(
np.real((hermp1Im @ paramsRe[param2]) * Smm[k])
)
AA[pos1 + 1, pos2 + 1] += np.sum(
np.real((hermp1Im @ paramsIm[param2]) * Smm[k])
)
# solve this iteration's linear problem
# numerical approaches (lsmr & lsqr) seem more stable, but not really more efficient
# - though haven't tested scaling to large numbers of free params
# - low rcond can lead to early wandering
# - high rcond can lead to convergence limits (similar to algorithm 1)
# solving A,v rather than AA,Av seems to be more accurate in high SNR situations
# - does solving AA,Av give the same limitations as algorithm 1 or using svd?
if lin_solver == "svd":
assert lin_solver_normal, (
"linear solver "
+ lin_solver
+ " requires normal equations"
)
svdu, svds, svdv = np.linalg.svd(AA, full_matrices=True)
sinv = 1.0 / svds
for k in range(0, 8 * nsubarray):
if (svds[k] / svds[0]) < lin_solver_rcond:
sinv[k] = 0
AAinv = svdv.T @ np.diag(sinv) @ svdu.T
gfit = AAinv @ Av
elif lin_solver == "lstsq":
# lstsq can be run on A,dv, but is unstable in the early iterations
assert lin_solver_normal, (
"linear solver "
+ lin_solver
+ " requires normal equations"
)
gfit, resid, rank, svals = np.linalg.lstsq(
AA, Av, rcond=lin_solver_rcond
)
elif lin_solver == "lsmr":
if lin_solver_normal:
gfit, istop, itn, normr = lsmr(
csc_matrix(AA, dtype=float), Av
)[:4]
else:
gfit, istop, itn, normr = lsmr(
csc_matrix(A, dtype=float), dv
)[:4]
elif lin_solver == "lsqr":
if lin_solver_normal:
gfit, istop, itn, r1norm = lsqr(
csc_matrix(AA, dtype=float), Av
)[:4]
else:
gfit, istop, itn, r1norm = lsqr(
csc_matrix(A, dtype=float), dv
)[:4]
else:
raise NameError("unknown linear solver " + lin_solver)
if nu == None:
nu_it = 1.0
else:
nu_it = nu
gX += (
nu_it
* np.array(
gfit[0 * nsubarray : 2 * nsubarray - 1 : 2]
+ 1j * gfit[0 * nsubarray + 1 : 2 * nsubarray : 2]
).flatten()
)
gY += (
nu_it
* np.array(
gfit[2 * nsubarray : 4 * nsubarray - 1 : 2]
+ 1j * gfit[2 * nsubarray + 1 : 4 * nsubarray : 2]
).flatten()
)
dXY += (
nu_it
* np.array(
gfit[4 * nsubarray : 6 * nsubarray - 1 : 2]
+ 1j * gfit[4 * nsubarray + 1 : 6 * nsubarray : 2]
).flatten()
)
dYX += (
nu_it
* np.array(
gfit[6 * nsubarray : 8 * nsubarray - 1 : 2]
+ 1j * gfit[6 * nsubarray + 1 : 8 * nsubarray : 2]
).flatten()
)
jtmp = []
for idx in range(0, nsubarray):
jtmp.append(
np.array(
[
[gX[idx], gX[idx] * dXY[idx]],
[-gY[idx] * dYX[idx], gY[idx]],
]
)
)
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
vfit[t, k, f] = (
jtmp[subarray_lookup[stn1[k]]]
@ vsky[t, k, f]
@ jtmp[subarray_lookup[stn2[k]]].conj().T
)
if testvis != None:
err = 0
for t in range(0, nt):
for f in range(0, nch):
for k in range(0, nbl):
vis_err = vfit[t, k, f, :, :] - vtst[t, k, f, :, :]
sqr_err = np.real(vis_err @ vis_err.conj().T)
err += sqr_err[0, 0] + sqr_err[1, 1]
chisq.append(err / 2.0 / nvis)
log.debug("iter {:2d} = {:.2e}".format(it, err / 2.0 / nbl))
oldJ = copy.deepcopy(np.array(J))
# update jones
for idx in range(0, nsubarray):
stn = subarray[idx]
J[0, stn, 0] = np.array(
[
[gX[idx], gX[idx] * dXY[idx]],
[-gY[idx] * dYX[idx], gY[idx]],
]
)
if np.max(np.abs(np.array(J) - oldJ)) / np.max(np.abs(oldJ)) < tol:
break
modelvis["vis"].data = np.reshape(vfit, (nt, nbl, nch, 4))
return np.array(chisq)