SimiLie
Loading...
Searching...
No Matches
tensor_impl.hpp
1// SPDX-FileCopyrightText: 2024 Baptiste Legouix
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include <iostream>
7
8#include <ddc/ddc.hpp>
9
10#include <similie/misc/portable_stl.hpp>
11#include <similie/misc/specialization.hpp>
12
13namespace sil {
14
15namespace tensor {
16
17// struct representing an index mu or nu in a tensor Tmunu.
18template <class... CDim>
20{
21 static constexpr bool is_tensor_index = true;
22 static constexpr bool is_tensor_natural_index = true;
23 static constexpr bool is_explicitely_stored_tensor = true;
24
25 using type_seq_dimensions = ddc::detail::TypeSeq<CDim...>;
26
27 using subindices_domain_t = ddc::DiscreteDomain<>;
28
29 KOKKOS_FUNCTION static constexpr subindices_domain_t subindices_domain()
30 {
31 return ddc::DiscreteDomain<>();
32 }
33
34 KOKKOS_FUNCTION static constexpr std::size_t rank()
35 {
36 return sizeof...(CDim) != 0;
37 }
38
39 KOKKOS_FUNCTION static constexpr std::size_t size()
40 {
41 if constexpr (rank() == 0) {
42 return 1;
43 } else {
44 return sizeof...(CDim);
45 }
46 }
47
48 KOKKOS_FUNCTION static constexpr std::size_t mem_size()
49 {
50 return size();
51 }
52
53 KOKKOS_FUNCTION static constexpr std::size_t access_size()
54 {
55 return size();
56 }
57
58 template <class ODim>
59 KOKKOS_FUNCTION static constexpr std::size_t mem_id()
60 {
61 if constexpr (rank() == 0) {
62 return 0;
63 } else {
64 return ddc::type_seq_rank_v<ODim, type_seq_dimensions>;
65 }
66 }
67
68 KOKKOS_FUNCTION static constexpr std::size_t mem_id(std::size_t const natural_id)
69 {
70 return natural_id;
71 }
72
73 KOKKOS_FUNCTION static constexpr std::size_t access_id(std::size_t const natural_id)
74 {
75 return natural_id;
76 }
77
78 KOKKOS_FUNCTION static constexpr std::size_t access_id_to_mem_id(std::size_t access_id)
79 {
80 return access_id;
81 }
82
83 template <class Tensor, class Elem, class Id, class FunctorType>
84 KOKKOS_FUNCTION static constexpr Tensor::element_type process_access(
85 const FunctorType& access,
86 Tensor tensor,
87 Elem elem)
88 {
89 return access(tensor, elem);
90 }
91
92 KOKKOS_FUNCTION static constexpr std::array<std::size_t, rank()>
94 {
95 assert(mem_id < mem_size());
96 if constexpr (rank() == 0) {
97 return std::array<std::size_t, rank()> {};
98 } else {
99 return std::array<std::size_t, rank()> {mem_id};
100 }
101 }
102};
103
104template <class DDim>
105concept TensorIndex = requires {
106 { DDim::is_tensor_index } -> std::convertible_to<bool>;
107} && DDim::is_tensor_index;
108
109template <class DDim>
110concept TensorNatIndex = requires {
111 { DDim::is_tensor_natural_index } -> std::convertible_to<bool>;
112} && DDim::is_tensor_natural_index;
113
114// ScalarIndex, a generic rank-0 index
116{
117};
118
119// natural_domain_t is obtained using concept and specialization
120namespace detail {
121
122template <class Index>
123struct NaturalDomainType;
124
125template <class Index>
127struct NaturalDomainType<Index>
128{
129 using type = typename Index::subindices_domain_t;
130};
131
132template <TensorNatIndex Index>
133struct NaturalDomainType<Index>
134{
135 using type = ddc::DiscreteDomain<Index>;
136};
137
138} // namespace detail
139
140template <TensorIndex Index>
141using natural_domain_t = typename detail::NaturalDomainType<Index>::type;
142
143// Helpers to build the access_id() function which computes the ids of subindices of an index. This cumbersome logic is necessary because subindices do not necessarily have the same rank.
144namespace detail {
145// For Tmunu and index=nu, returns 1
146template <class Index, class...>
147struct NbDimsBeforeIndex;
148
149template <class Index, class IndexHead, class... IndexTail>
150struct NbDimsBeforeIndex<Index, ddc::detail::TypeSeq<IndexHead, IndexTail...>>
151{
152 static constexpr std::size_t run(std::size_t nb_dims_before_index)
153 {
154 if constexpr (std::is_same_v<IndexHead, Index>) {
155 return nb_dims_before_index;
156 } else {
157 return NbDimsBeforeIndex<Index, ddc::detail::TypeSeq<IndexTail...>>::run(
158 nb_dims_before_index + IndexHead::rank());
159 }
160 }
161};
162
163// Offset and index sequence
164template <std::size_t Offset, class IndexSeq>
165struct OffsetIndexSeq;
166
167template <std::size_t Offset, std::size_t... Is>
168struct OffsetIndexSeq<Offset, std::integer_sequence<std::size_t, Is...>>
169{
170 using type = std::integer_sequence<std::size_t, Offset + Is...>;
171};
172
173template <std::size_t Offset, class IndexSeq>
174using offset_index_seq_t = OffsetIndexSeq<Offset, IndexSeq>::type;
175
176// Returns dimensions from integers (ie. for Tmunu, <1> gives nu)
177template <class CDimTypeSeq, class IndexSeq>
178struct TypeSeqDimsAtInts;
179
180template <class CDimTypeSeq, std::size_t... Is>
181struct TypeSeqDimsAtInts<CDimTypeSeq, std::integer_sequence<std::size_t, Is...>>
182{
183 using type = ddc::detail::TypeSeq<ddc::type_seq_element_t<Is, CDimTypeSeq>...>;
184};
185
186template <class CDimTypeSeq, class IndexSeq>
187using type_seq_dims_at_ints_t = TypeSeqDimsAtInts<CDimTypeSeq, IndexSeq>::type;
188
189// Returns Index::access_id but from a type seq (in place of a variadic template CDim...)
190template <class Index, class SubindicesDomain, class TypeSeqDims>
191struct IdFromTypeSeqDims;
192
193template <class Index, class... Subindex, class... CDim>
194struct IdFromTypeSeqDims<Index, ddc::DiscreteDomain<Subindex...>, ddc::detail::TypeSeq<CDim...>>
195{
196 static constexpr std::size_t run()
197 {
198 static_assert(sizeof...(Subindex) == sizeof...(CDim));
199 if constexpr (TensorNatIndex<Index>) {
200 return Index::access_id(
201 ddc::type_seq_rank_v<CDim, typename Index::type_seq_dimensions>...);
202 } else {
203 return Index::access_id(std::array<
204 std::size_t,
205 sizeof...(Subindex)> {ddc::type_seq_rank_v<
206 typename ddc::type_seq_element_t<
207 ddc::type_seq_rank_v<Subindex, ddc::detail::TypeSeq<Subindex...>>,
208 ddc::detail::TypeSeq<CDim...>>,
209 typename Subindex::type_seq_dimensions>...});
210 }
211 }
212};
213
214// Returns Index::access_id for the subindex Index of the IndicesTypeSeq
215template <class Index, class IndicesTypeSeq, class... CDim>
216static constexpr std::size_t access_id()
217{
218 if constexpr (TensorNatIndex<Index>) {
219 return IdFromTypeSeqDims<
220 Index,
221 ddc::DiscreteDomain<Index>,
222 type_seq_dims_at_ints_t<
223 ddc::detail::TypeSeq<CDim...>,
224 offset_index_seq_t<
225 NbDimsBeforeIndex<Index, IndicesTypeSeq>::run(0),
226 std::make_integer_sequence<std::size_t, Index::rank()>>>>::run();
227 } else {
228 return IdFromTypeSeqDims<
229 Index,
230 typename Index::subindices_domain_t,
231 type_seq_dims_at_ints_t<
232 ddc::detail::TypeSeq<CDim...>,
233 offset_index_seq_t<
234 NbDimsBeforeIndex<Index, IndicesTypeSeq>::run(0),
235 std::make_integer_sequence<std::size_t, Index::rank()>>>>::run();
236 }
237}
238
239template <class Index, class SubindicesDomain>
240struct IdFromElem;
241
242template <class Index, class... Subindex>
243struct IdFromElem<Index, ddc::DiscreteDomain<Subindex...>>
244{
245 template <class Elem>
246 static constexpr std::size_t run(Elem natural_elem)
247 {
248 if constexpr (TensorNatIndex<Index>) {
249 return Index::access_id(natural_elem.template uid<Index>());
250 } else {
251 return Index::access_id(std::array<std::size_t, sizeof...(Subindex)> {
252 natural_elem.template uid<Subindex>()...});
253 }
254 }
255};
256
257template <class Index, class IndicesTypeSeq, class... NaturalIndex>
258static constexpr std::size_t access_id(ddc::DiscreteElement<NaturalIndex...> natural_elem)
259{
260 if constexpr (TensorNatIndex<Index>) {
261 return IdFromElem<Index, ddc::DiscreteDomain<Index>>::run(natural_elem);
262 } else {
263 return IdFromElem<Index, typename Index::subindices_domain_t>::run(natural_elem);
264 }
265}
266
267} // namespace detail
268
269// TensorAccessor class, allows to build a domain which represents the tensor and access elements.
270template <TensorIndex... Index>
272{
273public:
274 explicit constexpr TensorAccessor();
275
276 using discrete_domain_type = ddc::DiscreteDomain<Index...>;
277
278 using discrete_element_type = ddc::DiscreteElement<Index...>;
279
280 using natural_domain_t = ddc::cartesian_prod_t<std::conditional_t< // TODO natural_domain_type
282 ddc::DiscreteDomain<Index>,
283 typename Index::subindices_domain_t>...>;
284
285 static constexpr natural_domain_t natural_domain();
286
287 static constexpr discrete_domain_type domain();
288
289 static constexpr discrete_domain_type access_domain();
290
291 template <class... CDim>
292 static constexpr discrete_element_type access_element();
293
294 template <class... NaturalIndex>
295 static constexpr discrete_element_type access_element(
296 ddc::DiscreteElement<NaturalIndex...> natural_elem);
297
298 template <class... MemIndex>
299 static constexpr natural_domain_t::discrete_element_type canonical_natural_element(
300 ddc::DiscreteElement<MemIndex...> mem_elem);
301};
302
303namespace detail {
304
305template <class Seq>
306struct TensorAccessorForTypeSeq;
307
308template <TensorIndex... Index>
309struct TensorAccessorForTypeSeq<ddc::detail::TypeSeq<Index...>>
310{
311 using type = TensorAccessor<Index...>;
312};
313
314template <class Dom>
315struct TensorAccessorForDomain;
316
317template <class... DDim>
318struct TensorAccessorForDomain<ddc::DiscreteDomain<DDim...>>
319{
320 using type = typename TensorAccessorForTypeSeq<
321 ddc::to_type_seq_t<ddc::cartesian_prod_t<std::conditional_t<
322 TensorIndex<DDim>,
323 ddc::DiscreteDomain<DDim>,
324 ddc::DiscreteDomain<>>...>>>::type;
325};
326
327} // namespace detail
328
329template <misc::Specialization<ddc::DiscreteDomain> Dom>
330using tensor_accessor_for_domain_t = detail::TensorAccessorForDomain<Dom>::type;
331
332template <TensorIndex... Index>
336
337namespace detail {
338template <class Index>
339constexpr auto natural_domain()
340{
341 if constexpr (TensorNatIndex<Index>) {
342 return typename ddc::DiscreteDomain<
343 Index>(ddc::DiscreteElement<Index>(0), ddc::DiscreteVector<Index>(Index::size()));
344 } else {
345 return Index::subindices_domain();
346 }
347}
348} // namespace detail
349
350template <TensorIndex... Index>
352{
353 return natural_domain_t(detail::natural_domain<Index>()...);
354}
355
356template <TensorIndex... Index>
357constexpr TensorAccessor<Index...>::discrete_domain_type TensorAccessor<Index...>::domain()
358{
359 return ddc::DiscreteDomain<Index...>(
360 ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(0)...),
361 ddc::DiscreteVector<Index...>(ddc::DiscreteVector<Index>(Index::mem_size())...));
362}
363
364template <TensorIndex... Index>
365constexpr TensorAccessor<Index...>::discrete_domain_type TensorAccessor<Index...>::access_domain()
366{
367 return ddc::DiscreteDomain<Index...>(
368 ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(0)...),
369 ddc::DiscreteVector<Index...>(ddc::DiscreteVector<Index>(Index::access_size())...));
370}
371
372template <TensorIndex... Index>
373template <class... CDim>
374constexpr TensorAccessor<Index...>::discrete_element_type TensorAccessor<Index...>::access_element()
375{
376 return ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(
377 detail::access_id<Index, ddc::detail::TypeSeq<Index...>, CDim...>())...);
378}
379
380template <TensorIndex... Index>
381template <class... NaturalIndex>
382constexpr TensorAccessor<Index...>::discrete_element_type TensorAccessor<Index...>::access_element(
383 [[maybe_unused]] ddc::DiscreteElement<NaturalIndex...> natural_elem)
384{
385 return ddc::DiscreteElement<Index...>(
386 ddc::DiscreteElement<Index>(detail::access_id<Index, ddc::detail::TypeSeq<Index...>>(
387 typename natural_domain_t::discrete_element_type(natural_elem)))...);
388}
389
390template <TensorIndex... Index>
391template <class... MemIndex>
392constexpr TensorAccessor<Index...>::natural_domain_t::discrete_element_type TensorAccessor<
393 Index...>::canonical_natural_element(ddc::DiscreteElement<MemIndex...> mem_elem)
394{
395 std::array<std::size_t, natural_domain_t::rank()> ids {};
396 auto it = ids.begin();
397 (
398 [&]() {
399 auto i = MemIndex::mem_id_to_canonical_natural_ids(
400 mem_elem.template uid<MemIndex>());
401 misc::detail::copy(i.begin(), i.end(), it);
402 it += i.size();
403 }(),
404 ...);
405 typename natural_domain_t::discrete_element_type natural_elem;
406 ddc::detail::array(natural_elem) = std::array<std::size_t, natural_domain_t::rank()>(ids);
407 return natural_elem;
408}
409
410namespace detail {
411
412// Helpers to handle memory access and processing for particular tensor structures (ie. eventual multiplication with -1 for antisymmetry or non-stored zeros)
413template <
414 class TensorField,
415 class Element,
416 class IndexHeadsTypeSeq,
417 class IndexInterest,
418 class... IndexTail>
419struct Access;
420
421template <
422 class TensorField,
423 class Element,
424 class... IndexHead,
425 class IndexInterest,
426 class... IndexTail>
427struct Access<TensorField, Element, ddc::detail::TypeSeq<IndexHead...>, IndexInterest, IndexTail...>
428{
429 template <class Elem>
430 KOKKOS_FUNCTION static TensorField::element_type run(TensorField tensor_field, Elem const& elem)
431 {
432 /*
433 ----- Important warning -----
434 The general case is not correctly handled here. It would be difficult to do so.
435 It means you can get silent bug (with wrong result) if you try to use exotic ordering
436 of dimensions/indices. Ie., a TensorYoungTableauIndex has to be the last of the list.
437 */
438 if constexpr (sizeof...(IndexTail) > 0) {
439 if constexpr (TensorIndex<IndexInterest>) {
440 return IndexInterest::template process_access<TensorField, Elem, IndexInterest>(
441 KOKKOS_LAMBDA(TensorField tensor_field_, Elem elem_)
442 ->TensorField::element_type {
443 return Access<
444 TensorField,
445 Element,
446 ddc::detail::TypeSeq<IndexHead..., IndexInterest>,
447 IndexTail...>::run(tensor_field_, elem_);
448 },
449 tensor_field,
450 elem);
451 } else {
452 return Access<
453 TensorField,
454 Element,
455 ddc::detail::TypeSeq<IndexHead..., IndexInterest>,
456 IndexTail...>::run(tensor_field, elem);
457 }
458 } else {
459 if constexpr (TensorIndex<IndexInterest>) {
460 return IndexInterest::template process_access<TensorField, Elem, IndexInterest>(
461 KOKKOS_LAMBDA(TensorField tensor_field_, Elem elem_)
462 ->TensorField::element_type {
463 double tensor_field_value = 0;
464 if constexpr (IndexInterest::is_explicitely_stored_tensor) {
465 std::size_t const mem_id
466 = IndexInterest::access_id_to_mem_id(
467 elem_.template uid<IndexInterest>());
468 if (mem_id != std::numeric_limits<std::size_t>::max()) {
469 tensor_field_value
470 = tensor_field_
471 .mem(ddc::DiscreteElement<
472 IndexHead...>(elem_),
473 ddc::DiscreteElement<
474 IndexInterest>(mem_id));
475 } else {
476 tensor_field_value = 1.;
477 }
478 } else {
479 std::pair<
480 std::vector<double>,
481 std::vector<std::size_t>> const mem_lin_comb
482 = IndexInterest::access_id_to_mem_lin_comb(
483 elem_.template uid<IndexInterest>());
484
485 if (std::get<0>(mem_lin_comb).size() > 0) {
486 for (std::size_t i = 0;
487 i < std::get<0>(mem_lin_comb).size();
488 ++i) {
489 tensor_field_value
490 += std::get<0>(mem_lin_comb)[i]
491 * tensor_field_.mem(
492 ddc::DiscreteElement<
493 IndexHead...>(elem_),
494 ddc::DiscreteElement<
495 IndexInterest>(std::get<
496 1>(
497 mem_lin_comb)[i]));
498 }
499 } else {
500 tensor_field_value = 1.;
501 }
502 }
503
504 return tensor_field_value;
505 },
506 tensor_field,
507 elem);
508 } else {
509 return tensor_field(elem);
510 }
511 }
512 }
513};
514
515// Functor for memory element access (if defined)
516template <class InterestDim>
517struct LambdaMemElem
518{
519 template <class Elem>
520 KOKKOS_FUNCTION static ddc::DiscreteElement<InterestDim> run(Elem elem)
521 {
522 return ddc::DiscreteElement<InterestDim>(elem);
523 }
524};
525
526template <TensorIndex InterestDim>
527struct LambdaMemElem<InterestDim>
528{
529 template <class Elem>
530 KOKKOS_FUNCTION static ddc::DiscreteElement<InterestDim> run(Elem elem)
531 {
532 if constexpr (InterestDim::is_explicitely_stored_tensor) {
533 std::size_t const mem_id
534 = InterestDim::access_id_to_mem_id(elem.template uid<InterestDim>());
535 assert(mem_id != std::numeric_limits<std::size_t>::max()
536 && "mem_elem is not defined because mem_id() returned a max integer. Maybe you "
537 "used Tensor::operator() in place of Tensor::get ?");
538 return ddc::DiscreteElement<InterestDim>(mem_id);
539 } else {
540 std::pair<std::vector<double>, std::vector<std::size_t>> const mem_lin_comb
541 = InterestDim::access_id_to_mem_lin_comb(elem.template uid<InterestDim>());
542 assert(std::get<0>(mem_lin_comb).size() > 0
543 && "mem_elem is not defined because mem_lin_comb contains no id. Maybe you used "
544 "Tensor::operator() in place of Tensor::get ?");
545 assert(std::get<0>(mem_lin_comb).size() == 1
546 && "mem_elem is not defined because mem_lin_comb contains several ids. Maybe "
547 "you used Tensor::operator() in place of Tensor::get ?");
548 return ddc::DiscreteElement<InterestDim>(std::get<1>(mem_lin_comb)[0]);
549 }
550 }
551};
552
553} // namespace detail
554
555// @cond
556
557template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
558class Tensor;
559
560} // namespace tensor
561
562} // namespace sil
563
564namespace ddc {
565
566template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
567inline constexpr bool enable_chunk<
569 = true;
570
571template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
572inline constexpr bool enable_borrowed_chunk<
574 = true;
575
576} // namespace ddc
577
578namespace sil {
579
580namespace tensor {
581
582// @endcond
583
585template <class ElementType, class... DDim, class LayoutStridedPolicy, class MemorySpace>
586class Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
587 : public ddc::
588 ChunkSpan<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
589{
590protected:
591 using base_type = ddc::
592 ChunkSpan<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>;
593
594public:
595 using base_type::ChunkSpan;
596 using reference = base_type::reference;
597 using discrete_domain_type = base_type::discrete_domain_type;
598 using discrete_element_type = base_type::discrete_element_type;
599
600 using base_type::domain;
601 using base_type::operator();
602
603 KOKKOS_FUNCTION constexpr explicit Tensor(ddc::ChunkSpan<
604 ElementType,
605 ddc::DiscreteDomain<DDim...>,
606 LayoutStridedPolicy,
607 MemorySpace> other) noexcept
608 : base_type(other)
609 {
610 }
611
612 using accessor_t = tensor_accessor_for_domain_t<ddc::cartesian_prod_t<std::conditional_t<
614 ddc::DiscreteDomain<DDim>,
615 ddc::DiscreteDomain<>>...>>;
616
617 static constexpr accessor_t accessor()
618 {
619 return accessor_t();
620 }
621
622 using indices_domain_t = accessor_t::discrete_domain_type;
623
625 = ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
626 ddc::to_type_seq_t<discrete_domain_type>,
627 ddc::to_type_seq_t<indices_domain_t>>>;
629 KOKKOS_FUNCTION constexpr indices_domain_t indices_domain() const noexcept
630 {
631 return indices_domain_t(domain());
632 }
634 KOKKOS_FUNCTION constexpr non_indices_domain_t non_indices_domain() const noexcept
635 {
636 return non_indices_domain_t(domain());
637 }
639 using natural_domain_t
640 = ddc::cartesian_prod_t<non_indices_domain_t, typename accessor_t::natural_domain_t>;
641
642 KOKKOS_FUNCTION constexpr natural_domain_t natural_domain() const noexcept
643 {
644 return natural_domain_t(non_indices_domain(), accessor_t::natural_domain());
646
647 KOKKOS_FUNCTION constexpr discrete_domain_type access_domain() const noexcept
648 {
649 return discrete_domain_type(non_indices_domain(), accessor_t::access_domain());
650 }
652 template <class... CDim>
653 KOKKOS_FUNCTION constexpr discrete_element_type access_element()
654 const noexcept // TODO merge this with the one below
655 {
656 return discrete_element_type(accessor_t::template access_element<CDim...>());
657 }
659 template <class... Elem>
660 KOKKOS_FUNCTION constexpr discrete_element_type access_element(Elem... elem) const noexcept
661 {
663 accessor_t::access_element(
664 typename accessor_t::natural_domain_t::discrete_element_type(elem...)),
665 typename non_indices_domain_t::discrete_element_type(elem...));
666 }
668 template <class... Elem>
669 KOKKOS_FUNCTION constexpr natural_domain_t::discrete_element_type canonical_natural_element(
670 Elem... mem_elem) const noexcept
671 {
672 return typename natural_domain_t::discrete_element_type(
673 accessor_t::canonical_natural_element(
674 typename accessor_t::discrete_element_type(mem_elem...)),
675 typename non_indices_domain_t::discrete_element_type(mem_elem...));
676 }
678 template <class... DElems>
679 KOKKOS_FUNCTION constexpr reference mem(DElems const&... delems) const noexcept
680 {
681 return ddc::ChunkSpan<
682 ElementType,
683 ddc::DiscreteDomain<DDim...>,
684 LayoutStridedPolicy,
685 MemorySpace>::operator()(delems...);
686 }
688 template <class... DElems>
689 KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
690 {
691 return ddc::ChunkSpan<
692 ElementType,
693 ddc::DiscreteDomain<DDim...>,
694 LayoutStridedPolicy,
695 MemorySpace>::
696 operator()(ddc::DiscreteElement<DDim...>(
697 detail::LambdaMemElem<DDim>::run(ddc::DiscreteElement<DDim>(delems...))...));
698 }
700 template <class... ODDim>
701 KOKKOS_FUNCTION constexpr auto operator[](
702 ddc::DiscreteElement<ODDim...> const& slice_spec) const noexcept
703 {
704 ddc::ChunkSpan chunkspan = ddc::ChunkSpan<
705 ElementType,
706 ddc::DiscreteDomain<DDim...>,
707 LayoutStridedPolicy,
708 MemorySpace>::
709 operator[](
710 ddc::DiscreteElement<ODDim...>(detail::LambdaMemElem<ODDim>::run(slice_spec)...));
711 return Tensor<
712 ElementType,
713 ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
714 ddc::detail::TypeSeq<DDim...>,
715 ddc::detail::TypeSeq<ODDim...>>>,
716 typename decltype(chunkspan)::layout_type,
717 MemorySpace>(chunkspan);
718 }
720 template <class... DElems>
721 KOKKOS_FUNCTION ElementType get(DElems const&... delems) const noexcept
722 {
723 if constexpr (sizeof...(DDim) == 0) {
724 return operator()(delems...);
725 } else {
726 return detail::Access<
727 Tensor<ElementType,
728 ddc::DiscreteDomain<DDim...>,
729 LayoutStridedPolicy,
730 MemorySpace>,
731 ddc::DiscreteElement<DDim...>,
732 ddc::detail::TypeSeq<>,
733 DDim...>::run(*this, ddc::DiscreteElement<DDim...>(delems...));
734 }
735 }
736
737 KOKKOS_FUNCTION Tensor<
738 ElementType,
739 ddc::DiscreteDomain<DDim...>,
740 LayoutStridedPolicy,
741 MemorySpace>&
742 operator+=(const Tensor<
743 ElementType,
744 ddc::DiscreteDomain<DDim...>,
745 LayoutStridedPolicy,
746 MemorySpace>& tensor)
747 {
748 ddc::annotated_for_each(this->domain(), [&](ddc::DiscreteElement<DDim...> elem) {
749 this->mem(elem) += tensor.mem(elem);
750 });
751 return *this;
752 }
753
754 KOKKOS_FUNCTION Tensor<
755 ElementType,
756 ddc::DiscreteDomain<DDim...>,
757 LayoutStridedPolicy,
758 MemorySpace>&
759 operator*=(const ElementType scalar)
760 {
761 ddc::annotated_for_each(this->domain(), [&](ddc::DiscreteElement<DDim...> elem) {
762 this->mem(elem) *= scalar;
763 });
764 return *this;
765 }
766};
767
768template <class ElementType, class SupportType, class Allocator>
769Tensor(ddc::Chunk<ElementType, SupportType, Allocator>)
771
772template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
773Tensor(ddc::ChunkSpan<ElementType, SupportType, LayoutStridedPolicy, MemorySpace>)
775
776namespace detail {
777
778// Domain of a tensor result of product between two tensors
779template <class Dom1, class Dom2>
780struct NaturalTensorProdDomain;
781
782template <class... DDim1, class... DDim2>
783struct NaturalTensorProdDomain<ddc::DiscreteDomain<DDim1...>, ddc::DiscreteDomain<DDim2...>>
784{
785 using type = ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_merge_t<
786 ddc::type_seq_remove_t<ddc::detail::TypeSeq<DDim1...>, ddc::detail::TypeSeq<DDim2...>>,
787 ddc::type_seq_remove_t<
788 ddc::detail::TypeSeq<DDim2...>,
789 ddc::detail::TypeSeq<DDim1...>>>>;
790};
791
792} // namespace detail
793
794template <
797using natural_tensor_prod_domain_t = detail::NaturalTensorProdDomain<Dom1, Dom2>::type;
798
799template <
806
807namespace detail {
808
809// Product between two tensors naturally indexed.
810template <class HeadDDim1TypeSeq, class ContractDDimTypeSeq, class TailDDim2TypeSeq>
811struct NaturalTensorProd;
812
813template <class... HeadDDim1, class... ContractDDim, class... TailDDim2>
814struct NaturalTensorProd<
815 ddc::detail::TypeSeq<HeadDDim1...>,
816 ddc::detail::TypeSeq<ContractDDim...>,
817 ddc::detail::TypeSeq<TailDDim2...>>
818{
819 template <class ElementType, class LayoutStridedPolicy, class MemorySpace>
820 KOKKOS_FUNCTION static Tensor<
821 ElementType,
822 ddc::DiscreteDomain<HeadDDim1..., TailDDim2...>,
823 LayoutStridedPolicy,
824 MemorySpace>
825 run(Tensor<ElementType,
826 ddc::DiscreteDomain<HeadDDim1..., TailDDim2...>,
827 LayoutStridedPolicy,
828 MemorySpace> prod_tensor,
829 Tensor<ElementType,
830 ddc::DiscreteDomain<HeadDDim1..., ContractDDim...>,
831 LayoutStridedPolicy,
832 MemorySpace> tensor1,
833 Tensor<ElementType,
834 ddc::DiscreteDomain<ContractDDim..., TailDDim2...>,
835 LayoutStridedPolicy,
836 MemorySpace> tensor2)
837 {
838 ddc::annotated_for_each(
839 prod_tensor.domain(),
840 [&](ddc::DiscreteElement<HeadDDim1..., TailDDim2...> elem) {
841 prod_tensor(elem) = ddc::annotated_transform_reduce(
842 tensor1.template domain<ContractDDim...>(),
843 0.,
844 ddc::reducer::sum<ElementType>(),
845 [&](ddc::DiscreteElement<ContractDDim...> contract_elem) {
846 return tensor1(ddc::select<HeadDDim1...>(elem), contract_elem)
847 * tensor2(ddc::select<TailDDim2...>(elem), contract_elem);
848 });
849 });
850 return prod_tensor;
851 }
852};
853
854} // namespace detail
855
856template <
857 TensorNatIndex... ProdDDim,
858 TensorNatIndex... DDim1,
859 TensorNatIndex... DDim2,
860 class ElementType,
861 class LayoutStridedPolicy,
862 class MemorySpace>
863Tensor<ElementType, ddc::DiscreteDomain<ProdDDim...>, LayoutStridedPolicy, MemorySpace> tensor_prod(
864 Tensor<ElementType, ddc::DiscreteDomain<ProdDDim...>, LayoutStridedPolicy, MemorySpace>
865 prod_tensor,
866 Tensor<ElementType, ddc::DiscreteDomain<DDim1...>, LayoutStridedPolicy, MemorySpace>
867 tensor1,
868 Tensor<ElementType, ddc::DiscreteDomain<DDim2...>, LayoutStridedPolicy, MemorySpace>
869 tensor2)
870{
871 static_assert(std::is_same_v<
872 ddc::type_seq_remove_t<
873 ddc::detail::TypeSeq<DDim1...>,
874 ddc::detail::TypeSeq<ProdDDim...>>,
875 ddc::type_seq_remove_t<
876 ddc::detail::TypeSeq<DDim2...>,
877 ddc::detail::TypeSeq<ProdDDim...>>>);
878 return detail::NaturalTensorProd<
879 ddc::type_seq_remove_t<
880 ddc::detail::TypeSeq<ProdDDim...>,
881 ddc::detail::TypeSeq<DDim2...>>,
882 ddc::type_seq_remove_t<
883 ddc::detail::TypeSeq<DDim1...>,
884 ddc::detail::TypeSeq<ProdDDim...>>,
885 ddc::type_seq_remove_t<
886 ddc::detail::TypeSeq<ProdDDim...>,
887 ddc::detail::TypeSeq<DDim1...>>>::run(prod_tensor, tensor1, tensor2);
888}
889
890namespace detail {
891
892template <class HeadDom, class InterestDom, class TailDom>
893struct PrintTensor;
894
895template <class... HeadDDim, class InterestDDim>
896struct PrintTensor<
897 ddc::DiscreteDomain<HeadDDim...>,
898 ddc::DiscreteDomain<InterestDDim>,
899 ddc::DiscreteDomain<>>
900{
901 template <class TensorType>
902 static std::string run(
903 std::string& str,
904 TensorType const& tensor,
905 ddc::DiscreteElement<HeadDDim...> i)
906 {
907 for (ddc::DiscreteElement<InterestDDim> elem :
908 ddc::DiscreteDomain<InterestDDim>(tensor.natural_domain())) {
909 str = str + " "
910 + std::to_string(tensor.get(tensor.access_element(
911 ddc::DiscreteElement<HeadDDim..., InterestDDim>(i, elem))));
912 }
913 str += "\n";
914 return str;
915 }
916};
917
918template <class... HeadDDim, class InterestDDim, class HeadOfTailDDim, class... TailOfTailDDim>
919struct PrintTensor<
920 ddc::DiscreteDomain<HeadDDim...>,
921 ddc::DiscreteDomain<InterestDDim>,
922 ddc::DiscreteDomain<HeadOfTailDDim, TailOfTailDDim...>>
923{
924 template <class TensorType>
925 static std::string run(
926 std::string& str,
927 TensorType const& tensor,
928 ddc::DiscreteElement<HeadDDim...> i)
929 {
930 str += "[";
931 for (ddc::DiscreteElement<InterestDDim> elem :
932 ddc::DiscreteDomain<InterestDDim>(tensor.natural_domain())) {
933 str = PrintTensor<
934 ddc::DiscreteDomain<HeadDDim..., InterestDDim>,
935 ddc::DiscreteDomain<HeadOfTailDDim>,
936 ddc::DiscreteDomain<TailOfTailDDim...>>::
937 run(str, tensor, ddc::DiscreteElement<HeadDDim..., InterestDDim>(i, elem));
938 }
939 str += "]\n";
940 return str;
941 }
942};
943
944} // namespace detail
945
946template <misc::Specialization<Tensor> TensorType>
947std::ostream& operator<<(std::ostream& os, TensorType const& tensor)
948{
949 std::string str = "";
950 os << detail::PrintTensor<
951 ddc::DiscreteDomain<>,
952 ddc::DiscreteDomain<ddc::type_seq_element_t<
953 0,
954 ddc::to_type_seq_t<typename TensorType::natural_domain_t>>>,
955 ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
956 ddc::to_type_seq_t<typename TensorType::natural_domain_t>,
957 ddc::detail::TypeSeq<ddc::type_seq_element_t<
958 0,
959 ddc::to_type_seq_t<typename TensorType::natural_domain_t>>>>>>::
960 run(str, tensor, ddc::DiscreteElement<>());
961 return os;
962}
963
964} // namespace tensor
965
966} // namespace sil
ddc::cartesian_prod_t< std::conditional_t< TensorNatIndex< Index >, ddc::DiscreteDomain< Index >, typename Index::subindices_domain_t >... > natural_domain_t
static constexpr natural_domain_t natural_domain()
static constexpr natural_domain_t::discrete_element_type canonical_natural_element(ddc::DiscreteElement< MemIndex... > mem_elem)
static constexpr discrete_domain_type domain()
static constexpr discrete_element_type access_element()
static constexpr discrete_domain_type access_domain()
ddc::DiscreteDomain< Index... > discrete_domain_type
ddc::DiscreteElement< Index... > discrete_element_type
ddc:: ChunkSpan< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > base_type
ddc::cartesian_prod_t< non_indices_domain_t, typename accessor_t::natural_domain_t > natural_domain_t
ddc::detail::convert_type_seq_to_discrete_domain_t< ddc::type_seq_remove_t< ddc::to_type_seq_t< discrete_domain_type >, ddc::to_type_seq_t< indices_domain_t > > > non_indices_domain_t
tensor_accessor_for_domain_t< ddc::cartesian_prod_t< std::conditional_t< TensorIndex< DDim >, ddc::DiscreteDomain< DDim >, ddc::DiscreteDomain<> >... > > accessor_t
KOKKOS_FUNCTION constexpr Tensor(ddc::ChunkSpan< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > other) noexcept
Tensor(ddc::Chunk< ElementType, SupportType, Allocator >) -> Tensor< ElementType, SupportType, Kokkos::layout_right, typename Allocator::memory_space >
detail::NaturalTensorProdDomain< Dom1, Dom2 >::type natural_tensor_prod_domain_t
natural_tensor_prod_domain_t< Dom1, Dom2 > natural_tensor_prod_domain(Dom1 dom1, Dom2 dom2)
typename detail::NaturalDomainType< Index >::type natural_domain_t
detail::TensorAccessorForDomain< Dom >::type tensor_accessor_for_domain_t
The top-level namespace of SimiLie.
Definition csr.hpp:14
static KOKKOS_FUNCTION constexpr std::size_t mem_id(std::size_t const natural_id)
static KOKKOS_FUNCTION constexpr std::size_t mem_id()
ddc::DiscreteDomain<> subindices_domain_t
static KOKKOS_FUNCTION constexpr subindices_domain_t subindices_domain()
static KOKKOS_FUNCTION constexpr std::size_t access_id_to_mem_id(std::size_t access_id)
static constexpr bool is_explicitely_stored_tensor
static KOKKOS_FUNCTION constexpr std::size_t access_size()
static KOKKOS_FUNCTION constexpr std::size_t access_id(std::size_t const natural_id)
ddc::detail::TypeSeq< CDim... > type_seq_dimensions
static KOKKOS_FUNCTION constexpr std::size_t rank()
static KOKKOS_FUNCTION constexpr Tensor::element_type process_access(const FunctorType &access, Tensor tensor, Elem elem)
static KOKKOS_FUNCTION constexpr std::size_t size()
static KOKKOS_FUNCTION constexpr std::size_t mem_size()
static KOKKOS_FUNCTION constexpr std::array< std::size_t, rank()> mem_id_to_canonical_natural_ids(std::size_t mem_id)
static constexpr bool is_tensor_natural_index
static constexpr bool is_tensor_index