DP3
IterativeDiagonalSolverCuda.h
Go to the documentation of this file.
1 // Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
2 // SPDX-License-Identifier: GPL-3.0-or-later
3 
4 #ifndef DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
5 #define DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
6 
7 #include <vector>
8 
9 #include <cudawrappers/cu.hpp>
10 
12 #include "SolverBase.h"
13 #include "SolveData.h"
14 #include "common/Timer.h"
15 
16 namespace dp3 {
17 namespace ddecal {
18 
19 template <typename VisMatrix>
21  public:
22  IterativeDiagonalSolverCuda(bool keep_buffers = false);
24  std::vector<std::vector<DComplex>>& solutions, double time,
25  std::ostream* stat_stream) override;
26 
27  size_t NSolutionPolarizations() const override { return 2; }
28 
29  bool SupportsDdSolutionIntervals() const override { return true; }
30 
31  private:
32  void AllocateGPUBuffers(const SolveData<VisMatrix>& data);
33  void DeallocateHostBuffers();
34  void AllocateHostBuffers(const SolveData<VisMatrix>& data);
35 
36  void CopyHostToHost(size_t ch_block, bool first_iteration,
37  const SolveData<VisMatrix>& data,
38  const std::vector<DComplex>& solutions,
39  cu::Stream& stream);
40 
41  void CopyHostToDevice(size_t ch_block, size_t buffer_id, cu::Stream& stream,
42  cu::Event& event, const SolveData<VisMatrix>& data);
43 
44  void PostProcessing(size_t& iteration, double time,
45  bool has_previously_converged, bool& has_converged,
46  bool& constraints_satisfied, bool& done,
48  std::vector<std::vector<DComplex>>& solutions,
49  SolutionSpan& next_solutions,
50  std::vector<double>& step_magnitudes,
51  std::ostream* stat_stream);
52 
54  bool gpu_buffers_initialized_ = false;
56  bool host_buffers_initialized_ = false;
57 
60  bool keep_buffers_ = false;
61 
62  std::unique_ptr<cu::Device> device_;
63  std::unique_ptr<cu::Context> context_;
64  std::unique_ptr<cu::Stream> execute_stream_;
65  std::unique_ptr<cu::Stream> host_to_device_stream_;
66  std::unique_ptr<cu::Stream> device_to_host_stream_;
67 
93  struct GPUBuffers {
94  // <2>[n_antennas][2], uint32_t
95  std::vector<cu::DeviceMemory> antenna_pairs;
96  // <2>[n_directions][n_visibilities], uint32_t
97  std::vector<cu::DeviceMemory> solution_map;
98  // <2>[n_visibilities], DComplex
99  std::vector<cu::DeviceMemory> solutions;
100  // <2>[n_visibilities], DComplex
101  std::vector<cu::DeviceMemory> next_solutions;
102  // <2>[n_directions][n_visibilities], MC2x2F
103  std::vector<cu::DeviceMemory> model;
104  // <3>[n_visibilities], MC2x2F
105  std::vector<cu::DeviceMemory> residual;
106  // [n_antennas][n_directions], MC2x2FDiag
107  std::unique_ptr<cu::DeviceMemory> numerator;
108  // [n_antennas][n_directions_solutions], float
109  std::unique_ptr<cu::DeviceMemory> denominator;
110  } gpu_buffers_;
111 
122  struct HostBuffers {
123  // <n_channelblocks>[n_directions][n_visibilities], MC2x2F
124  std::vector<cu::HostMemory> model;
125  // <n_channelblocks>[n_visibilities], MC2x2F
126  std::vector<cu::HostMemory> residual;
127  // <n_channelblocks>[n_visibilities], DComplex
128  std::vector<cu::HostMemory> solutions;
129  // [n_channelblocks][n_antennas][n_polarizations], DComplex
130  std::unique_ptr<cu::HostMemory> next_solutions;
131  // <n_channelblocks>[n_visibilities], std::pair<uin32_t, uint32_t>
132  std::vector<cu::HostMemory> antenna_pairs;
133  // <n_channelblocks>[n_directions][n_visibilities], uint32_t
134  std::vector<cu::HostMemory> solution_map;
135  } host_buffers_;
136 };
137 
138 } // namespace ddecal
139 } // namespace dp3
140 
141 #endif // DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
Definition: IterativeDiagonalSolverCuda.h:20
IterativeDiagonalSolverCuda(bool keep_buffers=false)
size_t NSolutionPolarizations() const override
Definition: IterativeDiagonalSolverCuda.h:27
SolveResult Solve(const SolveData< VisMatrix > &data, std::vector< std::vector< DComplex >> &solutions, double time, std::ostream *stat_stream) override
bool SupportsDdSolutionIntervals() const override
Definition: IterativeDiagonalSolverCuda.h:29
Definition: SolveData.h:29
Definition: SolverBase.h:24
aocommon::xt::Span< std::complex< double >, 4 > SolutionSpan
Definition: Solutions.h:20
This file has generic helper routines for testing steps.
Definition: AntennaConfig.h:53
Definition: SolverBase.h:61