Sacado Package Browser (Single Doxygen Collection)  Version of the Day
Sacado_Fad_Exp_Atomic.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ***********************************************************************
3 //
4 // Sacado Package
5 // Copyright (2006) Sandia Corporation
6 //
7 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8 // the U.S. Government retains certain rights in this software.
9 //
10 // This library is free software; you can redistribute it and/or modify
11 // it under the terms of the GNU Lesser General Public License as
12 // published by the Free Software Foundation; either version 2.1 of the
13 // License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 //
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
23 // USA
24 // Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
25 // (etphipp@sandia.gov).
26 //
27 // ***********************************************************************
28 // @HEADER
29 
30 #ifndef SACADO_FAD_EXP_ATOMIC_HPP
31 #define SACADO_FAD_EXP_ATOMIC_HPP
32 
33 #include "Sacado_ConfigDefs.h"
34 #if defined(HAVE_SACADO_KOKKOSCORE)
35 
37 #include "Kokkos_Atomic.hpp"
38 #include "impl/Kokkos_Error.hpp"
39 
40 namespace Sacado {
41 
42  namespace Fad {
43  namespace Exp {
44 
45  // Overload of Kokkos::atomic_add for ViewFad types.
46  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
48  void atomic_add(ViewFadPtr<ValT,sl,ss,U> dst, const Expr<T>& xx) {
49  using Kokkos::atomic_add;
50 
51  const typename Expr<T>::derived_type& x = xx.derived();
52 
53  const int xsz = x.size();
54  const int sz = dst->size();
55 
56  // We currently cannot handle resizing since that would need to be
57  // done atomically.
58  if (xsz > sz)
59  Kokkos::abort(
60  "Sacado error: Fad resize within atomic_add() not supported!");
61 
62  if (xsz != sz && sz > 0 && xsz > 0)
63  Kokkos::abort(
64  "Sacado error: Fad assignment of incompatiable sizes!");
65 
66 
67  if (sz > 0 && xsz > 0) {
69  atomic_add(&(dst->fastAccessDx(i)), x.fastAccessDx(i));
70  }
72  atomic_add(&(dst->val()), x.val());
73  }
74 
75  namespace Impl {
76  // Our implementation of Kokkos::atomic_oper_fetch() and
77  // Kokkos::atomic_fetch_oper() for Sacado types
78  template <typename Oper, typename DestPtrT, typename ValT, typename T>
80  typename Sacado::BaseExprType< Expr<T> >::type
81  atomic_oper_fetch_impl(const Oper& op, DestPtrT dest, ValT* dest_val,
82  const Expr<T>& x)
83  {
84  typedef typename Sacado::BaseExprType< Expr<T> >::type return_type;
85  const typename Expr<T>::derived_type& val = x.derived();
86 
87 #if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST)
88  while (!Kokkos::Impl::lock_address_host_space((void*)dest_val))
89  ;
90  Kokkos::memory_fence();
91  return_type return_val = op.apply(*dest, val);
92  *dest = return_val;
93  Kokkos::memory_fence();
94  Kokkos::Impl::unlock_address_host_space((void*)dest_val);
95  return return_val;
96 #elif defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_CUDA)
97  // It is not allowed to define SACADO_VIEW_CUDA_HIERARCHICAL or
98  // SACADO_VIEW_CUDA_HIERARCHICAL_DFAD and use Sacado inside a team-based
99  // kernel without Sacado hierarchical parallelism. So use the
100  // team-based version only if blockDim.x > 1 (i.e., a team policy)
101 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD)
102  const bool use_team = (blockDim.x > 1);
103 #else
104  const bool use_team = false;
105 #endif
106  if (use_team) {
107  int go = 1;
108  while (go) {
109  if (threadIdx.x == 0)
110  go = !Kokkos::Impl::lock_address_cuda_space((void*)dest_val);
111  go = Kokkos::shfl(go, 0, blockDim.x);
112  }
113  Kokkos::memory_fence();
114  return_type return_val = op.apply(*dest, val);
115  *dest = return_val;
116  Kokkos::memory_fence();
117  if (threadIdx.x == 0)
118  Kokkos::Impl::unlock_address_cuda_space((void*)dest_val);
119  return return_val;
120  }
121  else {
122  return_type return_val;
123  // This is a way to (hopefully) avoid dead lock in a warp
124  int done = 0;
125 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK
126  unsigned int mask = KOKKOS_IMPL_CUDA_ACTIVEMASK;
127  unsigned int active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, 1);
128 #else
129  unsigned int active = KOKKOS_IMPL_CUDA_BALLOT(1);
130 #endif
131  unsigned int done_active = 0;
132  while (active != done_active) {
133  if (!done) {
134  if (Kokkos::Impl::lock_address_cuda_space((void*)dest_val)) {
135  Kokkos::memory_fence();
136  return_val = op.apply(*dest, val);
137  *dest = return_val;
138  Kokkos::memory_fence();
139  Kokkos::Impl::unlock_address_cuda_space((void*)dest_val);
140  done = 1;
141  }
142  }
143 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK
144  done_active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, done);
145 #else
146  done_active = KOKKOS_IMPL_CUDA_BALLOT(done);
147 #endif
148  }
149  return return_val;
150  }
151 #elif defined(__HIP_DEVICE_COMPILE__)
152  // FIXME_HIP
153  Kokkos::abort("atomic_oper_fetch not implemented for large types.");
154  return_type return_val;
155  int done = 0;
156  unsigned int active = __ballot(1);
157  unsigned int done_active = 0;
158  while (active != done_active) {
159  if (!done) {
160  // if (Kokkos::Impl::lock_address_hip_space((void*)dest_val))
161  {
162  return_val = op.apply(*dest, val);
163  *dest = return_val;
164  // Kokkos::Impl::unlock_address_hip_space((void*)dest_val);
165  done = 1;
166  }
167  }
168  done_active = __ballot(done);
169  }
170  return return_val;
171 #endif
172  }
173 
174  template <typename Oper, typename DestPtrT, typename ValT, typename T>
176  typename Sacado::BaseExprType< Expr<T> >::type
177  atomic_fetch_oper_impl(const Oper& op, DestPtrT dest, ValT* dest_val,
178  const Expr<T>& x)
179  {
180  typedef typename Sacado::BaseExprType< Expr<T> >::type return_type;
181  const typename Expr<T>::derived_type& val = x.derived();
182 
183 #ifdef KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST
184  while (!Kokkos::Impl::lock_address_host_space((void*)dest_val))
185  ;
186  Kokkos::memory_fence();
187  return_type return_val = *dest;
188  *dest = op.apply(return_val, val);
189  Kokkos::memory_fence();
190  Kokkos::Impl::unlock_address_host_space((void*)dest_val);
191  return return_val;
192 #elif defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_CUDA)
193  // It is not allowed to define SACADO_VIEW_CUDA_HIERARCHICAL or
194  // SACADO_VIEW_CUDA_HIERARCHICAL_DFAD and use Sacado inside a team-based
195  // kernel without Sacado hierarchical parallelism. So use the
196  // team-based version only if blockDim.x > 1 (i.e., a team policy)
197 #if defined(SACADO_VIEW_CUDA_HIERARCHICAL) || defined(SACADO_VIEW_CUDA_HIERARCHICAL_DFAD)
198  const bool use_team = (blockDim.x > 1);
199 #else
200  const bool use_team = false;
201 #endif
202  if (use_team) {
203  int go = 1;
204  while (go) {
205  if (threadIdx.x == 0)
206  go = !Kokkos::Impl::lock_address_cuda_space((void*)dest_val);
207  go = Kokkos::shfl(go, 0, blockDim.x);
208  }
209  Kokkos::memory_fence();
210  return_type return_val = *dest;
211  *dest = op.apply(return_val, val);
212  Kokkos::memory_fence();
213  if (threadIdx.x == 0)
214  Kokkos::Impl::unlock_address_cuda_space((void*)dest_val);
215  return return_val;
216  }
217  else {
218  return_type return_val;
219  // This is a way to (hopefully) avoid dead lock in a warp
220  int done = 0;
221 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK
222  unsigned int mask = KOKKOS_IMPL_CUDA_ACTIVEMASK;
223  unsigned int active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, 1);
224 #else
225  unsigned int active = KOKKOS_IMPL_CUDA_BALLOT(1);
226 #endif
227  unsigned int done_active = 0;
228  while (active != done_active) {
229  if (!done) {
230  if (Kokkos::Impl::lock_address_cuda_space((void*)dest_val)) {
231  Kokkos::memory_fence();
232  return_val = *dest;
233  *dest = op.apply(return_val, val);
234  Kokkos::memory_fence();
235  Kokkos::Impl::unlock_address_cuda_space((void*)dest_val);
236  done = 1;
237  }
238  }
239 #ifdef KOKKOS_IMPL_CUDA_SYNCWARP_NEEDS_MASK
240  done_active = KOKKOS_IMPL_CUDA_BALLOT_MASK(mask, done);
241 #else
242  done_active = KOKKOS_IMPL_CUDA_BALLOT(done);
243 #endif
244  }
245  return return_val;
246  }
247 #elif defined(__HIP_DEVICE_COMPILE__)
248  // FIXME_HIP
249  Kokkos::abort("atomic_oper_fetch not implemented for large types.");
250  return_type return_val;
251  int done = 0;
252  unsigned int active = __ballot(1);
253  unsigned int done_active = 0;
254  while (active != done_active) {
255  if (!done) {
256  // if (Kokkos::Impl::lock_address_hip_space((void*)dest_val))
257  {
258  return_val = *dest;
259  *dest = op.apply(return_val, val);
260  // Kokkos::Impl::unlock_address_hip_space((void*)dest_val);
261  done = 1;
262  }
263  }
264  done_active = __ballot(done);
265  }
266  return return_val;
267 #endif
268  }
269 
270  // Overloads of Kokkos::atomic_oper_fetch/Kokkos::atomic_fetch_oper
271  // for Sacado types
272  template <typename Oper, typename S>
273  SACADO_INLINE_FUNCTION GeneralFad<S>
274  atomic_oper_fetch(const Oper& op, GeneralFad<S>* dest,
275  const GeneralFad<S>& val)
276  {
277  return Impl::atomic_oper_fetch_impl(op, dest, &(dest->val()), val);
278  }
279  template <typename Oper, typename ValT, unsigned sl, unsigned ss,
280  typename U, typename T>
282  atomic_oper_fetch(const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
283  const Expr<T>& val)
284  {
285  return Impl::atomic_oper_fetch_impl(op, dest, &dest.val(), val);
286  }
287 
288  template <typename Oper, typename S>
289  SACADO_INLINE_FUNCTION GeneralFad<S>
290  atomic_fetch_oper(const Oper& op, GeneralFad<S>* dest,
291  const GeneralFad<S>& val)
292  {
293  return Impl::atomic_fetch_oper_impl(op, dest, &(dest->val()), val);
294  }
295  template <typename Oper, typename ValT, unsigned sl, unsigned ss,
296  typename U, typename T>
298  atomic_fetch_oper(const Oper& op, ViewFadPtr<ValT,sl,ss,U> dest,
299  const Expr<T>& val)
300  {
301  return Impl::atomic_fetch_oper_impl(op, dest, &dest.val(), val);
302  }
303 
304  // Our definition of the various Oper classes to be more type-flexible
305  struct MaxOper {
306  template <class Scalar1, class Scalar2>
307  KOKKOS_FORCEINLINE_FUNCTION
308  static auto apply(const Scalar1& val1, const Scalar2& val2)
309  -> decltype(max(val1,val2))
310  {
311  return max(val1,val2);
312  }
313  };
314  struct MinOper {
315  template <class Scalar1, class Scalar2>
316  KOKKOS_FORCEINLINE_FUNCTION
317  static auto apply(const Scalar1& val1, const Scalar2& val2)
318  -> decltype(min(val1,val2))
319  {
320  return min(val1,val2);
321  }
322  };
323  struct AddOper {
324  template <class Scalar1, class Scalar2>
325  KOKKOS_FORCEINLINE_FUNCTION
326  static auto apply(const Scalar1& val1, const Scalar2& val2)
327  -> decltype(val1+val2)
328  {
329  return val1 + val2;
330  }
331  };
332  struct SubOper {
333  template <class Scalar1, class Scalar2>
334  KOKKOS_FORCEINLINE_FUNCTION
335  static auto apply(const Scalar1& val1, const Scalar2& val2)
336  -> decltype(val1-val2)
337  {
338  return val1 - val2;
339  }
340  };
341  struct MulOper {
342  template <class Scalar1, class Scalar2>
343  KOKKOS_FORCEINLINE_FUNCTION
344  static auto apply(const Scalar1& val1, const Scalar2& val2)
345  -> decltype(val1*val2)
346  {
347  return val1 * val2;
348  }
349  };
350  struct DivOper {
351  template <class Scalar1, class Scalar2>
352  KOKKOS_FORCEINLINE_FUNCTION
353  static auto apply(const Scalar1& val1, const Scalar2& val2)
354  -> decltype(val1/val2)
355  {
356  return val1 / val2;
357  }
358  };
359 
360  } // Impl
361 
362  // Overload of Kokkos::atomic_*_fetch() and Kokkos::atomic_fetch_*()
363  // for Sacado types
364  template <typename S>
365  SACADO_INLINE_FUNCTION GeneralFad<S>
366  atomic_max_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
367  return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
368  }
369  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
371  atomic_max_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
372  return Impl::atomic_oper_fetch(Impl::MaxOper(), dest, val);
373  }
374  template <typename S>
375  SACADO_INLINE_FUNCTION GeneralFad<S>
376  atomic_min_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
377  return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
378  }
379  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
381  atomic_min_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
382  return Impl::atomic_oper_fetch(Impl::MinOper(), dest, val);
383  }
384  template <typename S>
385  SACADO_INLINE_FUNCTION GeneralFad<S>
386  atomic_add_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
387  return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
388  }
389  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
391  atomic_add_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
392  return Impl::atomic_oper_fetch(Impl::AddOper(), dest, val);
393  }
394  template <typename S>
395  SACADO_INLINE_FUNCTION GeneralFad<S>
396  atomic_sub_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
397  return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
398  }
399  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
401  atomic_sub_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
402  return Impl::atomic_oper_fetch(Impl::SubOper(), dest, val);
403  }
404  template <typename S>
405  SACADO_INLINE_FUNCTION GeneralFad<S>
406  atomic_mul_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
407  return atomic_oper_fetch(Impl::MulOper(), dest, val);
408  }
409  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
411  atomic_mul_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
412  return Impl::atomic_oper_fetch(Impl::MulOper(), dest, val);
413  }
414  template <typename S>
415  SACADO_INLINE_FUNCTION GeneralFad<S>
416  atomic_div_fetch(GeneralFad<S>* dest, const GeneralFad<S>& val) {
417  return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
418  }
419  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
421  atomic_div_fetch(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
422  return Impl::atomic_oper_fetch(Impl::DivOper(), dest, val);
423  }
424 
425  template <typename S>
426  SACADO_INLINE_FUNCTION GeneralFad<S>
427  atomic_fetch_max(GeneralFad<S>* dest, const GeneralFad<S>& val) {
428  return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
429  }
430  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
432  atomic_fetch_max(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
433  return Impl::atomic_fetch_oper(Impl::MaxOper(), dest, val);
434  }
435  template <typename S>
436  SACADO_INLINE_FUNCTION GeneralFad<S>
437  atomic_fetch_min(GeneralFad<S>* dest, const GeneralFad<S>& val) {
438  return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
439  }
440  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
442  atomic_fetch_min(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
443  return Impl::atomic_fetch_oper(Impl::MinOper(), dest, val);
444  }
445  template <typename S>
446  SACADO_INLINE_FUNCTION GeneralFad<S>
447  atomic_fetch_add(GeneralFad<S>* dest, const GeneralFad<S>& val) {
448  return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
449  }
450  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
452  atomic_fetch_add(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
453  return Impl::atomic_fetch_oper(Impl::AddOper(), dest, val);
454  }
455  template <typename S>
456  SACADO_INLINE_FUNCTION GeneralFad<S>
457  atomic_fetch_sub(GeneralFad<S>* dest, const GeneralFad<S>& val) {
458  return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
459  }
460  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
462  atomic_fetch_sub(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
463  return Impl::atomic_fetch_oper(Impl::SubOper(), dest, val);
464  }
465  template <typename S>
466  SACADO_INLINE_FUNCTION GeneralFad<S>
467  atomic_fetch_mul(GeneralFad<S>* dest, const GeneralFad<S>& val) {
468  return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
469  }
470  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
472  atomic_fetch_mul(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
473  return Impl::atomic_fetch_oper(Impl::MulOper(), dest, val);
474  }
475  template <typename S>
476  SACADO_INLINE_FUNCTION GeneralFad<S>
477  atomic_fetch_div(GeneralFad<S>* dest, const GeneralFad<S>& val) {
478  return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
479  }
480  template <typename ValT, unsigned sl, unsigned ss, typename U, typename T>
482  atomic_fetch_div(ViewFadPtr<ValT,sl,ss,U> dest, const Expr<T>& val) {
483  return Impl::atomic_fetch_oper(Impl::DivOper(), dest, val);
484  }
485 
486  } // namespace Exp
487  } // namespace Fad
488 
489 } // namespace Sacado
490 
491 #endif // HAVE_SACADO_KOKKOSCORE
492 #endif // SACADO_FAD_EXP_VIEWFAD_HPP
#define SACADO_FAD_THREAD_SINGLE
expr val()
#define T
Definition: Sacado_rad.hpp:573
SimpleFad< ValueT > min(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_FAD_DERIV_LOOP(I, SZ)
Get the base Fad type from a view/expression.
T derived_type
Typename of derived object, returned by derived()
SimpleFad< ValueT > max(const SimpleFad< ValueT > &a, const SimpleFad< ValueT > &b)
#define SACADO_INLINE_FUNCTION