40 #ifndef TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP 41 #define TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP 43 #include <Tpetra_ConfigDefs.hpp> 45 #ifdef HAVE_TPETRA_TSQR 49 # include "Tsqr_NodeTsqrFactory.hpp" 51 # include "Tsqr_DistTsqr.hpp" 54 # include "Tsqr_TeuchosMessenger.hpp" 55 # include "Tpetra_MultiVector.hpp" 56 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp" 60 # include "Tpetra_TsqrAdaptor.hpp" 70 template <
class Storage,
class LO,
class GO,
class Node>
71 class TsqrAdaptor<
Tpetra::MultiVector< Sacado::MP::Vector<Storage>,
73 public Teuchos::ParameterListAcceptorDefaultBase {
75 typedef Tpetra::MultiVector< Sacado::MP::Vector<Storage>, LO, GO,
Node > MV;
82 typedef Teuchos::SerialDenseMatrix<ordinal_type, scalar_type> dense_matrix_type;
83 typedef typename Teuchos::ScalarTraits<scalar_type>::magnitudeType magnitude_type;
86 using node_tsqr_factory_type =
88 typename MV::device_type>;
89 using node_tsqr_type = TSQR::NodeTsqr<ordinal_type, scalar_type>;
90 using dist_tsqr_type = TSQR::DistTsqr<ordinal_type, scalar_type>;
91 using tsqr_type = TSQR::Tsqr<ordinal_type, scalar_type>;
100 TsqrAdaptor (
const Teuchos::RCP<Teuchos::ParameterList>& plist) :
101 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
102 distTsqr_ (new dist_tsqr_type),
103 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
106 setParameterList (plist);
111 nodeTsqr_ (node_tsqr_factory_type::getNodeTsqr ()),
112 distTsqr_ (new dist_tsqr_type),
113 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
116 setParameterList (Teuchos::null);
119 Teuchos::RCP<const Teuchos::ParameterList>
120 getValidParameters ()
const 124 using Teuchos::ParameterList;
125 using Teuchos::parameterList;
127 if (defaultParams_.is_null()) {
128 RCP<ParameterList> params = parameterList (
"TSQR implementation");
129 params->set (
"NodeTsqr", *(nodeTsqr_->getValidParameters ()));
130 params->set (
"DistTsqr", *(distTsqr_->getValidParameters ()));
131 defaultParams_ = params;
133 return defaultParams_;
137 setParameterList (
const Teuchos::RCP<Teuchos::ParameterList>& plist)
139 using Teuchos::ParameterList;
140 using Teuchos::parameterList;
142 using Teuchos::sublist;
144 RCP<ParameterList> params = plist.is_null() ?
145 parameterList (*getValidParameters ()) : plist;
146 nodeTsqr_->setParameterList (sublist (params,
"NodeTsqr"));
147 distTsqr_->setParameterList (sublist (params,
"DistTsqr"));
149 this->setMyParamList (params);
174 factorExplicit (MV& A,
176 dense_matrix_type& R,
177 const bool forceNonnegativeDiagonal=
false)
188 getNonConstView (numRows, numCols, A_ptr, LDA, A);
189 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
190 const bool contiguousCacheBlocks =
false;
191 tsqr_->factorExplicitRaw (numRows, numCols, A_ptr, LDA,
192 Q_ptr, LDQ, R.values (), R.stride (),
193 contiguousCacheBlocks,
194 forceNonnegativeDiagonal);
229 dense_matrix_type& R,
230 const magnitude_type& tol)
242 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
243 const bool contiguousCacheBlocks =
false;
244 return tsqr_->revealRankRaw (numRows, numCols, Q_ptr, LDQ,
245 R.values (), R.stride (), tol,
246 contiguousCacheBlocks);
251 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
254 Teuchos::RCP<dist_tsqr_type> distTsqr_;
257 Teuchos::RCP<tsqr_type> tsqr_;
260 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
286 prepareTsqr (
const MV& mv)
289 prepareDistTsqr (mv);
301 prepareDistTsqr (
const MV& mv)
304 using Teuchos::rcp_implicit_cast;
305 typedef TSQR::TeuchosMessenger<scalar_type> mess_type;
306 typedef TSQR::MessengerBase<scalar_type> base_mess_type;
308 RCP<const Teuchos::Comm<int> > comm = mv.getMap()->getComm();
309 RCP<mess_type> mess (
new mess_type (comm));
310 RCP<base_mess_type> messBase = rcp_implicit_cast<base_mess_type> (mess);
311 distTsqr_->init (messBase);
332 TEUCHOS_TEST_FOR_EXCEPTION
333 (! A.isConstantStride (), std::invalid_argument,
334 "TSQR does not currently support Tpetra::MultiVector " 335 "inputs that do not have constant stride.");
345 typedef typename MV::dual_view_type view_type;
346 typedef typename view_type::t_dev::array_type flat_array_type;
352 flat_array_type flat_mv = A.getLocalViewDevice();
354 numRows =
static_cast<ordinal_type> (flat_mv.extent(0));
355 numCols =
static_cast<ordinal_type> (flat_mv.extent(1));
356 A_ptr = flat_mv.data ();
359 flat_mv.stride (strides);
366 #endif // HAVE_TPETRA_TSQR 368 #endif // TPETRA_TSQR_ADAPTOR_MP_VECTOR_HPP
KokkosClassic::DefaultNode::DefaultNodeType Node