DP3
SolverBase.h
Go to the documentation of this file.
1 // Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
2 // SPDX-License-Identifier: GPL-3.0-or-later
3 
4 #ifndef DDECAL_SOLVER_BASE_H
5 #define DDECAL_SOLVER_BASE_H
6 
7 #include <algorithm>
8 #include <cassert>
9 #include <complex>
10 #include <iosfwd>
11 #include <memory>
12 #include <stdexcept>
13 #include <vector>
14 
15 #include <aocommon/recursivefor.h>
16 
17 #include "../constraints/Constraint.h"
18 #include "../linear_solvers/LLSSolver.h"
19 #include "SolveData.h"
20 
21 namespace dp3 {
22 namespace ddecal {
23 
24 class SolverBase {
25  public:
26  typedef std::complex<double> DComplex;
27  typedef std::complex<float> Complex;
28 
29  class Matrix final {
30  public:
31  Matrix() : data_(), columns_(0) {}
32  Matrix(size_t columns, size_t rows)
33  : data_(columns * rows, Complex(0.0, 0.0)), columns_(columns) {}
34  void SetZero() { std::fill(data_.begin(), data_.end(), Complex(0.0, 0.0)); }
35  Complex& operator()(size_t column, size_t row) {
36  return data_[column + row * columns_];
37  }
38  Complex* data() { return data_.data(); }
39 
40  // Re-initialize the object to a new size.
41  //
42  // Like the constructor this sets all data element to 0+0i.
43  void Reset(size_t columns, size_t rows) {
44  size_t size = columns * rows;
45  // Minimize the number of elements to modify.
46  if (size < data_.size()) {
47  data_.resize(size);
48  SetZero();
49  } else {
50  SetZero();
51  data_.resize(size, Complex(0.0, 0.0));
52  }
53  columns_ = columns;
54  }
55 
56  private:
57  std::vector<Complex> data_;
58  size_t columns_;
59  };
60 
61  struct SolveResult {
62  size_t iterations = 0;
64  std::vector<std::vector<ConstraintResult>> results;
65  };
66 
68 
69  virtual ~SolverBase() = default;
70 
76  virtual void Initialize(size_t n_antennas,
77  const std::vector<size_t>& n_solutions_per_direction,
78  size_t n_channel_blocks);
79 
83  virtual size_t NSolutionPolarizations() const = 0;
84 
89  void AddConstraint(std::unique_ptr<Constraint> constraint) {
90  assert(constraint);
91  constraints_.push_back(std::move(constraint));
92  }
93 
99  const std::vector<std::unique_ptr<Constraint>>& GetConstraints() {
100  return constraints_;
101  }
102 
109  bool GetPhaseOnly() const { return phase_only_; }
110  void SetPhaseOnly(bool phase_only) { phase_only_ = phase_only; }
117  size_t GetMaxIterations() const { return max_iterations_; }
118  void SetMaxIterations(size_t max_iterations) {
119  max_iterations_ = max_iterations;
120  }
127  size_t GetMinIterations() const { return min_iterations_; }
128  void SetMinIterations(size_t min_iterations) {
129  min_iterations_ = min_iterations;
130  }
137  void SetAccuracy(double accuracy) { accuracy_ = accuracy; }
138  double GetAccuracy() const { return accuracy_; }
145  void SetConstraintAccuracy(double constraint_accuracy) {
146  constraint_accuracy_ = constraint_accuracy;
147  }
148  double GetConstraintAccuracy() const { return constraint_accuracy_; }
156  void SetStepSize(double step_size) { step_size_ = step_size; }
157  double GetStepSize() const { return step_size_; }
164  void SetDetectStalling(bool detect_stalling, double step_diff_sigma) {
165  detect_stalling_ = detect_stalling;
166  step_diff_sigma_ = step_diff_sigma;
167  }
168  bool GetDetectStalling() const { return detect_stalling_; }
174  void GetTimings(std::ostream& os, double duration) const;
175 
176  void SetLLSSolverType(LLSSolverType solver_type);
177  LLSSolverType GetLLSSolverType() const { return lls_solver_type_; }
178 
185  virtual bool SupportsDdSolutionIntervals() const { return false; }
186 
193  virtual std::vector<SolverBase*> ConstraintSolvers() { return {this}; }
194 
205  virtual SolveResult Solve(const FullSolveData& data,
206  std::vector<std::vector<DComplex>>& solutions,
207  double time) {
208  throw std::logic_error(
209  "Full-visibility solver called for a solver that does not "
210  "support full-visibility solving");
211  }
212 
213  virtual SolveResult Solve(const DuoSolveData& data,
214  std::vector<std::vector<DComplex>>& solutions,
215  double time) {
216  throw std::logic_error(
217  "Duo-visibility (xx/yy) solver called for a solver that does not "
218  "support duo-visibility solving");
219  }
220 
221  virtual SolveResult Solve(const UniSolveData& data,
222  std::vector<std::vector<DComplex>>& solutions,
223  double time) {
224  throw std::logic_error(
225  "Single-visibility (xx/yy) solver called for a solver that does not "
226  "support single-visibility solving");
227  }
228 
230  void SetDdConstraintWeights(const std::vector<std::vector<double>>& weights);
231 
232  protected:
233  void Step(const std::vector<std::vector<DComplex>>& solutions,
234  SolutionTensor& next_solutions) const;
235 
236  bool DetectStall(size_t iteration,
237  const std::vector<double>& step_magnitudes);
238 
240  std::vector<std::vector<DComplex>>& solutions);
241 
243  std::vector<std::vector<DComplex>>& solutions);
244 
246  std::vector<std::vector<DComplex>>& solutions);
247 
249 
250  bool ApplyConstraints(size_t iteration, double time,
251  bool has_previously_converged,
252  SolutionTensor& next_solutions) const;
253  bool ApplyConstraints(size_t iteration, double time,
254  bool has_previously_converged,
255  SolutionSpan& next_solutions) const;
256 
257  SolveResult MakeResult(size_t iteration, bool has_converged,
258  bool constraints_satisfied) const;
259 
265  bool AssignSolutions(std::vector<std::vector<DComplex>>& solutions,
266  SolutionTensor& new_solutions,
267  bool use_constraint_accuracy, double& avg_abs_diff,
268  std::vector<double>& step_magnitudes) const;
269  bool AssignSolutions(std::vector<std::vector<DComplex>>& solutions,
270  SolutionSpan& new_solutions,
271  bool use_constraint_accuracy, double& avg_abs_diff,
272  std::vector<double>& step_magnitudes) const;
273 
274  bool ReachedStoppingCriterion(size_t iteration, bool has_converged,
275  bool constraints_satisfied,
276  const std::vector<double>& step_magnitudes) {
277  bool has_stalled = false;
278  if (detect_stalling_ && constraints_satisfied)
279  has_stalled = DetectStall(iteration, step_magnitudes);
280 
281  const bool is_ready = iteration >= max_iterations_ ||
282  (has_converged && constraints_satisfied) ||
283  has_stalled;
284  return iteration >= min_iterations_ && is_ready;
285  }
286 
287  size_t NAntennas() const { return n_antennas_; }
288  size_t NDirections() const { return n_directions_; }
289  size_t NChannelBlocks() const { return n_channel_blocks_; }
295  size_t NSubSolutions() const { return n_sub_solutions_; }
299  size_t NVisibilities() const {
300  return NChannelBlocks() * NAntennas() * NSubSolutions() *
302  }
303 
312  size_t NSubThreads() const;
313 
321  std::unique_ptr<aocommon::RecursiveFor> MakeOptionalRecursiveFor() const;
322 
327  std::unique_ptr<LLSSolver> CreateLLSSolver(size_t m, size_t n,
328  size_t nrhs) const;
329 
330  private:
331  size_t n_antennas_;
332  size_t n_directions_;
333  size_t n_channel_blocks_;
334  size_t n_sub_solutions_;
335 
340  size_t min_iterations_;
341  size_t max_iterations_;
342  double accuracy_;
343  double constraint_accuracy_;
344  double step_size_;
345  bool detect_stalling_;
346  double step_diff_sigma_;
347 
348  bool phase_only_;
349  std::vector<std::unique_ptr<Constraint>> constraints_;
350 
351  LLSSolverType lls_solver_type_;
352 
357  size_t n_var_count_;
358  double step_mean_;
359  double step_var_;
361 };
362 
363 } // namespace ddecal
364 } // namespace dp3
365 
366 #endif
Definition: SolveData.h:29
Definition: SolverBase.h:29
void Reset(size_t columns, size_t rows)
Definition: SolverBase.h:43
Matrix(size_t columns, size_t rows)
Definition: SolverBase.h:32
Complex * data()
Definition: SolverBase.h:38
void SetZero()
Definition: SolverBase.h:34
Matrix()
Definition: SolverBase.h:31
Complex & operator()(size_t column, size_t row)
Definition: SolverBase.h:35
Definition: SolverBase.h:24
bool GetDetectStalling() const
Definition: SolverBase.h:168
double GetStepSize() const
Definition: SolverBase.h:157
LLSSolverType GetLLSSolverType() const
Definition: SolverBase.h:177
void SetStepSize(double step_size)
Definition: SolverBase.h:156
bool ReachedStoppingCriterion(size_t iteration, bool has_converged, bool constraints_satisfied, const std::vector< double > &step_magnitudes)
Definition: SolverBase.h:274
SolveResult MakeResult(size_t iteration, bool has_converged, bool constraints_satisfied) const
bool DetectStall(size_t iteration, const std::vector< double > &step_magnitudes)
size_t NSubThreads() const
size_t GetMaxIterations() const
Definition: SolverBase.h:117
void SetMinIterations(size_t min_iterations)
Definition: SolverBase.h:128
double GetAccuracy() const
Definition: SolverBase.h:138
virtual void Initialize(size_t n_antennas, const std::vector< size_t > &n_solutions_per_direction, size_t n_channel_blocks)
void SetLLSSolverType(LLSSolverType solver_type)
bool ApplyConstraints(size_t iteration, double time, bool has_previously_converged, SolutionTensor &next_solutions) const
const std::vector< std::unique_ptr< Constraint > > & GetConstraints()
Definition: SolverBase.h:99
static void MakeSolutionsFinite1Pol(std::vector< std::vector< DComplex >> &solutions)
size_t NAntennas() const
Definition: SolverBase.h:287
virtual bool SupportsDdSolutionIntervals() const
Definition: SolverBase.h:185
bool GetPhaseOnly() const
Definition: SolverBase.h:109
void SetAccuracy(double accuracy)
Definition: SolverBase.h:137
void SetPhaseOnly(bool phase_only)
Definition: SolverBase.h:110
double GetConstraintAccuracy() const
Definition: SolverBase.h:148
bool AssignSolutions(std::vector< std::vector< DComplex >> &solutions, SolutionSpan &new_solutions, bool use_constraint_accuracy, double &avg_abs_diff, std::vector< double > &step_magnitudes) const
std::complex< float > Complex
Definition: SolverBase.h:27
void Step(const std::vector< std::vector< DComplex >> &solutions, SolutionTensor &next_solutions) const
size_t NChannelBlocks() const
Definition: SolverBase.h:289
static void MakeSolutionsFinite4Pol(std::vector< std::vector< DComplex >> &solutions)
size_t NDirections() const
Definition: SolverBase.h:288
std::complex< double > DComplex
Definition: SolverBase.h:26
virtual ~SolverBase()=default
void SetMaxIterations(size_t max_iterations)
Definition: SolverBase.h:118
bool ApplyConstraints(size_t iteration, double time, bool has_previously_converged, SolutionSpan &next_solutions) const
void SetConstraintAccuracy(double constraint_accuracy)
Definition: SolverBase.h:145
std::unique_ptr< LLSSolver > CreateLLSSolver(size_t m, size_t n, size_t nrhs) const
static void MakeSolutionsFinite2Pol(std::vector< std::vector< DComplex >> &solutions)
void SetDdConstraintWeights(const std::vector< std::vector< double >> &weights)
virtual std::vector< SolverBase * > ConstraintSolvers()
Definition: SolverBase.h:193
size_t NSubSolutions() const
Definition: SolverBase.h:295
virtual SolveResult Solve(const DuoSolveData &data, std::vector< std::vector< DComplex >> &solutions, double time)
Definition: SolverBase.h:213
size_t NVisibilities() const
Definition: SolverBase.h:299
void GetTimings(std::ostream &os, double duration) const
virtual SolveResult Solve(const FullSolveData &data, std::vector< std::vector< DComplex >> &solutions, double time)
Definition: SolverBase.h:205
bool AssignSolutions(std::vector< std::vector< DComplex >> &solutions, SolutionTensor &new_solutions, bool use_constraint_accuracy, double &avg_abs_diff, std::vector< double > &step_magnitudes) const
virtual SolveResult Solve(const UniSolveData &data, std::vector< std::vector< DComplex >> &solutions, double time)
Definition: SolverBase.h:221
virtual size_t NSolutionPolarizations() const =0
void SetDetectStalling(bool detect_stalling, double step_diff_sigma)
Definition: SolverBase.h:164
std::unique_ptr< aocommon::RecursiveFor > MakeOptionalRecursiveFor() const
size_t GetMinIterations() const
Definition: SolverBase.h:127
void AddConstraint(std::unique_ptr< Constraint > constraint)
Definition: SolverBase.h:89
xt::xtensor< std::complex< double >, 4 > SolutionTensor
Definition: Solutions.h:19
aocommon::xt::Span< std::complex< double >, 4 > SolutionSpan
Definition: Solutions.h:20
LLSSolverType
Definition: LLSSolver.h:18
This file has generic helper routines for testing steps.
Definition: AntennaConfig.h:53
Definition: SolverBase.h:61
size_t iterations
Definition: SolverBase.h:62
size_t constraint_iterations
Definition: SolverBase.h:63
std::vector< std::vector< ConstraintResult > > results
Definition: SolverBase.h:64