24 #include "cheetah/sps/astroaccelerate/Sps.h" 25 #include "cheetah/cuda_utils/cuda_errorhandling.h" 26 #include "cheetah/data/DmTrialsMetadata.h" 27 #include "cheetah/data/DmTrials.h" 28 #include "cheetah/data/SpCcl.h" 29 #include "cheetah/data/TimeFrequency.h" 30 #include "cheetah/data/Units.h" 31 #include "cheetah/data/DedispersionMeasure.h" 32 #include "cheetah/cuda_utils/nvtx.h" 33 #include "panda/Resource.h" 34 #include "panda/Log.h" 35 #include "panda/Error.h" 44 namespace astroaccelerate {
47 #ifdef ENABLE_ASTROACCELERATE 48 template<
class SpsTraits,
typename Enable>
49 template<
typename DmHandler,
typename SpHandler,
typename OtherBufferType>
50 void Sps<SpsTraits, Enable>::operator()(panda::PoolResource<cheetah::Cpu>&
56 throw panda::Error(
"astroaccelerate::Sps can only handle uint8_t - please deactivate this algorithm in your config");
60 template<
class SpsTraits>
61 template<
typename DmHandler,
typename SpHandler>
62 void Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::operator()(panda::PoolResource<panda::nvidia::Cuda>& gpu
64 , DmHandler& dm_handler
65 , SpHandler& sp_handler
68 auto it = this->_cuda_runner.find(gpu.device_id());
69 if(it==this->_cuda_runner.end()) {
70 TimeFrequencyType
const& data=*(agg_buf.composition().front());
71 set_dedispersion_strategy(this->_cuda_runner.at(0U).dedispersion_strategy().get_gpu_memory(), data, gpu.device_id());
72 it=this->_cuda_runner.find(gpu.device_id());
74 (*it).second(gpu, agg_buf, dm_handler, sp_handler);
77 template<
class SpsTraits>
78 template<
typename DmHandler,
typename SpHandler,
typename OtherBufferType>
79 void Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::operator()(panda::PoolResource<cheetah::Cuda>&
85 throw panda::Error(
"astroaccelerate::Sps can only handle uint8_t - please deactivate this algorithm in your config");
88 template<
class SpsTraits>
89 Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::Sps(sps::Config
const& config)
90 : BaseT(config.astroaccelerate_config(), config)
94 template<
class SpsTraits>
95 std::size_t Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::buffer_overlap()
const 97 return _cuda_runner.at(0U).buffer_overlap();
100 template<
class SpsTraits>
101 std::size_t Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::set_dedispersion_strategy(std::size_t min_gpu_memory, TimeFrequencyType
const& tf_data)
103 _cuda_runner.clear();
104 return set_dedispersion_strategy(min_gpu_memory, tf_data, 0);
107 template<
class SpsTraits>
108 std::size_t Sps<SpsTraits, EnableIfIsUint8T<SpsTraits>>::set_dedispersion_strategy(std::size_t min_gpu_memory, TimeFrequencyType
const& tf_data,
unsigned device_id)
110 if(_cuda_runner.count(device_id) == 0) {
111 _cuda_runner.insert(std::make_pair(device_id, SpsCuda(_algo_config)));
113 return _cuda_runner.at(device_id).set_dedispersion_strategy(min_gpu_memory, tf_data);
117 #endif // ENABLE_ASTROACCELERATE Some limits and constants for FLDO.