Intrepid2
Intrepid2_DirectSumBasis.hpp
Go to the documentation of this file.
1 // @HEADER
2 // ************************************************************************
3 //
4 // Intrepid2 Package
5 // Copyright (2007) Sandia Corporation
6 //
7 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
8 // license for use of this work by or on behalf of the U.S. Government.
9 //
10 // Redistribution and use in source and binary forms, with or without
11 // modification, are permitted provided that the following conditions are
12 // met:
13 //
14 // 1. Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimer.
16 //
17 // 2. Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimer in the
19 // documentation and/or other materials provided with the distribution.
20 //
21 // 3. Neither the name of the Corporation nor the names of the
22 // contributors may be used to endorse or promote products derived from
23 // this software without specific prior written permission.
24 //
25 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 //
37 // Questions? Contact Kyungjoo Kim (kyukim@sandia.gov),
38 // Mauro Perego (mperego@sandia.gov), or
39 // Nate Roberts (nvrober@sandia.gov)
40 //
41 // ************************************************************************
42 // @HEADER
43 
49 #ifndef Intrepid2_DirectSumBasis_h
50 #define Intrepid2_DirectSumBasis_h
51 
52 #include <Kokkos_View.hpp>
53 #include <Kokkos_DynRankView.hpp>
54 
55 namespace Intrepid2
56 {
67  template<typename BasisBaseClass>
68  class Basis_DirectSumBasis : public BasisBaseClass
69  {
70  public:
71  using BasisBase = BasisBaseClass;
72  using BasisPtr = Teuchos::RCP<BasisBase>;
73 
74  using DeviceType = typename BasisBase::DeviceType;
75  using ExecutionSpace = typename BasisBase::ExecutionSpace;
76  using OutputValueType = typename BasisBase::OutputValueType;
77  using PointValueType = typename BasisBase::PointValueType;
78 
79  using OrdinalTypeArray1DHost = typename BasisBase::OrdinalTypeArray1DHost;
80  using OrdinalTypeArray2DHost = typename BasisBase::OrdinalTypeArray2DHost;
81  using OutputViewType = typename BasisBase::OutputViewType;
82  using PointViewType = typename BasisBase::PointViewType;
83  using ScalarViewType = typename BasisBase::ScalarViewType;
84  protected:
85  BasisPtr basis1_;
86  BasisPtr basis2_;
87 
88  std::string name_;
89  public:
94  Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
95  :
96  basis1_(basis1),basis2_(basis2)
97  {
98  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBasisType() != basis2->getBasisType(), std::invalid_argument, "basis1 and basis2 must agree in basis type");
99  INTREPID2_TEST_FOR_EXCEPTION(basis1->getBaseCellTopology().getKey() != basis2->getBaseCellTopology().getKey(),
100  std::invalid_argument, "basis1 and basis2 must agree in cell topology");
101  INTREPID2_TEST_FOR_EXCEPTION(basis1->getCoordinateSystem() != basis2->getCoordinateSystem(),
102  std::invalid_argument, "basis1 and basis2 must agree in coordinate system");
103 
104  this->basisCardinality_ = basis1->getCardinality() + basis2->getCardinality();
105  this->basisDegree_ = std::max(basis1->getDegree(), basis2->getDegree());
106 
107  {
108  std::ostringstream basisName;
109  basisName << basis1->getName() << " + " << basis2->getName();
110  name_ = basisName.str();
111  }
112 
113  this->basisCellTopology_ = basis1->getBaseCellTopology();
114  this->basisType_ = basis1->getBasisType();
115  this->basisCoordinates_ = basis1->getCoordinateSystem();
116 
117  if (this->basisType_ == BASIS_FEM_HIERARCHICAL)
118  {
119  int degreeLength = basis1_->getPolynomialDegreeLength();
120  INTREPID2_TEST_FOR_EXCEPTION(degreeLength != basis2_->getPolynomialDegreeLength(), std::invalid_argument, "Basis1 and Basis2 must agree on polynomial degree length");
121 
122  this->fieldOrdinalPolynomialDegree_ = OrdinalTypeArray2DHost("DirectSumBasis degree lookup",this->basisCardinality_,degreeLength);
123  // our field ordinals start with basis1_; basis2_ follows
124  for (int fieldOrdinal1=0; fieldOrdinal1<basis1_->getCardinality(); fieldOrdinal1++)
125  {
126  int fieldOrdinal = fieldOrdinal1;
127  auto polynomialDegree = basis1->getPolynomialDegreeOfField(fieldOrdinal1);
128  for (int d=0; d<degreeLength; d++)
129  {
130  this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
131  }
132  }
133  for (int fieldOrdinal2=0; fieldOrdinal2<basis2_->getCardinality(); fieldOrdinal2++)
134  {
135  int fieldOrdinal = basis1->getCardinality() + fieldOrdinal2;
136 
137  auto polynomialDegree = basis2->getPolynomialDegreeOfField(fieldOrdinal2);
138  for (int d=0; d<degreeLength; d++)
139  {
140  this->fieldOrdinalPolynomialDegree_(fieldOrdinal,d) = polynomialDegree(d);
141  }
142  }
143  }
144 
145  // initialize tags
146  {
147  const auto & cardinality = this->basisCardinality_;
148 
149  // Basis-dependent initializations
150  const ordinal_type tagSize = 4; // size of DoF tag, i.e., number of fields in the tag
151  const ordinal_type posScDim = 0; // position in the tag, counting from 0, of the subcell dim
152  const ordinal_type posScOrd = 1; // position in the tag, counting from 0, of the subcell ordinal
153  const ordinal_type posDfOrd = 2; // position in the tag, counting from 0, of DoF ordinal relative to the subcell
154 
155  OrdinalTypeArray1DHost tagView("tag view", cardinality*tagSize);
156 
157  shards::CellTopology cellTopo = this->basisCellTopology_;
158 
159  unsigned spaceDim = cellTopo.getDimension();
160 
161  ordinal_type basis2Offset = basis1_->getCardinality();
162 
163  for (unsigned d=0; d<=spaceDim; d++)
164  {
165  unsigned subcellCount = cellTopo.getSubcellCount(d);
166  for (unsigned subcellOrdinal=0; subcellOrdinal<subcellCount; subcellOrdinal++)
167  {
168  ordinal_type subcellDofCount1 = basis1->getDofCount(d, subcellOrdinal);
169  ordinal_type subcellDofCount2 = basis2->getDofCount(d, subcellOrdinal);
170 
171  ordinal_type subcellDofCount = subcellDofCount1 + subcellDofCount2;
172  for (ordinal_type localDofID=0; localDofID<subcellDofCount; localDofID++)
173  {
174  ordinal_type fieldOrdinal;
175  if (localDofID < subcellDofCount1)
176  {
177  // first basis: field ordinal matches the basis1 ordinal
178  fieldOrdinal = basis1_->getDofOrdinal(d, subcellOrdinal, localDofID);
179  }
180  else
181  {
182  // second basis: field ordinal is offset by basis1 cardinality
183  fieldOrdinal = basis2Offset + basis2_->getDofOrdinal(d, subcellOrdinal, localDofID - subcellDofCount1);
184  }
185  tagView(fieldOrdinal*tagSize+0) = d; // subcell dimension
186  tagView(fieldOrdinal*tagSize+1) = subcellOrdinal;
187  tagView(fieldOrdinal*tagSize+2) = localDofID;
188  tagView(fieldOrdinal*tagSize+3) = subcellDofCount;
189  }
190  }
191  }
192  // // Basis-independent function sets tag and enum data in tagToOrdinal_ and ordinalToTag_ arrays:
193  // // tags are constructed on host
194  this->setOrdinalTagData(this->tagToOrdinal_,
195  this->ordinalToTag_,
196  tagView,
197  this->basisCardinality_,
198  tagSize,
199  posScDim,
200  posScOrd,
201  posDfOrd);
202  }
203  }
204 
210  virtual BasisValues<OutputValueType,DeviceType> allocateBasisValues( TensorPoints<PointValueType,DeviceType> points, const EOperator operatorType = OPERATOR_VALUE) const override
211  {
212  BasisValues<OutputValueType,DeviceType> basisValues1 = basis1_->allocateBasisValues(points, operatorType);
213  BasisValues<OutputValueType,DeviceType> basisValues2 = basis2_->allocateBasisValues(points, operatorType);
214 
215  const int numScalarFamilies1 = basisValues1.numTensorDataFamilies();
216  if (numScalarFamilies1 > 0)
217  {
218  // then both basis1 and basis2 should be scalar-valued; check that for basis2:
219  const int numScalarFamilies2 = basisValues2.numTensorDataFamilies();
220  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() <=0, std::invalid_argument, "When basis1 has scalar value, basis2 must also");
221  std::vector< TensorData<OutputValueType,DeviceType> > scalarFamilies(numScalarFamilies1 + numScalarFamilies2);
222  for (int i=0; i<numScalarFamilies1; i++)
223  {
224  scalarFamilies[i] = basisValues1.tensorData(i);
225  }
226  for (int i=0; i<numScalarFamilies2; i++)
227  {
228  scalarFamilies[i+numScalarFamilies1] = basisValues2.tensorData(i);
229  }
230  return BasisValues<OutputValueType,DeviceType>(scalarFamilies);
231  }
232  else
233  {
234  // then both basis1 and basis2 should be vector-valued; check that:
235  INTREPID2_TEST_FOR_EXCEPTION(!basisValues1.vectorData().isValid(), std::invalid_argument, "When basis1 does not have tensorData() defined, it must have a valid vectorData()");
236  INTREPID2_TEST_FOR_EXCEPTION(basisValues2.numTensorDataFamilies() > 0, std::invalid_argument, "When basis1 has vector value, basis2 must also");
237 
238  const auto & vectorData1 = basisValues1.vectorData();
239  const auto & vectorData2 = basisValues2.vectorData();
240 
241  const int numFamilies1 = vectorData1.numFamilies();
242  const int numComponents = vectorData1.numComponents();
243  INTREPID2_TEST_FOR_EXCEPTION(numComponents != vectorData2.numComponents(), std::invalid_argument, "basis1 and basis2 must agree on the number of components in each vector");
244  const int numFamilies2 = vectorData2.numFamilies();
245 
246  const int numFamilies = numFamilies1 + numFamilies2;
247  std::vector< std::vector<TensorData<OutputValueType,DeviceType> > > vectorComponents(numFamilies, std::vector<TensorData<OutputValueType,DeviceType> >(numComponents));
248 
249  for (int i=0; i<numFamilies1; i++)
250  {
251  for (int j=0; j<numComponents; j++)
252  {
253  vectorComponents[i][j] = vectorData1.getComponent(i,j);
254  }
255  }
256  for (int i=0; i<numFamilies2; i++)
257  {
258  for (int j=0; j<numComponents; j++)
259  {
260  vectorComponents[i+numFamilies1][j] = vectorData2.getComponent(i,j);
261  }
262  }
263  VectorData<OutputValueType,DeviceType> vectorData(vectorComponents);
264  return BasisValues<OutputValueType,DeviceType>(vectorData);
265  }
266  }
267 
276  virtual void getDofCoords( ScalarViewType dofCoords ) const override {
277  const int basisCardinality1 = basis1_->getCardinality();
278  const int basisCardinality2 = basis2_->getCardinality();
279  const int basisCardinality = basisCardinality1 + basisCardinality2;
280 
281  auto dofCoords1 = Kokkos::subview(dofCoords, std::make_pair(0,basisCardinality1), Kokkos::ALL());
282  auto dofCoords2 = Kokkos::subview(dofCoords, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
283 
284  basis1_->getDofCoords(dofCoords1);
285  basis2_->getDofCoords(dofCoords2);
286  }
287 
299  virtual void getDofCoeffs( ScalarViewType dofCoeffs ) const override {
300  const int basisCardinality1 = basis1_->getCardinality();
301  const int basisCardinality2 = basis2_->getCardinality();
302  const int basisCardinality = basisCardinality1 + basisCardinality2;
303 
304  auto dofCoeffs1 = Kokkos::subview(dofCoeffs, std::make_pair(0,basisCardinality1), Kokkos::ALL());
305  auto dofCoeffs2 = Kokkos::subview(dofCoeffs, std::make_pair(basisCardinality1,basisCardinality), Kokkos::ALL());
306 
307  basis1_->getDofCoeffs(dofCoeffs1);
308  basis2_->getDofCoeffs(dofCoeffs2);
309  }
310 
311 
316  virtual
317  const char*
318  getName() const override {
319  return name_.c_str();
320  }
321 
322  // since the getValues() below only overrides the FEM variants, we specify that
323  // we use the base class's getValues(), which implements the FVD variant by throwing an exception.
324  // (It's an error to use the FVD variant on this basis.)
325  using BasisBase::getValues;
326 
338  virtual
339  void
341  const TensorPoints<PointValueType,DeviceType> inputPoints,
342  const EOperator operatorType = OPERATOR_VALUE ) const override
343  {
344  const int fieldStartOrdinal1 = 0;
345  const int numFields1 = basis1_->getCardinality();
346  const int fieldStartOrdinal2 = numFields1;
347  const int numFields2 = basis2_->getCardinality();
348 
349  auto basisValues1 = outputValues.basisValuesForFields(fieldStartOrdinal1, numFields1);
350  auto basisValues2 = outputValues.basisValuesForFields(fieldStartOrdinal2, numFields2);
351 
352  basis1_->getValues(basisValues1, inputPoints, operatorType);
353  basis2_->getValues(basisValues2, inputPoints, operatorType);
354  }
355 
374  virtual void getValues( OutputViewType outputValues, const PointViewType inputPoints,
375  const EOperator operatorType = OPERATOR_VALUE ) const override
376  {
377  int cardinality1 = basis1_->getCardinality();
378  int cardinality2 = basis2_->getCardinality();
379 
380  auto range1 = std::make_pair(0,cardinality1);
381  auto range2 = std::make_pair(cardinality1,cardinality1+cardinality2);
382  if (outputValues.rank() == 2) // F,P
383  {
384  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL());
385  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL());
386 
387  basis1_->getValues(outputValues1, inputPoints, operatorType);
388  basis2_->getValues(outputValues2, inputPoints, operatorType);
389  }
390  else if (outputValues.rank() == 3) // F,P,D
391  {
392  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL());
393  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL());
394 
395  basis1_->getValues(outputValues1, inputPoints, operatorType);
396  basis2_->getValues(outputValues2, inputPoints, operatorType);
397  }
398  else if (outputValues.rank() == 4) // F,P,D,D
399  {
400  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
401  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
402 
403  basis1_->getValues(outputValues1, inputPoints, operatorType);
404  basis2_->getValues(outputValues2, inputPoints, operatorType);
405  }
406  else if (outputValues.rank() == 5) // F,P,D,D,D
407  {
408  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
409  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
410 
411  basis1_->getValues(outputValues1, inputPoints, operatorType);
412  basis2_->getValues(outputValues2, inputPoints, operatorType);
413  }
414  else if (outputValues.rank() == 6) // F,P,D,D,D,D
415  {
416  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
417  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
418 
419  basis1_->getValues(outputValues1, inputPoints, operatorType);
420  basis2_->getValues(outputValues2, inputPoints, operatorType);
421  }
422  else if (outputValues.rank() == 7) // F,P,D,D,D,D,D
423  {
424  auto outputValues1 = Kokkos::subview(outputValues, range1, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
425  auto outputValues2 = Kokkos::subview(outputValues, range2, Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL(), Kokkos::ALL());
426 
427  basis1_->getValues(outputValues1, inputPoints, operatorType);
428  basis2_->getValues(outputValues2, inputPoints, operatorType);
429  }
430  else
431  {
432  INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Unsupported outputValues rank");
433  }
434  }
435  };
436 } // end namespace Intrepid2
437 
438 #endif /* Intrepid2_DirectSumBasis_h */
virtual void getValues(BasisValues< OutputValueType, DeviceType > outputValues, const TensorPoints< PointValueType, DeviceType > inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell, using point and output value containers that allow pre...
virtual void getDofCoords(ScalarViewType dofCoords) const override
Fills in spatial locations (coordinates) of degrees of freedom (nodes) on the reference cell...
virtual BasisValues< OutputValueType, DeviceType > allocateBasisValues(TensorPoints< PointValueType, DeviceType > points, const EOperator operatorType=OPERATOR_VALUE) const override
Allocate BasisValues container suitable for passing to the getValues() variant that takes a TensorPoi...
View-like interface to tensor points; point components are stored separately; the appropriate coordin...
BasisValues< Scalar, ExecSpaceType > basisValuesForFields(const int &fieldStartOrdinal, const int &numFields)
field start and length must align with families in vectorData_ or tensorDataFamilies_ (whichever is v...
The data containers in Intrepid2 that support sum factorization and other reduced-data optimizations ...
virtual void getValues(OutputViewType outputValues, const PointViewType inputPoints, const EOperator operatorType=OPERATOR_VALUE) const override
Evaluation of a FEM basis on a reference cell.
A basis that is the direct sum of two other bases.
virtual const char * getName() const override
Returns basis name.
EOperator
Enumeration of primitive operators available in Intrepid. Primitive operators act on reconstructed fu...
virtual void getDofCoeffs(ScalarViewType dofCoeffs) const override
Fills in coefficients of degrees of freedom for Lagrangian basis on the reference cell...
TensorDataType & tensorData()
TensorData accessor for single-family scalar data.
View-like interface to tensor data; tensor components are stored separately and multiplied together a...
const VectorDataType & vectorData() const
VectorData accessor.
Reference-space field values for a basis, designed to support typical vector-valued bases...
Basis_DirectSumBasis(BasisPtr basis1, BasisPtr basis2)
Constructor.