Cheetah - SKA - PSS - Prototype Time Domain Search Pipeline
RfimTester.cpp
1 /*
2  * The MIT License (MIT)
3  *
4  * Copyright (c) 2017 The SKA organisation
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "cheetah/rfim/test_utils/RfimTester.h"
25 #include "cheetah/sigproc/SigProcWriter.h"
26 #include "cheetah/generators/GaussianNoise.h"
27 #include "cheetah/generators/GaussianNoiseConfig.h"
28 #include "cheetah/generators/DispersedPulse.h"
29 #include "cheetah/generators/RfiScenario.h"
30 #include "cheetah/data/TimeFrequency.h"
31 #include "panda/test/TestDir.h"
32 #include <boost/units/systems/si/frequency.hpp>
33 #include <boost/units/systems/si/prefixes.hpp>
34 #include <limits>
35 #include <algorithm>
36 #include <functional>
37 #include <iostream>
38 #include <string>
39 
40 
41 namespace ska {
42 namespace cheetah {
43 namespace rfim {
44 namespace test {
45 
46 
47 template <typename TestTraits>
48 RfimTester<TestTraits>::RfimTester()
49  : cheetah::utils::test::AlgorithmTester<TestTraits>()
50 {
51 }
52 
53 template <typename TestTraits>
54 RfimTester<TestTraits>::~RfimTester()
55 {
56 }
57 
58 template<typename TestTraits>
59 void RfimTester<TestTraits>::SetUp()
60 {
61 }
62 
63 template<typename TestTraits>
64 void RfimTester<TestTraits>::TearDown()
65 {
66 }
67 
68 template<typename TestTraits>
69 void RfimTester<TestTraits>::verify_equal(typename TestTraits::DataType const& d1, typename TestTraits::DataType const& d2)
70 {
71  ASSERT_EQ(d1.number_of_samples(), d2.number_of_samples());
72  ASSERT_EQ(d1.number_of_channels(), d2.number_of_channels());
73  auto it = d2.begin();
74  for(auto const val : d1) {
75  ASSERT_DOUBLE_EQ(val, *it++);
76  }
77 }
78 
79 template<typename DataType>
80 static
81 std::shared_ptr<DataType> generate_time_freq_data(unsigned number_of_samples, unsigned number_of_channels)
82 {
83  typedef typename DataType::FrequencyType FrequencyType;
84 
85  std::shared_ptr<DataType> data = std::make_shared<DataType>(number_of_samples, number_of_channels);
86  FrequencyType delta( 1.0 * boost::units::si::mega * boost::units::si::hertz);
87  FrequencyType start( 100.0 * boost::units::si::mega * boost::units::si::hertz);
88 
89  data->set_channel_frequencies_const_width(start, delta);
90 
91  // generate a spectrum
92  return data;
93 }
94 
95 
96 template <typename TypeParam, typename DeviceType>
97 void test_rfi_algorithm(
98  std::function<void (
99  typename TypeParam::FlaggedDataType::TimeFrequencyFlagsType const &,
100  typename TypeParam::FlaggedDataType::TimeFrequencyFlagsType const &)> assertion,
101  DeviceType &device, typename TypeParam::FlaggedDataType &data)
102 {
103  using DataType = typename TypeParam::DataType;
104  using TimeFrequencyFlags = typename TypeParam::FlaggedDataType::TimeFrequencyFlagsType;
105 
106  TimeFrequencyFlags expected_flags(data.rfi_flags());
107 
108  TypeParam tester;
109  std::shared_ptr<typename TypeParam::FlaggedDataType> result = tester.apply_algorithm(device, data.tf_data());
110 
111  assertion(expected_flags, result->rfi_flags());
112 }
113 
114 namespace {
115  struct TestDataWriter
116  {
117  public:
118  TestDataWriter(std::string const& prefix, bool record_files)
119  : _test_dir(!record_files)
120  , _prefix(_test_dir.path() / prefix)
121  , _record(record_files)
122  {
123  _test_dir.create();
124  boost::filesystem::create_directory(_prefix);
125  }
126 
127  template<typename DataType>
128  void write(std::string const& test_name, DataType const& data) {
129  if(_record)
130  {
131  boost::filesystem::path dir(_prefix/test_name);
132  if(!boost::filesystem::create_directory(dir)) {
133  panda::Error e("unable to create directory:");
134  e << dir;
135  throw e;
136  }
137  std::cout << "writing to " << dir;
138  sigproc::SigProcWriter<> sigproc_writer(_prefix / test_name);
139  sigproc_writer << data::ExtractTimeFrequencyDataType<DataType>::extract(data);
140  }
141  }
142 
143  private:
144  panda::test::TestDir _test_dir;
145  boost::filesystem::path _prefix;
146  bool _record;
147  };
148 }
149 
150 template<int Num, typename TypeParam, typename DeviceType, typename Enable = void>
152  typedef typename TypeParam::DataType DataType;
153  inline static
154  void exec(DataType const&, DeviceType&) {
155  //std::string const test_name("RfiScenario<" + std::to_string(Num) + ">");
156  //std::cout << "test does not exist: " << test_name << std::endl;
157  }
158 };
159 
160 template<int Num, typename TypeParam, typename DeviceType>
161 struct RfiScenarioLauncher<Num, TypeParam, DeviceType, typename std::enable_if<std::is_constructible<generators::RfiScenario<Num, TypeParam>>::value>::type >
162 {
163  typedef typename TypeParam::DataType DataType;
164  typedef typename TypeParam::FlaggedDataType FlaggedDataType;
165  typedef typename FlaggedDataType::TimeFrequencyFlagsType TimeFrequencyFlags;
166  typedef typename DataType::NumericalRep NumRepType;
167 
168  // apply the N'th pre-defined Rfi Scenario to the data
169  inline static
170  void exec(data::RfimFlaggedData<DataType> const& base_data, DeviceType& device) {
171 
173  std::string const test_name("RfiScenario<" + std::to_string(Num) + ">: " + std::string(Scenario::description));
174  std::cout << "running test: " << test_name << std::endl;
175 
176  data::RfimFlaggedData<DataType> data(base_data);
177  Scenario scenario;
178  scenario.next(data);
179  TestDataWriter file_writer(test_name, false); // DEBUG set to true to record data files during debugging (should be false otherwise)
180  file_writer.write("tf_data_input", data);
181  test_rfi_algorithm<TypeParam>(
182  [test_name] (TimeFrequencyFlags const &expected, TimeFrequencyFlags const &given)
183  {
184  SCOPED_TRACE( test_name );
185  rfim::Metrics m(expected, given);
186  std::cout << test_name << " metrics\n"
187  << "\t" << "rfi found=" << m.num_correct() << " of " << m.num_rfi() << " (" << m.rfi_detected_percentage() << "%)\n"
188  << "\t" << "false +ve=" << m.num_false_positives() << " (" << m.false_positives_percentage() << "%)\n"
189  << "\t" << "false -ve=" << m.num_false_negatives() << " (" << m.false_negatives_percentage() << "%)\n"
190  << "\t" << "total correct=" << m.correct_percentage() << "%\n"
191  ;
193  }, device, data);
195  }
196 
197 };
198 
199 ALGORITHM_TYPED_TEST_P(RfimTester, gaussian_noise_wth_rfi)
200 {
201  // Generates a series of tests each with a gaussian noise with the RfiScenario<n> imposed on top
202  // n.b RfiScenario<0> is no RFI - i.e just gaussian noise.
203  typedef typename TypeParam::DataType DataType;
204  typedef typename TypeParam::FlaggedDataType FlaggedDataType;
205  using TimeFrequencyFlags = typename FlaggedDataType::TimeFrequencyFlagsType;
206  using NumericalRep = typename DataType::NumericalRep;
207  typename DataType::FrequencyType delta( 1.0 * boost::units::si::mega * boost::units::si::hertz);
208  typename DataType::FrequencyType start( 100.0 * boost::units::si::mega * boost::units::si::hertz);
209 
212  std::shared_ptr<DataType> time_frequency_data(new DataType(data::DimensionSize<data::Time>(2048), data::DimensionSize<data::Frequency>(1024))); // at least one channel
213  time_frequency_data->sample_interval(typename DataType::TimeType( 0.01 * boost::units::si::seconds));
214  time_frequency_data->set_channel_frequencies_const_width(start, delta);
215  data::RfimFlaggedData<DataType> flagged_time_frequency_data(time_frequency_data);
216  auto& flags = flagged_time_frequency_data.rfi_flags();
217  std::fill(flags.begin(), flags.end(), false);
218 
219  // set up a gaussian noise signal
220  noise.next(*time_frequency_data);
221 
222  // execute all RFI Scenarios
223  //typedef typename std::decay<decltype(device)>::type::element_type DeviceType;
224  typedef decltype(device) DeviceType;
225  RfiScenarioLauncher<0, TypeParam, DeviceType>::exec(flagged_time_frequency_data, device);
226 }
227 
228 ALGORITHM_TYPED_TEST_P(RfimTester, single_pulse_on_gaussian_noise_wth_rfi)
229 {
230  // Generates a series of tests each with a gaussian noise with the RfiScenario<n> imposed on top
231  // n.b RfiScenario<0> is no RFI - i.e just gaussian noise.
232  typedef typename TypeParam::DataType DataType;
233  typedef typename TypeParam::FlaggedDataType FlaggedDataType;
234  using TimeFrequencyFlags = typename FlaggedDataType::TimeFrequencyFlagsType;
235  using NumericalRep = typename DataType::NumericalRep;
236  typename DataType::FrequencyType delta( 1.0 * boost::units::si::mega * boost::units::si::hertz);
237  typename DataType::FrequencyType start( 100.0 * boost::units::si::mega * boost::units::si::hertz);
238 
242  generators::DispersedPulse<NumericalRep> pulse_generator(pulse_config);
243 
244  std::shared_ptr<DataType> time_frequency_data(new DataType(data::DimensionSize<data::Time>(16384), data::DimensionSize<data::Frequency>(1024))); // at least one channel
245  time_frequency_data->set_channel_frequencies_const_width(start, delta);
246  time_frequency_data->sample_interval(typename DataType::TimeType( 0.01 * boost::units::si::seconds));
247  data::RfimFlaggedData<DataType> flagged_time_frequency_data(time_frequency_data);
248 
249  // set up a gaussian noise signal
250  noise.next(*time_frequency_data);
251  pulse_generator.next(flagged_time_frequency_data);
252 
253  // execute all RFI Scenarios
254  //typedef typename std::decay<decltype(device)>::type::element_type DeviceType;
255  typedef decltype(device) DeviceType;
256  RfiScenarioLauncher<0, TypeParam, DeviceType>::exec(flagged_time_frequency_data, device);
257 }
258 
259 // each test defined by ALGORITHM_TYPED_TEST_P must be added to the
260 // test register (each one as an element of the comma seperated list)
261 REGISTER_TYPED_TEST_CASE_P(RfimTester, gaussian_noise_wth_rfi, single_pulse_on_gaussian_noise_wth_rfi);
262 
263 } // namespace test
264 } // namespace rfim
265 } // namespace cheetah
266 } // namespace ska
float correct_percentage() const
return the total of any correct flags as a percentage
Definition: Metrics.cpp:58
float rfi_detected_percentage() const
return the total of any correct flags as a percentage
Definition: Metrics.cpp:51
std::size_t num_correct() const
return the total of any correctly identified flags found
Definition: Metrics.cpp:41
TimeFrequency data with flags representing rfim detection.
Some limits and constants for FLDO.
Definition: Brdz.h:35
std::size_t num_false_positives() const
return the total of any false positives detected
Definition: Metrics.cpp:78
Configuration parameters for the DispersedPulse generator.
std::size_t num_false_negatives() const
return the number of any false negatives detected
Definition: Metrics.cpp:88
A class for analysing and storing the results of the difference between two sets of flags representin...
Definition: Metrics.h:39
std::size_t num_rfi() const
return the number of rfi flags in the expected data
Definition: Metrics.cpp:46
Collection of RFI scenarios.
Definition: RfiScenario.h:43
Inject a single pulse at a specified dispersion measure.
Specialise to set non-standard pass/fail criteria.
Definition: RfimTester.h:116