Cheetah - SKA - PSS - Prototype Time Domain Search Pipeline
Kernels.cu
1 // Plain templatization of Linus code, not changing it for now
2 
3 #define TRUE 1
4 #define FALSE 0
5 #define BASE_SENSITIVITY 1.0f
6 #define MAX_ITERS 7
7 #define FIRST_THRESHOLD 6.0f
8 #define SIR_VALUE 0.4f
9 
10 template<typename Type>
11 __device__ inline void swap(Type & a, Type & b) {
12  Type temp = a;
13  a = b;
14  b = temp;
15 }
16 
17 template<typename Type>
18 __device__ Type bitonic_sort(Type * values, int n, int nr_flagged) {
19  const int tid = threadIdx.x;
20 
21  for ( int k = 2; k <= n; k *= 2) {
22  for ( int j = k / 2; j > 0; j /= 2) {
23  int ixj = tid ^ j;
24  if ( ixj > tid ) {
25  if ( (tid & k) == 0 ) {
26  if ( values[tid] > values[ixj] ) {
27  swap(values[tid], values[ixj]);
28  }
29  } else {
30  if ( values[tid] < values[ixj] ) {
31  swap(values[tid], values[ixj]);
32  }
33  }
34  }
35  __syncthreads();
36  }
37  }
38  return values[nr_flagged + (n - nr_flagged) / 2];
39 }
40 
41 template<typename Type>
42 __device__ Type sum_values(Type * values) {
43  unsigned int tid = threadIdx.x;
44 
45  for ( unsigned int s = blockDim.x / 2; s > 32; s >>= 1 ) {
46  if ( tid < s ) {
47  values[tid] += values[tid + s];
48  }
49  __syncthreads();
50  }
51 
52  if ( tid < 32 ) {
53  values[tid] += values[tid + 32];
54  values[tid] += values[tid + 16];
55  values[tid] += values[tid + 8];
56  values[tid] += values[tid + 4];
57  values[tid] += values[tid + 2];
58  values[tid] += values[tid + 1];
59  }
60 
61  return values[0];
62 }
63 
64 __device__ void count_flags(unsigned int * nr_flagged, LocalFlagsType * flags) {
65  unsigned int tid = threadIdx.x;
66  if ( flags[tid] == TRUE ) {
67  atomicAdd(nr_flagged, 1);
68  }
69 }
70 
71 template<typename Type>
72 __device__ void sum_threshold(Type * values, LocalFlagsType * flags, float median, float stddev, int n) {
73  int window = 1;
74  int tid = threadIdx.x;
75  float factor = stddev * BASE_SENSITIVITY;
76  float sum;
77  int pos;
78  float threshold;
79 
80  for ( int i = 0; i < MAX_ITERS; i++ ) {
81  threshold = fma((FIRST_THRESHOLD * powf(1.5f, i) / window), factor, median);
82  sum = 0.0f;
83  if ( tid % window == 0 ) {
84  for ( pos = tid; pos < tid + window; pos++ ) {
85  if ( flags[pos] != TRUE ) {
86  sum += values[pos];
87  } else {
88  sum += threshold;
89  }
90  }
91  if ( sum >= window * threshold ) {
92  for ( pos = tid; pos < tid + window; pos++ ) {
93  flags[pos] = TRUE;
94  }
95  }
96  }
97  window *= 2;
98  }
99 }
100 
101 __global__ void sir_operator(LocalFlagsType * d_flags, int n) {
102  LocalFlagsType * flags = &(d_flags[(blockIdx.x * n)]);
103  float credit = 0.0f;
104  float w;
105  float max_credit0;
106 
107  for ( int i = 0; i < n; i++ ) {
108  w = flags[i] ? SIR_VALUE : SIR_VALUE - 1.0f;
109  max_credit0 = credit > 0.0f ? credit : 0.0f;
110  credit = max_credit0 + w;
111  flags[i] = credit >= 0.0f;
112  }
113  credit = 0;
114  for ( int i = n - 1; i > 0; i-- ) {
115  w = flags[i] ? SIR_VALUE : SIR_VALUE - 1.0f;
116  max_credit0 = credit > 0.0f ? credit : 0.0f;
117  credit = max_credit0 + w;
118  flags[i] = credit >= 0.0f | flags[i];
119  }
120 }
121 
122 // MODIFIED, not equivalent to Linus code because our data structures are different
123 template<typename Type>
124 __global__ void reduce_freq(Type * values, Type * results, unsigned int number_of_channels, unsigned int number_of_samples) {
125  extern __shared__ Type shared[];
126  Type result = 0;
127 
128  // NOTE: Terrible memory access pattern
129  for ( unsigned int channel = threadIdx.x; channel < number_of_channels; channel += blockDim.x) {
130  result += values[(channel * number_of_samples) + blockIdx.x];
131  }
132  shared[threadIdx.x] = result;
133  __syncthreads();
134  result = sum_values(shared);
135  if ( threadIdx.x == 0 ) {
136  results[blockIdx.x] = result;
137  }
138 }
139 
140 template<typename Type>
141 __device__ void winsorize(Type * shared, int nr_flagged, int n) {
142  if ( threadIdx.x < (0.1f * (n - nr_flagged) + nr_flagged) ) {
143  shared[threadIdx.x] = shared[(int)(0.1f * (n - nr_flagged) + nr_flagged)];
144  }
145  if ( threadIdx.x > (0.9f * (n - nr_flagged) + nr_flagged) ) {
146  shared[threadIdx.x] = shared[(int)(0.9f * (n - nr_flagged) + nr_flagged)];
147  }
148 }
149 
150 // MODIFIED, not equivalent to Linus code because our data structures are different
151 template<typename Type>
152 __global__ void reduce_time(Type * values, Type * results, unsigned int number_of_samples) {
153  extern __shared__ Type shared[];
154  Type result = 0;
155 
156  for ( unsigned int sample = threadIdx.x; sample < number_of_samples; sample++ ) {
157  result += values[(blockIdx.x * number_of_samples) + sample];
158  }
159  shared[threadIdx.x] = result;
160  __syncthreads();
161  result = sum_values(shared);
162  if ( threadIdx.x == 0 ) {
163  results[blockDim.x] = result;
164  }
165 }
166 
167 // MODIFIED, not equivalent to Linus code because our data structures are different
168 template<typename Type>
169 __global__ void flagger_freq(Type * values, LocalFlagsType * global_flags, unsigned int * nr_flagged, unsigned int number_of_channels, unsigned int number_of_samples) {
170  extern __shared__ Type shared[];
171  LocalFlagsType * local_flags = (LocalFlagsType *) &(shared[number_of_channels]);
172  unsigned int tid = threadIdx.x;
173  Type median;
174  Type stddev;
175 
176  // NOTE: terrible memory access pattern
177  shared[tid] = values[(threadIdx.x * number_of_samples) + blockIdx.x];
178  local_flags[tid] = 0;
179  __syncthreads();
180 
181  for ( unsigned int i = 0; i < 2; i++ ) {
182  Type sum = 0;
183  Type squared_sum = 0;
184  unsigned int local_nr_flagged = nr_flagged[blockIdx.x];
185 
186  median = bitonic_sort(shared, number_of_channels, local_nr_flagged);
187  if ( tid >= local_nr_flagged ) {
188  winsorize(shared, local_nr_flagged, number_of_channels);
189  }
190  __syncthreads();
191  sum = sum_values(shared);
192 
193  // NOTE: terrible memory access pattern
194  shared[tid] = values[(threadIdx.x * number_of_samples) + blockIdx.x];
195  if ( local_flags[tid] ) {
196  shared[tid] = 0;
197  }
198  __syncthreads();
199  bitonic_sort(shared, number_of_channels, local_nr_flagged);
200  if ( tid >= local_nr_flagged ) {
201  winsorize(shared, local_nr_flagged, number_of_channels);
202  shared[tid] *= shared[tid];
203  }
204  __syncthreads();
205  squared_sum = sum_values(shared);
206  stddev = sqrtf(squared_sum / number_of_channels - (sum / number_of_channels * sum / number_of_channels));
207  // NOTE: terrible memory access pattern
208  shared[tid] = values[(threadIdx.x * number_of_samples) + blockIdx.x];
209  if ( local_flags[tid] ) {
210  shared[tid] = 0;
211  }
212  __syncthreads();
213  sum_threshold(shared, local_flags, median, stddev, number_of_channels);
214  nr_flagged[blockIdx.x] = 0;
215  count_flags(&(nr_flagged[blockIdx.x]), local_flags);
216  }
217  // NOTE: terrible memory access pattern
218  global_flags[(threadIdx.x * number_of_samples) + blockIdx.x] = local_flags[tid] | global_flags[(threadIdx.x * number_of_samples) + blockIdx.x];
219 }
220 
221 // MODIFIED, not equivalent to Linus code because our data structures are different
222 template<typename Type>
223 __global__ void flagger_time(Type * values, LocalFlagsType * global_flags, unsigned int * nr_flagged, unsigned int number_of_samples) {
224  extern __shared__ Type shared[];
225  LocalFlagsType * local_flags = (LocalFlagsType *) &(shared[number_of_samples]);
226  unsigned int tid = threadIdx.x;
227  Type median;
228  Type stddev;
229 
230  shared[tid] = values[(blockIdx.x * number_of_samples) + threadIdx.x];
231  local_flags[tid] = 0;
232  __syncthreads();
233 
234  for ( unsigned int i = 0; i < 2; i++ ) {
235  Type sum = 0;
236  Type squared_sum = 0;
237  unsigned int local_nr_flagged = nr_flagged[blockIdx.x];
238 
239  median = bitonic_sort(shared, number_of_samples, local_nr_flagged);
240  if ( tid >= local_nr_flagged ) {
241  winsorize(shared, local_nr_flagged, number_of_samples);
242  }
243  __syncthreads();
244  sum = sum_values(shared);
245 
246  shared[tid] = values[(blockIdx.x * number_of_samples) + threadIdx.x];
247  if ( local_flags[tid] ) {
248  shared[tid] = 0;
249  }
250  __syncthreads();
251  bitonic_sort(shared, number_of_samples, local_nr_flagged);
252  if ( tid >= local_nr_flagged ) {
253  winsorize(shared, local_nr_flagged, number_of_samples);
254  shared[tid] *= shared[tid];
255  }
256  __syncthreads();
257  squared_sum = sum_values(shared);
258  stddev = sqrtf(squared_sum / number_of_samples - (sum / number_of_samples * sum / number_of_samples));
259  shared[tid] = values[(blockIdx.x * number_of_samples) + threadIdx.x];
260  if ( local_flags[tid] ) {
261  shared[tid] = 0;
262  }
263  __syncthreads();
264  sum_threshold(shared, local_flags, median, stddev, number_of_samples);
265  nr_flagged[blockIdx.x] = 0;
266  count_flags(&(nr_flagged[blockIdx.x]), local_flags);
267  }
268  global_flags[(blockIdx.x * number_of_samples) + threadIdx.x] = local_flags[tid] | global_flags[(blockIdx.x * number_of_samples) + threadIdx.x];
269 }
270 
271 
272 
273 
274 
275 
276 
277 
278