Cheetah - SKA - PSS - Prototype Time Domain Search Pipeline
Fft.cu
1 #include "cheetah/fft/cuda/Fft.cuh"
2 #include "panda/Log.h"
3 
4 namespace ska {
5 namespace cheetah {
6 namespace fft {
7 namespace cuda {
8 namespace detail {
9 
23 template <typename T>
25 {
26 };
27 
28 template <>
29 struct CufftHelper<float>
30 {
31  typedef cufftReal RealType;
32  typedef cufftComplex ComplexType;
33 
34  static inline cufftResult r2c(cufftHandle plan, RealType const* input, ComplexType* output)
35  {
36  RealType* input_non_const = const_cast<RealType*>(input);
37  return cufftExecR2C(plan, input_non_const, output);
38  }
39 
40  static inline cufftResult c2r(cufftHandle plan, ComplexType const* input, RealType* output)
41  {
42  ComplexType* input_non_const = const_cast<ComplexType*>(input);
43  return cufftExecC2R(plan, input_non_const, output);
44  }
45 
46  static inline cufftResult c2c(cufftHandle plan, ComplexType const* input, ComplexType* output, int direction)
47  {
48  ComplexType* input_non_const = const_cast<ComplexType*>(input);
49  return cufftExecC2C(plan, input_non_const, output, direction);
50  }
51 };
52 
53 template <>
54 struct CufftHelper<double>
55 {
56  typedef cufftDoubleReal RealType;
57  typedef cufftDoubleComplex ComplexType;
58 
59  static inline cufftResult r2c(cufftHandle plan, RealType const* input, ComplexType* output)
60  {
61  RealType* input_non_const = const_cast<RealType*>(input);
62  return cufftExecD2Z(plan, input_non_const, output);
63  }
64 
65  static inline cufftResult c2r(cufftHandle plan, ComplexType const* input, RealType* output)
66  {
67  ComplexType* input_non_const = const_cast<ComplexType*>(input);
68  return cufftExecZ2D(plan, input_non_const, output);
69  }
70 
71  static inline cufftResult c2c(cufftHandle plan, ComplexType const* input, ComplexType* output, int direction)
72  {
73  ComplexType* input_non_const = const_cast<ComplexType*>(input);
74  return cufftExecZ2Z(plan, input_non_const, output, direction);
75  }
76 };
77 } // namespace detail
78 
79 
80 template <typename T, typename InputAlloc, typename OutputAlloc>
81 void Fft::process(ResourceType& gpu,
83  data::FrequencySeries<cheetah::Cuda, typename data::ComplexTypeTraits<cheetah::Cuda,T>::type,OutputAlloc>& output)
84 {
85  typedef detail::CufftHelper<T> Cufft;
86  typedef typename Cufft::RealType RealType;
87  typedef typename Cufft::ComplexType ComplexType;
88  PANDA_LOG_DEBUG << "GPU ID: "<<gpu.device_id();
89  //update the size of the output buffer to match output transform size
90  output.resize(input.size()/2 + 1);
91  //Calculate the new frequency step that the output will have
92  output.frequency_step((1.0f/(input.sampling_interval().value() * input.size())) * data::hz);
93  RealType const* in = thrust::raw_pointer_cast(input.data());
94  ComplexType* out = (ComplexType*) thrust::raw_pointer_cast(output.data());
95  CUFFT_ERROR_CHECK(Cufft::r2c(_plan.plan<T>(R2C,input.size(),1), in, out));
96 }
97 
98 template <typename T, typename InputAlloc, typename OutputAlloc>
99 void Fft::process(ResourceType& gpu,
100  data::FrequencySeries<cheetah::Cuda, typename data::ComplexTypeTraits<cheetah::Cuda,T>::type, InputAlloc> const& input,
102 {
103  typedef detail::CufftHelper<T> Cufft;
104  typedef typename Cufft::RealType RealType;
105  typedef typename Cufft::ComplexType ComplexType;
106  PANDA_LOG_DEBUG << "GPU ID: "<<gpu.device_id();
107  //update the size of the output buffer to match output transform size
108  output.resize(2*(input.size() - 1));
109  //Calculate the new sampling time that the output will have
110  output.sampling_interval((1.0f/(input.frequency_step().value()*output.size())) * data::seconds);
111  ComplexType const* in = (ComplexType*) thrust::raw_pointer_cast(input.data());
112  RealType* out = thrust::raw_pointer_cast(output.data());
113  CUFFT_ERROR_CHECK(Cufft::c2r(_plan.plan<T>(C2R,input.size(),1), in, out));
114 }
115 
116 template <typename T, typename InputAlloc, typename OutputAlloc>
117 void Fft::process(ResourceType& gpu,
118  data::TimeSeries<cheetah::Cuda, thrust::complex<T>, InputAlloc> const& input,
119  data::FrequencySeries<cheetah::Cuda, typename data::ComplexTypeTraits<cheetah::Cuda,T>::type,OutputAlloc>& output)
120 {
121  typedef detail::CufftHelper<T> Cufft;
122  typedef typename Cufft::ComplexType ComplexType;
123  PANDA_LOG_DEBUG << "GPU ID: "<<gpu.device_id();
124  //update the size of the output buffer to match output transform size
125  output.resize(input.size());
126  //Calculate the new frequency step that the output will have
127  output.frequency_step((1.0f/(input.sampling_interval().value() * input.size())) * data::hz);
128  ComplexType const* in = (ComplexType*) thrust::raw_pointer_cast(input.data());
129  ComplexType* out = (ComplexType*) thrust::raw_pointer_cast(output.data());
130  CUFFT_ERROR_CHECK(Cufft::c2c(_plan.plan<T>(C2C,input.size(),1), in, out, CUFFT_FORWARD));
131 }
132 
133 template <typename T, typename InputAlloc, typename OutputAlloc>
134 void Fft::process(ResourceType& gpu,
135  data::FrequencySeries<cheetah::Cuda, thrust::complex<T>, InputAlloc> const& input,
136  data::TimeSeries<cheetah::Cuda, typename data::ComplexTypeTraits<cheetah::Cuda,T>::type,OutputAlloc>& output)
137 {
138  typedef detail::CufftHelper<T> Cufft;
139  typedef typename Cufft::ComplexType ComplexType;
140  PANDA_LOG_DEBUG << "GPU ID: "<<gpu.device_id();
141  //update the size of the output buffer to match output transform size
142  output.resize(input.size());
143  //Calculate the new sampling time that the output will have
144  output.sampling_interval((1.0f/(input.frequency_step().value()*output.size())) * data::seconds);
145  ComplexType const* in = (ComplexType*) thrust::raw_pointer_cast(input.data());
146  ComplexType* out = (ComplexType*) thrust::raw_pointer_cast(output.data());
147  CUFFT_ERROR_CHECK(Cufft::c2c(_plan.plan<T>(C2C,input.size(),1), in, out, CUFFT_INVERSE));
148 }
149 
150 } // namespace cuda
151 } // namespace fft
152 } // namespace cheetah
153 } // namespace ska
A helper class for selecting the correct execution calls from cuFFT.
Definition: Fft.cu:24
TimeType const & sampling_interval() const
Retrive the sampling interval.
Definition: TimeSeries.cpp:64
void process(ResourceType &gpu, data::TimeSeries< cheetah::Cuda, T, InputAlloc > const &input, data::FrequencySeries< cheetah::Cuda, typename data::ComplexTypeTraits< cheetah::Cuda, T >::type, OutputAlloc > &output)
Perform a real-to-complex 1D FFT.
Definition: Fft.cu:81
A helper class to determine the type of complex data for different architectures. ...
A container of Fourier series data.
Some limits and constants for FLDO.
Definition: Brdz.h:35
Class for time series data.
Definition: TimeSeries.h:47
void resize(std::size_t size)
resize the data
Definition: Series.cpp:115
std::size_t size() const
the size of the series
Definition: Series.cpp:109