SimiLie
Loading...
Searching...
No Matches
relabelization.hpp
1// SPDX-FileCopyrightText: 2024 Baptiste Legouix
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include <ddc/ddc.hpp>
7
8#include "tensor_impl.hpp"
9
10namespace sil {
11
12namespace tensor {
13
14// Relabelize index without altering allocation
15namespace detail {
16template <class IndexToRelabelize, TensorIndex OldIndex, TensorIndex NewIndex>
17struct RelabelizeIndex;
18
19template <class IndexToRelabelize, TensorIndex OldIndex, TensorIndex NewIndex>
20 requires(!TensorIndex<IndexToRelabelize> || TensorNatIndex<IndexToRelabelize>)
21struct RelabelizeIndex<IndexToRelabelize, OldIndex, NewIndex>
22{
23 using type = std::
24 conditional_t<std::is_same_v<IndexToRelabelize, OldIndex>, NewIndex, IndexToRelabelize>;
25};
26
27template <
28 template <class...>
29 class IndexToRelabelizeType,
30 TensorIndex OldIndex,
31 TensorIndex NewIndex,
32 class... Arg>
33struct RelabelizeIndex<IndexToRelabelizeType<Arg...>, OldIndex, NewIndex>
34{
35 using type = std::conditional_t<
36 std::is_same_v<IndexToRelabelizeType<Arg...>, OldIndex>,
37 NewIndex,
38 IndexToRelabelizeType<typename RelabelizeIndex<Arg, OldIndex, NewIndex>::type...>>;
39};
40
41template <class T, class OldIndex, class NewIndex>
42struct RelabelizeIndexInType;
43
44template <template <class...> class T, class... DDim, class OldIndex, class NewIndex>
45struct RelabelizeIndexInType<T<DDim...>, OldIndex, NewIndex>
46{
47 using type = T<typename RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>;
48};
49
50} // namespace detail
51
52template <class T, TensorIndex OldIndex, TensorIndex NewIndex>
53using relabelize_index_in_t = detail::RelabelizeIndexInType<T, OldIndex, NewIndex>::type;
54
55namespace detail {
56
57template <class OldIndex, class NewIndex>
58struct RelabelizeIndexIn
59{
60 template <class... DDim>
61 static constexpr auto run(ddc::DiscreteElement<DDim...> elem)
62 {
63 return ddc::DiscreteElement<
64 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>(
65 elem.template uid<DDim>()...);
66 }
67
68 template <class... DDim>
69 static constexpr auto run(ddc::DiscreteVector<DDim...> vect)
70 {
71 return ddc::DiscreteVector<
72 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>(
73 static_cast<std::size_t>(vect.template get<DDim>())...);
74 }
75
76 template <class... DDim>
77 static constexpr auto run(ddc::DiscreteDomain<DDim...> dom)
78 {
79 return relabelize_index_in_t<ddc::DiscreteDomain<DDim...>, OldIndex, NewIndex>(
82 }
83};
84
85} // namespace detail
86
87template <class OldIndex, class NewIndex, class T>
89{
90 return detail::RelabelizeIndexIn<OldIndex, NewIndex>::run(t);
91}
92
93namespace detail {
94template <class TensorType, class OldIndex, class NewIndex>
95struct RelabelizeIndexOfType;
96
97template <
98 class OldIndex,
99 class NewIndex,
100 class ElementType,
101 class Dom,
102 class LayoutStridedPolicy,
103 class MemorySpace>
104struct RelabelizeIndexOfType<
105 Tensor<ElementType, Dom, LayoutStridedPolicy, MemorySpace>,
106 OldIndex,
107 NewIndex>
108{
109 using type = Tensor<
110 ElementType,
111 typename RelabelizeIndexInType<Dom, OldIndex, NewIndex>::type,
112 LayoutStridedPolicy,
113 MemorySpace>;
114};
115
116} // namespace detail
117
118template <misc::Specialization<Tensor> TensorType, TensorIndex OldIndex, TensorIndex NewIndex>
119using relabelize_index_of_t = detail::RelabelizeIndexOfType<TensorType, OldIndex, NewIndex>::type;
120
121template <
122 TensorIndex OldIndex,
123 TensorIndex NewIndex,
124 class ElementType,
125 class... DDim,
126 class LayoutStridedPolicy,
127 class MemorySpace>
128constexpr relabelize_index_of_t<
129 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>,
130 OldIndex,
131 NewIndex>
133 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
134 old_tensor)
135{
137 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>,
138 OldIndex,
139 NewIndex>(
140 old_tensor.data_handle(),
141 typename detail::RelabelizeIndexInType<
142 ddc::DiscreteDomain<DDim...>,
143 OldIndex,
144 NewIndex>::
145 type(ddc::DiscreteDomain<
146 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type>(
147 ddc::DiscreteElement<
148 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::
149 type>(old_tensor.domain().front().template uid<DDim>()),
150 ddc::DiscreteVector<
151 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::
152 type>(static_cast<std::size_t>(
153 old_tensor.template extent<DDim>())))...));
154}
155
156namespace detail {
157template <class IndexToRelabelize, class OldIndices, class NewIndices>
158struct RelabelizeIndices;
159
160template <class IndexToRelabelize>
161struct RelabelizeIndices<IndexToRelabelize, ddc::detail::TypeSeq<>, ddc::detail::TypeSeq<>>
162{
163 using type = IndexToRelabelize;
164};
165
166template <
167 class IndexToRelabelize,
168 class HeadOldIndex,
169 class... TailOldIndex,
170 class HeadNewIndex,
171 class... TailNewIndex>
172struct RelabelizeIndices<
173 IndexToRelabelize,
174 ddc::detail::TypeSeq<HeadOldIndex, TailOldIndex...>,
175 ddc::detail::TypeSeq<HeadNewIndex, TailNewIndex...>>
176{
177 static_assert(sizeof...(TailOldIndex) == sizeof...(TailNewIndex));
178 using type = std::conditional_t<
179 (sizeof...(TailOldIndex) > 0),
180 typename RelabelizeIndices<
181 typename RelabelizeIndex<IndexToRelabelize, HeadOldIndex, HeadNewIndex>::type,
182 ddc::detail::TypeSeq<TailOldIndex...>,
183 ddc::detail::TypeSeq<TailNewIndex...>>::type,
184 typename RelabelizeIndex<IndexToRelabelize, HeadOldIndex, HeadNewIndex>::type>;
185};
186
187template <class T, class OldIndices, class NewIndices>
188struct RelabelizeIndicesInType;
189
190template <class T>
191struct RelabelizeIndicesInType<T, ddc::detail::TypeSeq<>, ddc::detail::TypeSeq<>>
192{
193 using type = T;
194};
195
196template <
197 class T,
198 class HeadOldIndex,
199 class... TailOldIndex,
200 class HeadNewIndex,
201 class... TailNewIndex>
202struct RelabelizeIndicesInType<
203 T,
204 ddc::detail::TypeSeq<HeadOldIndex, TailOldIndex...>,
205 ddc::detail::TypeSeq<HeadNewIndex, TailNewIndex...>>
206{
207 static_assert(sizeof...(TailOldIndex) == sizeof...(TailNewIndex));
208 using type = typename RelabelizeIndicesInType<
210 ddc::detail::TypeSeq<TailOldIndex...>,
211 ddc::detail::TypeSeq<TailNewIndex...>>::type;
212};
213
214} // namespace detail
215
216template <class T, class OldIndices, class NewIndices>
218 typename detail::RelabelizeIndicesInType<T, OldIndices, NewIndices>::type;
219
220namespace detail {
221
222template <class OldIndices, class NewIndices, std::size_t I = 0>
223struct RelabelizeIndicesIn
224{
225 template <class... DDim>
226 static constexpr auto run(ddc::DiscreteElement<DDim...> elem)
227 {
228 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
229 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
231 ddc::type_seq_element_t<I, OldIndices>,
232 ddc::type_seq_element_t<I, NewIndices>>(elem));
233 } else {
234 return elem;
235 }
236 }
237
238 template <class... DDim>
239 static constexpr auto run(ddc::DiscreteVector<DDim...> vect)
240 {
241 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
242 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
244 ddc::type_seq_element_t<I, OldIndices>,
245 ddc::type_seq_element_t<I, NewIndices>>(vect));
246 } else {
247 return vect;
248 }
249 }
250
251 template <class... DDim>
252 static constexpr auto run(ddc::DiscreteDomain<DDim...> dom)
253 {
254 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
255 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
257 ddc::type_seq_element_t<I, OldIndices>,
258 ddc::type_seq_element_t<I, NewIndices>>(dom));
259 } else {
260 return dom;
261 }
262 }
263};
264
265} // namespace detail
266
267template <class OldIndices, class NewIndices, class T>
269{
270 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
271 return detail::RelabelizeIndicesIn<OldIndices, NewIndices>::run(t);
272}
273
274namespace detail {
275
276template <class TensorType, class OldIndex, class NewIndex>
277struct RelabelizeIndicesOfType;
278
279template <
280 class OldIndices,
281 class NewIndices,
282 class ElementType,
283 class Dom,
284 class LayoutStridedPolicy,
285 class MemorySpace>
286struct RelabelizeIndicesOfType<
287 Tensor<ElementType, Dom, LayoutStridedPolicy, MemorySpace>,
288 OldIndices,
289 NewIndices>
290{
291 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
292 using type = Tensor<
293 ElementType,
294 typename RelabelizeIndicesInType<Dom, OldIndices, NewIndices>::type,
295 LayoutStridedPolicy,
296 MemorySpace>;
297};
298
299template <
300 class OldIndices,
301 class NewIndices,
302 std::size_t I,
303 class ElementType,
304 class... DDim,
305 class LayoutStridedPolicy,
306 class MemorySpace>
307constexpr auto RelabelizeIndicesOf(
308 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
309 old_tensor)
310{
311 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
312 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
313 return RelabelizeIndicesOf<
314 ddc::type_seq_replace_t<
315 OldIndices,
316 ddc::detail::TypeSeq<ddc::type_seq_element_t<I, OldIndices>>,
317 ddc::detail::TypeSeq<ddc::type_seq_element_t<I, NewIndices>>>,
318 NewIndices,
320 Tensor<ElementType,
321 ddc::DiscreteDomain<DDim...>,
322 LayoutStridedPolicy,
323 MemorySpace>,
324 ddc::type_seq_element_t<I, OldIndices>,
325 ddc::type_seq_element_t<I, NewIndices>>(
326 old_tensor.data_handle(),
327 typename detail::RelabelizeIndexInType<
328 ddc::DiscreteDomain<DDim...>,
329 ddc::type_seq_element_t<I, OldIndices>,
330 ddc::type_seq_element_t<I, NewIndices>>::
331 type(ddc::DiscreteDomain<typename detail::RelabelizeIndex<
332 DDim,
333 ddc::type_seq_element_t<I, OldIndices>,
334 ddc::type_seq_element_t<I, NewIndices>>::type>(
335 ddc::DiscreteElement<typename detail::RelabelizeIndex<
336 DDim,
337 ddc::type_seq_element_t<I, OldIndices>,
338 ddc::type_seq_element_t<I, NewIndices>>::type>(
339 old_tensor.domain().front().template uid<DDim>()),
340 ddc::DiscreteVector<typename detail::RelabelizeIndex<
341 DDim,
342 ddc::type_seq_element_t<I, OldIndices>,
343 ddc::type_seq_element_t<I, NewIndices>>::type>(
344 static_cast<std::size_t>(
345 old_tensor.template extent<DDim>())))...)));
346 } else {
347 return old_tensor;
348 }
349}
350
351} // namespace detail
352
353template <
354 misc::Specialization<Tensor> TensorType,
355 misc::Specialization<ddc::detail::TypeSeq> OldIndices,
356 misc::Specialization<ddc::detail::TypeSeq> NewIndices>
358 = detail::RelabelizeIndicesOfType<TensorType, OldIndices, NewIndices>::type;
359
360template <
365 Tensor tensor)
366{
367 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
368 return detail::RelabelizeIndicesOf<OldIndices, NewIndices, 0>(tensor);
369}
370
371} // namespace tensor
372
373} // namespace sil
Tensor(ddc::Chunk< ElementType, SupportType, Allocator >) -> Tensor< ElementType, SupportType, Kokkos::layout_right, typename Allocator::memory_space >
detail::RelabelizeIndicesOfType< TensorType, OldIndices, NewIndices >::type relabelize_indices_of_t
constexpr relabelize_indices_in_t< T, OldIndices, NewIndices > relabelize_indices_in(T t)
detail::RelabelizeIndexInType< T, OldIndex, NewIndex >::type relabelize_index_in_t
constexpr relabelize_index_in_t< T, OldIndex, NewIndex > relabelize_index_in(T t)
detail::RelabelizeIndexOfType< TensorType, OldIndex, NewIndex >::type relabelize_index_of_t
constexpr relabelize_index_of_t< Tensor< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace >, OldIndex, NewIndex > relabelize_index_of(Tensor< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > old_tensor)
constexpr relabelize_indices_of_t< Tensor, OldIndices, NewIndices > relabelize_indices_of(Tensor tensor)
typename detail::RelabelizeIndicesInType< T, OldIndices, NewIndices >::type relabelize_indices_in_t
The top-level namespace of SimiLie.
Definition csr.hpp:14