DP3
NormalEquationsSolver.h
Go to the documentation of this file.
1 // Copyright (C) 2020 ASTRON (Netherlands Institute for Radio Astronomy)
2 // SPDX-License-Identifier: GPL-3.0-or-later
3 
4 #ifndef NORMALEQ_SOLVER_H
5 #define NORMALEQ_SOLVER_H
6 
7 #include "LLSSolver.h"
8 
9 #include <algorithm>
10 #include <complex>
11 #include <vector>
12 
13 namespace dp3 {
14 namespace ddecal {
15 
16 /* CGELSD prototype */
17 extern "C" void cposv_(char* uplo, int* n, int* nrhs, std::complex<float>* a,
18  int* lda, std::complex<float>* b, int* ldb, int* info);
19 
20 class NormalEquationsSolver final : public LLSSolver {
21  public:
22  NormalEquationsSolver(int m, int n, int nrhs)
23  : LLSSolver(m, n, nrhs), adaggera_(n_ * n_), adaggerb_(n_ * nrhs_) {}
24 
33  bool Solve(std::complex<float>* a, std::complex<float>* b) override {
34  // compute a^dagger a
35  for (int n = 0; n < n_; ++n) {
36  for (int np = n; np < n_; ++np) { // fill upper triangle only
37  adaggera_[n + np * n_] = std::complex<float>(0.0, 0.0);
38  for (int m = 0; m < m_; ++m) {
39  adaggera_[n + np * n_] += std::conj(a[m + n * m_]) * a[m + np * m_];
40  }
41  }
42  }
43  for (int p = 0; p < nrhs_; ++p) {
44  for (int n = 0; n < n_; ++n) {
45  adaggerb_[n + p * n_] = std::complex<float>(0.0, 0.0);
46  for (int m = 0; m < m_; ++m) {
47  adaggerb_[n + p * n_] += std::conj(a[m + n * m_]) * b[m + p * m_];
48  }
49  }
50  }
51 
52  int info;
53  char uplo = 'U';
54  int ldb = n_;
55 
56  // solve Hermitian system of normal equations using Cholesky decomposition
57  cposv_(&uplo, &n_, &nrhs_, adaggera_.data(), &n_, adaggerb_.data(), &ldb,
58  &info);
59 
60  std::copy_n(adaggerb_.data(), n_ * nrhs_, b);
61 
62  // Check for full rank
63  return info == 0;
64  }
65 
66  private:
67  std::vector<std::complex<float>> adaggera_;
68  std::vector<std::complex<float>> adaggerb_;
69 };
70 
71 } // namespace ddecal
72 } // namespace dp3
73 
74 #endif
Definition: LLSSolver.h:25
int m_
Definition: LLSSolver.h:44
int n_
Definition: LLSSolver.h:45
int nrhs_
Definition: LLSSolver.h:46
Definition: NormalEquationsSolver.h:20
bool Solve(std::complex< float > *a, std::complex< float > *b) override
Definition: NormalEquationsSolver.h:33
NormalEquationsSolver(int m, int n, int nrhs)
Definition: NormalEquationsSolver.h:22
void cposv_(char *uplo, int *n, int *nrhs, std::complex< float > *a, int *lda, std::complex< float > *b, int *ldb, int *info)
This file has generic helper routines for testing steps.
Definition: AntennaConfig.h:53