SimiLie
Loading...
Searching...
No Matches
tensor_impl.hpp
1// SPDX-FileCopyrightText: 2024 Baptiste Legouix
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include <iostream>
7
8#include <ddc/ddc.hpp>
9
10#include <similie/misc/portable_stl.hpp>
11#include <similie/misc/specialization.hpp>
12
13namespace sil {
14
15namespace tensor {
16
17// struct representing an index mu or nu in a tensor Tmunu.
18template <class... CDim>
20{
21 static constexpr bool is_tensor_index = true;
22 static constexpr bool is_tensor_natural_index = true;
23 static constexpr bool is_explicitely_stored_tensor = true;
24
25 using type_seq_dimensions = ddc::detail::TypeSeq<CDim...>;
26
27 using subindices_domain_t = ddc::DiscreteDomain<>;
28
29 KOKKOS_FUNCTION static constexpr subindices_domain_t subindices_domain()
30 {
31 return ddc::DiscreteDomain<>();
32 }
33
34 KOKKOS_FUNCTION static constexpr std::size_t rank()
35 {
36 return sizeof...(CDim) != 0;
37 }
38
39 KOKKOS_FUNCTION static constexpr std::size_t size()
40 {
41 if constexpr (rank() == 0) {
42 return 1;
43 } else {
44 return sizeof...(CDim);
45 }
46 }
47
48 KOKKOS_FUNCTION static constexpr std::size_t mem_size()
49 {
50 return size();
51 }
52
53 KOKKOS_FUNCTION static constexpr std::size_t access_size()
54 {
55 return size();
56 }
57
58 template <class ODim>
59 KOKKOS_FUNCTION static constexpr std::size_t mem_id()
60 {
61 if constexpr (rank() == 0) {
62 return 0;
63 } else {
64 return ddc::type_seq_rank_v<ODim, type_seq_dimensions>;
65 }
66 }
67
68 KOKKOS_FUNCTION static constexpr std::size_t mem_id(std::size_t const natural_id)
69 {
70 return natural_id;
71 }
72
73 KOKKOS_FUNCTION static constexpr std::size_t access_id(std::size_t const natural_id)
74 {
75 return natural_id;
76 }
77
78 KOKKOS_FUNCTION static constexpr std::size_t access_id_to_mem_id(std::size_t access_id)
79 {
80 return access_id;
81 }
82
83 template <class Tensor, class Elem, class Id, class FunctorType>
84 KOKKOS_FUNCTION static constexpr Tensor::element_type process_access(
85 const FunctorType& access,
86 Tensor tensor,
87 Elem elem)
88 {
89 return access(tensor, elem);
90 }
91
92 KOKKOS_FUNCTION static constexpr std::array<std::size_t, rank()>
94 {
95 assert(mem_id < mem_size());
96 if constexpr (rank() == 0) {
97 return std::array<std::size_t, rank()> {};
98 } else {
99 return std::array<std::size_t, rank()> {mem_id};
100 }
101 }
102};
103
104template <class DDim>
105concept TensorIndex = requires {
106 { DDim::is_tensor_index } -> std::convertible_to<bool>;
107} && DDim::is_tensor_index;
108
109template <class DDim>
110concept TensorNatIndex = requires {
111 { DDim::is_tensor_natural_index } -> std::convertible_to<bool>;
112} && DDim::is_tensor_natural_index;
113
114// ScalarIndex, a generic rank-0 index
116{
117};
118
119// natural_domain_t is obtained using concept and specialization
120namespace detail {
121
122template <class Index>
123struct NaturalDomainType;
124
125template <class Index>
127struct NaturalDomainType<Index>
128{
129 using type = typename Index::subindices_domain_t;
130};
131
132template <TensorNatIndex Index>
133struct NaturalDomainType<Index>
134{
135 using type = ddc::DiscreteDomain<Index>;
136};
137
138} // namespace detail
139
140template <TensorIndex Index>
141using natural_domain_t = typename detail::NaturalDomainType<Index>::type;
142
143// Helpers to build the access_id() function which computes the ids of subindices of an index. This cumbersome logic is necessary because subindices do not necessarily have the same rank.
144namespace detail {
145// For Tmunu and index=nu, returns 1
146template <class Index, class...>
147struct NbDimsBeforeIndex;
148
149template <class Index, class IndexHead, class... IndexTail>
150struct NbDimsBeforeIndex<Index, ddc::detail::TypeSeq<IndexHead, IndexTail...>>
151{
152 static constexpr std::size_t run(std::size_t nb_dims_before_index)
153 {
154 if constexpr (std::is_same_v<IndexHead, Index>) {
155 return nb_dims_before_index;
156 } else {
157 return NbDimsBeforeIndex<Index, ddc::detail::TypeSeq<IndexTail...>>::run(
158 nb_dims_before_index + IndexHead::rank());
159 }
160 }
161};
162
163// Offset and index sequence
164template <std::size_t Offset, class IndexSeq>
165struct OffsetIndexSeq;
166
167template <std::size_t Offset, std::size_t... Is>
168struct OffsetIndexSeq<Offset, std::integer_sequence<std::size_t, Is...>>
169{
170 using type = std::integer_sequence<std::size_t, Offset + Is...>;
171};
172
173template <std::size_t Offset, class IndexSeq>
174using offset_index_seq_t = OffsetIndexSeq<Offset, IndexSeq>::type;
175
176// Returns dimensions from integers (ie. for Tmunu, <1> gives nu)
177template <class CDimTypeSeq, class IndexSeq>
178struct TypeSeqDimsAtInts;
179
180template <class CDimTypeSeq, std::size_t... Is>
181struct TypeSeqDimsAtInts<CDimTypeSeq, std::integer_sequence<std::size_t, Is...>>
182{
183 using type = ddc::detail::TypeSeq<ddc::type_seq_element_t<Is, CDimTypeSeq>...>;
184};
185
186template <class CDimTypeSeq, class IndexSeq>
187using type_seq_dims_at_ints_t = TypeSeqDimsAtInts<CDimTypeSeq, IndexSeq>::type;
188
189// Returns Index::access_id but from a type seq (in place of a variadic template CDim...)
190template <class Index, class SubindicesDomain, class TypeSeqDims>
191struct IdFromTypeSeqDims;
192
193template <class Index, class... Subindex, class... CDim>
194struct IdFromTypeSeqDims<Index, ddc::DiscreteDomain<Subindex...>, ddc::detail::TypeSeq<CDim...>>
195{
196 static constexpr std::size_t run()
197 {
198 static_assert(sizeof...(Subindex) == sizeof...(CDim));
199 if constexpr (TensorNatIndex<Index>) {
200 return Index::access_id(
201 ddc::type_seq_rank_v<CDim, typename Index::type_seq_dimensions>...);
202 } else {
203 return Index::access_id(
204 std::array<std::size_t, sizeof...(Subindex)> {ddc::type_seq_rank_v<
205 typename ddc::type_seq_element_t<
206 ddc::type_seq_rank_v<
207 Subindex,
208 ddc::detail::TypeSeq<Subindex...>>,
209 ddc::detail::TypeSeq<CDim...>>,
210 typename Subindex::type_seq_dimensions>...});
211 }
212 }
213};
214
215// Returns Index::access_id for the subindex Index of the IndicesTypeSeq
216template <class Index, class IndicesTypeSeq, class... CDim>
217static constexpr std::size_t access_id()
218{
219 if constexpr (TensorNatIndex<Index>) {
220 return IdFromTypeSeqDims<
221 Index,
222 ddc::DiscreteDomain<Index>,
223 type_seq_dims_at_ints_t<
224 ddc::detail::TypeSeq<CDim...>,
225 offset_index_seq_t<
226 NbDimsBeforeIndex<Index, IndicesTypeSeq>::run(0),
227 std::make_integer_sequence<std::size_t, Index::rank()>>>>::run();
228 } else {
229 return IdFromTypeSeqDims<
230 Index,
231 typename Index::subindices_domain_t,
232 type_seq_dims_at_ints_t<
233 ddc::detail::TypeSeq<CDim...>,
234 offset_index_seq_t<
235 NbDimsBeforeIndex<Index, IndicesTypeSeq>::run(0),
236 std::make_integer_sequence<std::size_t, Index::rank()>>>>::run();
237 }
238}
239
240template <class Index, class SubindicesDomain>
241struct IdFromElem;
242
243template <class Index, class... Subindex>
244struct IdFromElem<Index, ddc::DiscreteDomain<Subindex...>>
245{
246 template <class Elem>
247 static constexpr std::size_t run(Elem natural_elem)
248 {
249 if constexpr (TensorNatIndex<Index>) {
250 return Index::access_id(natural_elem.template uid<Index>());
251 } else {
252 return Index::access_id(
253 std::array<std::size_t, sizeof...(Subindex)> {
254 natural_elem.template uid<Subindex>()...});
255 }
256 }
257};
258
259template <class Index, class IndicesTypeSeq, class... NaturalIndex>
260static constexpr std::size_t access_id(ddc::DiscreteElement<NaturalIndex...> natural_elem)
261{
262 if constexpr (TensorNatIndex<Index>) {
263 return IdFromElem<Index, ddc::DiscreteDomain<Index>>::run(natural_elem);
264 } else {
265 return IdFromElem<Index, typename Index::subindices_domain_t>::run(natural_elem);
266 }
267}
268
269} // namespace detail
270
271// TensorAccessor class, allows to build a domain which represents the tensor and access elements.
272template <TensorIndex... Index>
274{
275public:
276 explicit constexpr TensorAccessor();
277
278 using discrete_domain_type = ddc::DiscreteDomain<Index...>;
279
280 using discrete_element_type = ddc::DiscreteElement<Index...>;
281
282 using natural_domain_t = ddc::cartesian_prod_t<std::conditional_t< // TODO natural_domain_type
284 ddc::DiscreteDomain<Index>,
285 typename Index::subindices_domain_t>...>;
286
287 static constexpr natural_domain_t natural_domain();
288
289 static constexpr discrete_domain_type domain();
290
291 static constexpr discrete_domain_type access_domain();
292
293 template <class... CDim>
294 static constexpr discrete_element_type access_element();
295
296 template <class... NaturalIndex>
297 static constexpr discrete_element_type access_element(
298 ddc::DiscreteElement<NaturalIndex...> natural_elem);
299
300 template <class... MemIndex>
301 static constexpr natural_domain_t::discrete_element_type canonical_natural_element(
302 ddc::DiscreteElement<MemIndex...> mem_elem);
303};
304
305namespace detail {
306
307template <class Seq>
308struct TensorAccessorForTypeSeq;
309
310template <TensorIndex... Index>
311struct TensorAccessorForTypeSeq<ddc::detail::TypeSeq<Index...>>
312{
313 using type = TensorAccessor<Index...>;
314};
315
316template <class Dom>
317struct TensorAccessorForDomain;
318
319template <class... DDim>
320struct TensorAccessorForDomain<ddc::DiscreteDomain<DDim...>>
321{
322 using type = typename TensorAccessorForTypeSeq<
323 ddc::to_type_seq_t<ddc::cartesian_prod_t<std::conditional_t<
324 TensorIndex<DDim>,
325 ddc::DiscreteDomain<DDim>,
326 ddc::DiscreteDomain<>>...>>>::type;
327};
328
329} // namespace detail
330
331template <misc::Specialization<ddc::DiscreteDomain> Dom>
332using tensor_accessor_for_domain_t = detail::TensorAccessorForDomain<Dom>::type;
333
334template <TensorIndex... Index>
338
339namespace detail {
340template <class Index>
341constexpr auto natural_domain()
342{
343 if constexpr (TensorNatIndex<Index>) {
344 return typename ddc::DiscreteDomain<
345 Index>(ddc::DiscreteElement<Index>(0), ddc::DiscreteVector<Index>(Index::size()));
346 } else {
347 return Index::subindices_domain();
348 }
349}
350} // namespace detail
351
352template <TensorIndex... Index>
354{
355 return natural_domain_t(detail::natural_domain<Index>()...);
356}
357
358template <TensorIndex... Index>
359constexpr TensorAccessor<Index...>::discrete_domain_type TensorAccessor<Index...>::domain()
360{
361 return ddc::DiscreteDomain<Index...>(
362 ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(0)...),
363 ddc::DiscreteVector<Index...>(ddc::DiscreteVector<Index>(Index::mem_size())...));
364}
365
366template <TensorIndex... Index>
367constexpr TensorAccessor<Index...>::discrete_domain_type TensorAccessor<Index...>::access_domain()
368{
369 return ddc::DiscreteDomain<Index...>(
370 ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(0)...),
371 ddc::DiscreteVector<Index...>(ddc::DiscreteVector<Index>(Index::access_size())...));
372}
373
374template <TensorIndex... Index>
375template <class... CDim>
376constexpr TensorAccessor<Index...>::discrete_element_type TensorAccessor<Index...>::access_element()
377{
378 return ddc::DiscreteElement<Index...>(ddc::DiscreteElement<Index>(
379 detail::access_id<Index, ddc::detail::TypeSeq<Index...>, CDim...>())...);
380}
381
382template <TensorIndex... Index>
383template <class... NaturalIndex>
384constexpr TensorAccessor<Index...>::discrete_element_type TensorAccessor<Index...>::access_element(
385 [[maybe_unused]] ddc::DiscreteElement<NaturalIndex...> natural_elem)
386{
387 return ddc::DiscreteElement<Index...>(
388 ddc::DiscreteElement<Index>(detail::access_id<Index, ddc::detail::TypeSeq<Index...>>(
389 typename natural_domain_t::discrete_element_type(natural_elem)))...);
390}
391
392template <TensorIndex... Index>
393template <class... MemIndex>
394constexpr TensorAccessor<Index...>::natural_domain_t::discrete_element_type TensorAccessor<
395 Index...>::canonical_natural_element(ddc::DiscreteElement<MemIndex...> mem_elem)
396{
397 std::array<std::size_t, natural_domain_t::rank()> ids {};
398 auto it = ids.begin();
399 (
400 [&]() {
401 auto i = MemIndex::mem_id_to_canonical_natural_ids(
402 mem_elem.template uid<MemIndex>());
403 misc::detail::copy(i.begin(), i.end(), it);
404 it += i.size();
405 }(),
406 ...);
407 typename natural_domain_t::discrete_element_type natural_elem;
408 ddc::detail::array(natural_elem) = std::array<std::size_t, natural_domain_t::rank()>(ids);
409 return natural_elem;
410}
411
412namespace detail {
413
414// Helpers to handle memory access and processing for particular tensor structures (ie. eventual multiplication with -1 for antisymmetry or non-stored zeros)
415template <
416 class TensorField,
417 class Element,
418 class IndexHeadsTypeSeq,
419 class IndexInterest,
420 class... IndexTail>
421struct Access;
422
423template <
424 class TensorField,
425 class Element,
426 class... IndexHead,
427 class IndexInterest,
428 class... IndexTail>
429struct Access<TensorField, Element, ddc::detail::TypeSeq<IndexHead...>, IndexInterest, IndexTail...>
430{
431 template <class Elem>
432 KOKKOS_FUNCTION static TensorField::element_type run(TensorField tensor_field, Elem const& elem)
433 {
434 /*
435 ----- Important warning -----
436 The general case is not correctly handled here. It would be difficult to do so.
437 It means you can get silent bug (with wrong result) if you try to use exotic ordering
438 of dimensions/indices. Ie., a TensorYoungTableauIndex has to be the last of the list.
439 */
440 if constexpr (sizeof...(IndexTail) > 0) {
441 if constexpr (TensorIndex<IndexInterest>) {
442 return IndexInterest::template process_access<TensorField, Elem, IndexInterest>(
443 KOKKOS_LAMBDA(TensorField tensor_field_, Elem elem_)
444 ->TensorField::element_type {
445 return Access<
446 TensorField,
447 Element,
448 ddc::detail::TypeSeq<IndexHead..., IndexInterest>,
449 IndexTail...>::run(tensor_field_, elem_);
450 },
451 tensor_field,
452 elem);
453 } else {
454 return Access<
455 TensorField,
456 Element,
457 ddc::detail::TypeSeq<IndexHead..., IndexInterest>,
458 IndexTail...>::run(tensor_field, elem);
459 }
460 } else {
461 if constexpr (TensorIndex<IndexInterest>) {
462 return IndexInterest::template process_access<TensorField, Elem, IndexInterest>(
463 KOKKOS_LAMBDA(TensorField tensor_field_, Elem elem_)
464 ->TensorField::element_type {
465 double tensor_field_value = 0;
466 if constexpr (IndexInterest::is_explicitely_stored_tensor) {
467 std::size_t const mem_id
468 = IndexInterest::access_id_to_mem_id(
469 elem_.template uid<IndexInterest>());
470 if (mem_id != std::numeric_limits<std::size_t>::max()) {
471 tensor_field_value
472 = tensor_field_
473 .mem(ddc::DiscreteElement<
474 IndexHead...>(elem_),
475 ddc::DiscreteElement<
476 IndexInterest>(mem_id));
477 } else {
478 tensor_field_value = 1.;
479 }
480 } else {
481 std::pair<
482 std::vector<double>,
483 std::vector<std::size_t>> const mem_lin_comb
484 = IndexInterest::access_id_to_mem_lin_comb(
485 elem_.template uid<IndexInterest>());
486
487 if (std::get<0>(mem_lin_comb).size() > 0) {
488 for (std::size_t i = 0;
489 i < std::get<0>(mem_lin_comb).size();
490 ++i) {
491 tensor_field_value
492 += std::get<0>(mem_lin_comb)[i]
493 * tensor_field_.mem(
494 ddc::DiscreteElement<
495 IndexHead...>(elem_),
496 ddc::DiscreteElement<
497 IndexInterest>(std::get<
498 1>(
499 mem_lin_comb)[i]));
500 }
501 } else {
502 tensor_field_value = 1.;
503 }
504 }
505
506 return tensor_field_value;
507 },
508 tensor_field,
509 elem);
510 } else {
511 return tensor_field(elem);
512 }
513 }
514 }
515};
516
517// Functor for memory element access (if defined)
518template <class InterestDim>
519struct LambdaMemElem
520{
521 template <class Elem>
522 KOKKOS_FUNCTION static ddc::DiscreteElement<InterestDim> run(Elem elem)
523 {
524 return ddc::DiscreteElement<InterestDim>(elem);
525 }
526};
527
528template <TensorIndex InterestDim>
529struct LambdaMemElem<InterestDim>
530{
531 template <class Elem>
532 KOKKOS_FUNCTION static ddc::DiscreteElement<InterestDim> run(Elem elem)
533 {
534 if constexpr (InterestDim::is_explicitely_stored_tensor) {
535 std::size_t const mem_id
536 = InterestDim::access_id_to_mem_id(elem.template uid<InterestDim>());
537 assert(mem_id != std::numeric_limits<std::size_t>::max()
538 && "mem_elem is not defined because mem_id() returned a max integer. Maybe you "
539 "used Tensor::operator() in place of Tensor::get ?");
540 return ddc::DiscreteElement<InterestDim>(mem_id);
541 } else {
542 std::pair<std::vector<double>, std::vector<std::size_t>> const mem_lin_comb
543 = InterestDim::access_id_to_mem_lin_comb(elem.template uid<InterestDim>());
544 assert(std::get<0>(mem_lin_comb).size() > 0
545 && "mem_elem is not defined because mem_lin_comb contains no id. Maybe you used "
546 "Tensor::operator() in place of Tensor::get ?");
547 assert(std::get<0>(mem_lin_comb).size() == 1
548 && "mem_elem is not defined because mem_lin_comb contains several ids. Maybe "
549 "you used Tensor::operator() in place of Tensor::get ?");
550 return ddc::DiscreteElement<InterestDim>(std::get<1>(mem_lin_comb)[0]);
551 }
552 }
553};
554
555} // namespace detail
556
557// @cond
558
559template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
560class Tensor;
561
562} // namespace tensor
563
564} // namespace sil
565
566namespace ddc {
567
568template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
569inline constexpr bool enable_chunk<
571 = true;
572
573template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
574inline constexpr bool enable_borrowed_chunk<
576 = true;
577
578} // namespace ddc
579
580namespace sil {
581
582namespace tensor {
583
584// @endcond
585
587template <class ElementType, class... DDim, class LayoutStridedPolicy, class MemorySpace>
588class Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
589 : public ddc::
590 ChunkSpan<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
591{
592protected:
593 using base_type = ddc::
594 ChunkSpan<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>;
595
596public:
597 using base_type::ChunkSpan;
598 using reference = base_type::reference;
599 using discrete_domain_type = base_type::discrete_domain_type;
600 using discrete_element_type = base_type::discrete_element_type;
601
602 using base_type::domain;
603 using base_type::operator();
604
605 KOKKOS_FUNCTION constexpr explicit Tensor(
606 ddc::ChunkSpan<
607 ElementType,
608 ddc::DiscreteDomain<DDim...>,
609 LayoutStridedPolicy,
610 MemorySpace> other) noexcept
611 : base_type(other)
612 {
613 }
614
615 using accessor_t = tensor_accessor_for_domain_t<ddc::cartesian_prod_t<std::conditional_t<
617 ddc::DiscreteDomain<DDim>,
618 ddc::DiscreteDomain<>>...>>;
619
620 static constexpr accessor_t accessor()
621 {
622 return accessor_t();
623 }
624
625 using indices_domain_t = accessor_t::discrete_domain_type;
626
628 = ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
629 ddc::to_type_seq_t<discrete_domain_type>,
630 ddc::to_type_seq_t<indices_domain_t>>>;
632 KOKKOS_FUNCTION constexpr indices_domain_t indices_domain() const noexcept
633 {
634 return indices_domain_t(domain());
635 }
637 KOKKOS_FUNCTION constexpr non_indices_domain_t non_indices_domain() const noexcept
638 {
639 return non_indices_domain_t(domain());
640 }
642 using natural_domain_t
643 = ddc::cartesian_prod_t<non_indices_domain_t, typename accessor_t::natural_domain_t>;
644
645 KOKKOS_FUNCTION constexpr natural_domain_t natural_domain() const noexcept
646 {
647 return natural_domain_t(non_indices_domain(), accessor_t::natural_domain());
649
650 KOKKOS_FUNCTION constexpr discrete_domain_type access_domain() const noexcept
651 {
652 return discrete_domain_type(non_indices_domain(), accessor_t::access_domain());
653 }
655 template <class... CDim>
656 KOKKOS_FUNCTION constexpr discrete_element_type access_element()
657 const noexcept // TODO merge this with the one below
658 {
659 return discrete_element_type(accessor_t::template access_element<CDim...>());
660 }
662 template <class... Elem>
663 KOKKOS_FUNCTION constexpr discrete_element_type access_element(Elem... elem) const noexcept
664 {
666 accessor_t::access_element(
667 typename accessor_t::natural_domain_t::discrete_element_type(elem...)),
668 typename non_indices_domain_t::discrete_element_type(elem...));
669 }
671 template <class... Elem>
672 KOKKOS_FUNCTION constexpr natural_domain_t::discrete_element_type canonical_natural_element(
673 Elem... mem_elem) const noexcept
674 {
675 return typename natural_domain_t::discrete_element_type(
676 accessor_t::canonical_natural_element(
677 typename accessor_t::discrete_element_type(mem_elem...)),
678 typename non_indices_domain_t::discrete_element_type(mem_elem...));
679 }
681 template <class... DElems>
682 KOKKOS_FUNCTION constexpr reference mem(DElems const&... delems) const noexcept
683 {
684 return ddc::ChunkSpan<
685 ElementType,
686 ddc::DiscreteDomain<DDim...>,
687 LayoutStridedPolicy,
688 MemorySpace>::
689 operator()(delems...);
690 }
692 template <class... DElems>
693 KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
694 {
695 return ddc::ChunkSpan<
696 ElementType,
697 ddc::DiscreteDomain<DDim...>,
698 LayoutStridedPolicy,
699 MemorySpace>::
700 operator()(ddc::DiscreteElement<DDim...>(
701 detail::LambdaMemElem<DDim>::run(ddc::DiscreteElement<DDim>(delems...))...));
702 }
704 template <class... ODDim>
705 KOKKOS_FUNCTION constexpr auto operator[](
706 ddc::DiscreteElement<ODDim...> const& slice_spec) const noexcept
707 {
708 ddc::ChunkSpan chunkspan = ddc::ChunkSpan<
709 ElementType,
710 ddc::DiscreteDomain<DDim...>,
711 LayoutStridedPolicy,
712 MemorySpace>::
713 operator[](
714 ddc::DiscreteElement<ODDim...>(detail::LambdaMemElem<ODDim>::run(slice_spec)...));
715 return Tensor<
716 ElementType,
717 ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
718 ddc::detail::TypeSeq<DDim...>,
719 ddc::detail::TypeSeq<ODDim...>>>,
720 typename decltype(chunkspan)::layout_type,
721 MemorySpace>(chunkspan);
722 }
724 template <class... DElems>
725 KOKKOS_FUNCTION ElementType get(DElems const&... delems) const noexcept
726 {
727 if constexpr (sizeof...(DDim) == 0) {
728 return operator()(delems...);
729 } else {
730 return detail::Access<
731 Tensor<ElementType,
732 ddc::DiscreteDomain<DDim...>,
733 LayoutStridedPolicy,
734 MemorySpace>,
735 ddc::DiscreteElement<DDim...>,
736 ddc::detail::TypeSeq<>,
737 DDim...>::run(*this, ddc::DiscreteElement<DDim...>(delems...));
738 }
739 }
740
741 KOKKOS_FUNCTION Tensor<
742 ElementType,
743 ddc::DiscreteDomain<DDim...>,
744 LayoutStridedPolicy,
745 MemorySpace>&
746 operator+=(const Tensor<
747 ElementType,
748 ddc::DiscreteDomain<DDim...>,
749 LayoutStridedPolicy,
750 MemorySpace>& tensor)
751 {
752 ddc::device_for_each(this->domain(), [&](ddc::DiscreteElement<DDim...> elem) {
753 this->mem(elem) += tensor.mem(elem);
754 });
755 return *this;
756 }
757
758 KOKKOS_FUNCTION Tensor<
759 ElementType,
760 ddc::DiscreteDomain<DDim...>,
761 LayoutStridedPolicy,
762 MemorySpace>&
763 operator*=(const ElementType scalar)
764 {
765 ddc::device_for_each(this->domain(), [&](ddc::DiscreteElement<DDim...> elem) {
766 this->mem(elem) *= scalar;
767 });
768 return *this;
769 }
770};
771
772template <class ElementType, class SupportType, class Allocator>
773Tensor(ddc::Chunk<ElementType, SupportType, Allocator>)
775
776template <class ElementType, class SupportType, class LayoutStridedPolicy, class MemorySpace>
777Tensor(ddc::ChunkSpan<ElementType, SupportType, LayoutStridedPolicy, MemorySpace>)
779
780namespace detail {
781
782// Domain of a tensor result of product between two tensors
783template <class Dom1, class Dom2>
784struct NaturalTensorProdDomain;
785
786template <class... DDim1, class... DDim2>
787struct NaturalTensorProdDomain<ddc::DiscreteDomain<DDim1...>, ddc::DiscreteDomain<DDim2...>>
788{
789 using type = ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_merge_t<
790 ddc::type_seq_remove_t<ddc::detail::TypeSeq<DDim1...>, ddc::detail::TypeSeq<DDim2...>>,
791 ddc::type_seq_remove_t<
792 ddc::detail::TypeSeq<DDim2...>,
793 ddc::detail::TypeSeq<DDim1...>>>>;
794};
795
796} // namespace detail
797
798template <
801using natural_tensor_prod_domain_t = detail::NaturalTensorProdDomain<Dom1, Dom2>::type;
802
803template <
810
811namespace detail {
812
813// Product between two tensors naturally indexed.
814template <class HeadDDim1TypeSeq, class ContractDDimTypeSeq, class TailDDim2TypeSeq>
815struct NaturalTensorProd;
816
817template <class... HeadDDim1, class... ContractDDim, class... TailDDim2>
818struct NaturalTensorProd<
819 ddc::detail::TypeSeq<HeadDDim1...>,
820 ddc::detail::TypeSeq<ContractDDim...>,
821 ddc::detail::TypeSeq<TailDDim2...>>
822{
823 template <class ElementType, class LayoutStridedPolicy, class MemorySpace>
824 KOKKOS_FUNCTION static Tensor<
825 ElementType,
826 ddc::DiscreteDomain<HeadDDim1..., TailDDim2...>,
827 LayoutStridedPolicy,
828 MemorySpace>
829 run(Tensor<ElementType,
830 ddc::DiscreteDomain<HeadDDim1..., TailDDim2...>,
831 LayoutStridedPolicy,
832 MemorySpace> prod_tensor,
833 Tensor<ElementType,
834 ddc::DiscreteDomain<HeadDDim1..., ContractDDim...>,
835 LayoutStridedPolicy,
836 MemorySpace> tensor1,
837 Tensor<ElementType,
838 ddc::DiscreteDomain<ContractDDim..., TailDDim2...>,
839 LayoutStridedPolicy,
840 MemorySpace> tensor2)
841 {
842 ddc::device_for_each(
843 prod_tensor.domain(),
844 [&](ddc::DiscreteElement<HeadDDim1..., TailDDim2...> elem) {
845 prod_tensor(elem) = ddc::device_transform_reduce(
846 tensor1.template domain<ContractDDim...>(),
847 0.,
848 ddc::reducer::sum<ElementType>(),
849 [&](ddc::DiscreteElement<ContractDDim...> contract_elem) {
850 return tensor1(ddc::select<HeadDDim1...>(elem), contract_elem)
851 * tensor2(ddc::select<TailDDim2...>(elem), contract_elem);
852 });
853 });
854 return prod_tensor;
855 }
856};
857
858} // namespace detail
859
860template <
861 TensorNatIndex... ProdDDim,
862 TensorNatIndex... DDim1,
863 TensorNatIndex... DDim2,
864 class ElementType,
865 class LayoutStridedPolicy,
866 class MemorySpace>
867Tensor<ElementType, ddc::DiscreteDomain<ProdDDim...>, LayoutStridedPolicy, MemorySpace> tensor_prod(
868 Tensor<ElementType, ddc::DiscreteDomain<ProdDDim...>, LayoutStridedPolicy, MemorySpace>
869 prod_tensor,
870 Tensor<ElementType, ddc::DiscreteDomain<DDim1...>, LayoutStridedPolicy, MemorySpace>
871 tensor1,
872 Tensor<ElementType, ddc::DiscreteDomain<DDim2...>, LayoutStridedPolicy, MemorySpace>
873 tensor2)
874{
875 static_assert(std::is_same_v<
876 ddc::type_seq_remove_t<
877 ddc::detail::TypeSeq<DDim1...>,
878 ddc::detail::TypeSeq<ProdDDim...>>,
879 ddc::type_seq_remove_t<
880 ddc::detail::TypeSeq<DDim2...>,
881 ddc::detail::TypeSeq<ProdDDim...>>>);
882 return detail::NaturalTensorProd<
883 ddc::type_seq_remove_t<
884 ddc::detail::TypeSeq<ProdDDim...>,
885 ddc::detail::TypeSeq<DDim2...>>,
886 ddc::type_seq_remove_t<
887 ddc::detail::TypeSeq<DDim1...>,
888 ddc::detail::TypeSeq<ProdDDim...>>,
889 ddc::type_seq_remove_t<
890 ddc::detail::TypeSeq<ProdDDim...>,
891 ddc::detail::TypeSeq<DDim1...>>>::run(prod_tensor, tensor1, tensor2);
892}
893
894namespace detail {
895
896template <class HeadDom, class InterestDom, class TailDom>
897struct PrintTensor;
898
899template <class... HeadDDim, class InterestDDim>
900struct PrintTensor<
901 ddc::DiscreteDomain<HeadDDim...>,
902 ddc::DiscreteDomain<InterestDDim>,
903 ddc::DiscreteDomain<>>
904{
905 template <class TensorType>
906 static std::string run(
907 std::string& str,
908 TensorType const& tensor,
909 ddc::DiscreteElement<HeadDDim...> i)
910 {
911 for (ddc::DiscreteElement<InterestDDim> elem :
912 ddc::DiscreteDomain<InterestDDim>(tensor.natural_domain())) {
913 str = str + " "
914 + std::to_string(tensor.get(tensor.access_element(
915 ddc::DiscreteElement<HeadDDim..., InterestDDim>(i, elem))));
916 }
917 str += "\n";
918 return str;
919 }
920};
921
922template <class... HeadDDim, class InterestDDim, class HeadOfTailDDim, class... TailOfTailDDim>
923struct PrintTensor<
924 ddc::DiscreteDomain<HeadDDim...>,
925 ddc::DiscreteDomain<InterestDDim>,
926 ddc::DiscreteDomain<HeadOfTailDDim, TailOfTailDDim...>>
927{
928 template <class TensorType>
929 static std::string run(
930 std::string& str,
931 TensorType const& tensor,
932 ddc::DiscreteElement<HeadDDim...> i)
933 {
934 str += "[";
935 for (ddc::DiscreteElement<InterestDDim> elem :
936 ddc::DiscreteDomain<InterestDDim>(tensor.natural_domain())) {
937 str = PrintTensor<
938 ddc::DiscreteDomain<HeadDDim..., InterestDDim>,
939 ddc::DiscreteDomain<HeadOfTailDDim>,
940 ddc::DiscreteDomain<TailOfTailDDim...>>::
941 run(str, tensor, ddc::DiscreteElement<HeadDDim..., InterestDDim>(i, elem));
942 }
943 str += "]\n";
944 return str;
945 }
946};
947
948} // namespace detail
949
950template <misc::Specialization<Tensor> TensorType>
951std::ostream& operator<<(std::ostream& os, TensorType const& tensor)
952{
953 std::string str = "";
954 os << detail::PrintTensor<
955 ddc::DiscreteDomain<>,
956 ddc::DiscreteDomain<ddc::type_seq_element_t<
957 0,
958 ddc::to_type_seq_t<typename TensorType::natural_domain_t>>>,
959 ddc::detail::convert_type_seq_to_discrete_domain_t<ddc::type_seq_remove_t<
960 ddc::to_type_seq_t<typename TensorType::natural_domain_t>,
961 ddc::detail::TypeSeq<ddc::type_seq_element_t<
962 0,
963 ddc::to_type_seq_t<typename TensorType::natural_domain_t>>>>>>::
964 run(str, tensor, ddc::DiscreteElement<>());
965 return os;
966}
967
968} // namespace tensor
969
970} // namespace sil
ddc::cartesian_prod_t< std::conditional_t< TensorNatIndex< Index >, ddc::DiscreteDomain< Index >, typename Index::subindices_domain_t >... > natural_domain_t
static constexpr natural_domain_t natural_domain()
static constexpr natural_domain_t::discrete_element_type canonical_natural_element(ddc::DiscreteElement< MemIndex... > mem_elem)
static constexpr discrete_domain_type domain()
static constexpr discrete_element_type access_element()
static constexpr discrete_domain_type access_domain()
ddc::DiscreteDomain< Index... > discrete_domain_type
ddc::DiscreteElement< Index... > discrete_element_type
ddc:: ChunkSpan< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > base_type
ddc::cartesian_prod_t< non_indices_domain_t, typename accessor_t::natural_domain_t > natural_domain_t
ddc::detail::convert_type_seq_to_discrete_domain_t< ddc::type_seq_remove_t< ddc::to_type_seq_t< discrete_domain_type >, ddc::to_type_seq_t< indices_domain_t > > > non_indices_domain_t
tensor_accessor_for_domain_t< ddc::cartesian_prod_t< std::conditional_t< TensorIndex< DDim >, ddc::DiscreteDomain< DDim >, ddc::DiscreteDomain<> >... > > accessor_t
KOKKOS_FUNCTION constexpr Tensor(ddc::ChunkSpan< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > other) noexcept
Tensor(ddc::Chunk< ElementType, SupportType, Allocator >) -> Tensor< ElementType, SupportType, Kokkos::layout_right, typename Allocator::memory_space >
detail::NaturalTensorProdDomain< Dom1, Dom2 >::type natural_tensor_prod_domain_t
natural_tensor_prod_domain_t< Dom1, Dom2 > natural_tensor_prod_domain(Dom1 dom1, Dom2 dom2)
typename detail::NaturalDomainType< Index >::type natural_domain_t
detail::TensorAccessorForDomain< Dom >::type tensor_accessor_for_domain_t
The top-level namespace of SimiLie.
Definition csr.hpp:14
static KOKKOS_FUNCTION constexpr std::size_t mem_id(std::size_t const natural_id)
static KOKKOS_FUNCTION constexpr std::size_t mem_id()
ddc::DiscreteDomain<> subindices_domain_t
static KOKKOS_FUNCTION constexpr subindices_domain_t subindices_domain()
static KOKKOS_FUNCTION constexpr std::size_t access_id_to_mem_id(std::size_t access_id)
static constexpr bool is_explicitely_stored_tensor
static KOKKOS_FUNCTION constexpr std::size_t access_size()
static KOKKOS_FUNCTION constexpr std::size_t access_id(std::size_t const natural_id)
ddc::detail::TypeSeq< CDim... > type_seq_dimensions
static KOKKOS_FUNCTION constexpr std::size_t rank()
static KOKKOS_FUNCTION constexpr Tensor::element_type process_access(const FunctorType &access, Tensor tensor, Elem elem)
static KOKKOS_FUNCTION constexpr std::size_t size()
static KOKKOS_FUNCTION constexpr std::size_t mem_size()
static KOKKOS_FUNCTION constexpr std::array< std::size_t, rank()> mem_id_to_canonical_natural_ids(std::size_t mem_id)
static constexpr bool is_tensor_natural_index
static constexpr bool is_tensor_index