SimiLie
Loading...
Searching...
No Matches
relabelization.hpp
1// SPDX-FileCopyrightText: 2024 Baptiste Legouix
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include <ddc/ddc.hpp>
7
8#include "tensor_impl.hpp"
9
10namespace sil {
11
12namespace tensor {
13
14// Relabelize index without altering allocation
15namespace detail {
16template <class IndexToRelabelize, TensorIndex OldIndex, TensorIndex NewIndex>
17struct RelabelizeIndex;
18
19template <class IndexToRelabelize, TensorIndex OldIndex, TensorIndex NewIndex>
20 requires(!TensorIndex<IndexToRelabelize> || TensorNatIndex<IndexToRelabelize>)
21struct RelabelizeIndex<IndexToRelabelize, OldIndex, NewIndex>
22{
23 using type = std::
24 conditional_t<std::is_same_v<IndexToRelabelize, OldIndex>, NewIndex, IndexToRelabelize>;
25};
26
27template <
28 template <class...> class IndexToRelabelizeType,
29 TensorIndex OldIndex,
30 TensorIndex NewIndex,
31 class... Arg>
32struct RelabelizeIndex<IndexToRelabelizeType<Arg...>, OldIndex, NewIndex>
33{
34 using type = std::conditional_t<
35 std::is_same_v<IndexToRelabelizeType<Arg...>, OldIndex>,
36 NewIndex,
37 IndexToRelabelizeType<typename RelabelizeIndex<Arg, OldIndex, NewIndex>::type...>>;
38};
39
40template <class T, class OldIndex, class NewIndex>
41struct RelabelizeIndexInType;
42
43template <template <class...> class T, class... DDim, class OldIndex, class NewIndex>
44struct RelabelizeIndexInType<T<DDim...>, OldIndex, NewIndex>
45{
46 using type = T<typename RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>;
47};
48
49} // namespace detail
50
51template <class T, TensorIndex OldIndex, TensorIndex NewIndex>
52using relabelize_index_in_t = detail::RelabelizeIndexInType<T, OldIndex, NewIndex>::type;
53
54namespace detail {
55
56template <class OldIndex, class NewIndex>
57struct RelabelizeIndexIn
58{
59 template <class... DDim>
60 static constexpr auto run(ddc::DiscreteElement<DDim...> elem)
61 {
62 return ddc::DiscreteElement<
63 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>(
64 elem.template uid<DDim>()...);
65 }
66
67 template <class... DDim>
68 static constexpr auto run(ddc::DiscreteVector<DDim...> vect)
69 {
70 return ddc::DiscreteVector<
71 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type...>(
72 static_cast<std::size_t>(vect.template get<DDim>())...);
73 }
74
75 template <class... DDim>
76 static constexpr auto run(ddc::DiscreteDomain<DDim...> dom)
77 {
78 return relabelize_index_in_t<ddc::DiscreteDomain<DDim...>, OldIndex, NewIndex>(
81 }
82};
83
84} // namespace detail
85
86template <class OldIndex, class NewIndex, class T>
88{
89 return detail::RelabelizeIndexIn<OldIndex, NewIndex>::run(t);
90}
91
92namespace detail {
93template <class TensorType, class OldIndex, class NewIndex>
94struct RelabelizeIndexOfType;
95
96template <
97 class OldIndex,
98 class NewIndex,
99 class ElementType,
100 class Dom,
101 class LayoutStridedPolicy,
102 class MemorySpace>
103struct RelabelizeIndexOfType<
104 Tensor<ElementType, Dom, LayoutStridedPolicy, MemorySpace>,
105 OldIndex,
106 NewIndex>
107{
108 using type = Tensor<
109 ElementType,
110 typename RelabelizeIndexInType<Dom, OldIndex, NewIndex>::type,
111 LayoutStridedPolicy,
112 MemorySpace>;
113};
114
115} // namespace detail
116
117template <misc::Specialization<Tensor> TensorType, TensorIndex OldIndex, TensorIndex NewIndex>
118using relabelize_index_of_t = detail::RelabelizeIndexOfType<TensorType, OldIndex, NewIndex>::type;
119
120template <
121 TensorIndex OldIndex,
122 TensorIndex NewIndex,
123 class ElementType,
124 class... DDim,
125 class LayoutStridedPolicy,
126 class MemorySpace>
127constexpr relabelize_index_of_t<
128 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>,
129 OldIndex,
130 NewIndex>
132 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
133 old_tensor)
134{
136 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>,
137 OldIndex,
138 NewIndex>(
139 old_tensor.data_handle(),
140 typename detail::RelabelizeIndexInType<
141 ddc::DiscreteDomain<DDim...>,
142 OldIndex,
143 NewIndex>::
144 type(ddc::DiscreteDomain<
145 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::type>(
146 ddc::DiscreteElement<
147 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::
148 type>(old_tensor.domain().front().template uid<DDim>()),
149 ddc::DiscreteVector<
150 typename detail::RelabelizeIndex<DDim, OldIndex, NewIndex>::
151 type>(static_cast<std::size_t>(
152 old_tensor.template extent<DDim>())))...));
153}
154
155namespace detail {
156template <class IndexToRelabelize, class OldIndices, class NewIndices>
157struct RelabelizeIndices;
158
159template <class IndexToRelabelize>
160struct RelabelizeIndices<IndexToRelabelize, ddc::detail::TypeSeq<>, ddc::detail::TypeSeq<>>
161{
162 using type = IndexToRelabelize;
163};
164
165template <
166 class IndexToRelabelize,
167 class HeadOldIndex,
168 class... TailOldIndex,
169 class HeadNewIndex,
170 class... TailNewIndex>
171struct RelabelizeIndices<
172 IndexToRelabelize,
173 ddc::detail::TypeSeq<HeadOldIndex, TailOldIndex...>,
174 ddc::detail::TypeSeq<HeadNewIndex, TailNewIndex...>>
175{
176 static_assert(sizeof...(TailOldIndex) == sizeof...(TailNewIndex));
177 using type = std::conditional_t<
178 (sizeof...(TailOldIndex) > 0),
179 typename RelabelizeIndices<
180 typename RelabelizeIndex<IndexToRelabelize, HeadOldIndex, HeadNewIndex>::type,
181 ddc::detail::TypeSeq<TailOldIndex...>,
182 ddc::detail::TypeSeq<TailNewIndex...>>::type,
183 typename RelabelizeIndex<IndexToRelabelize, HeadOldIndex, HeadNewIndex>::type>;
184};
185
186template <class T, class OldIndices, class NewIndices>
187struct RelabelizeIndicesInType;
188
189template <class T>
190struct RelabelizeIndicesInType<T, ddc::detail::TypeSeq<>, ddc::detail::TypeSeq<>>
191{
192 using type = T;
193};
194
195template <
196 class T,
197 class HeadOldIndex,
198 class... TailOldIndex,
199 class HeadNewIndex,
200 class... TailNewIndex>
201struct RelabelizeIndicesInType<
202 T,
203 ddc::detail::TypeSeq<HeadOldIndex, TailOldIndex...>,
204 ddc::detail::TypeSeq<HeadNewIndex, TailNewIndex...>>
205{
206 static_assert(sizeof...(TailOldIndex) == sizeof...(TailNewIndex));
207 using type = typename RelabelizeIndicesInType<
209 ddc::detail::TypeSeq<TailOldIndex...>,
210 ddc::detail::TypeSeq<TailNewIndex...>>::type;
211};
212
213} // namespace detail
214
215template <class T, class OldIndices, class NewIndices>
217 typename detail::RelabelizeIndicesInType<T, OldIndices, NewIndices>::type;
218
219namespace detail {
220
221template <class OldIndices, class NewIndices, std::size_t I = 0>
222struct RelabelizeIndicesIn
223{
224 template <class... DDim>
225 static constexpr auto run(ddc::DiscreteElement<DDim...> elem)
226 {
227 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
228 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
230 ddc::type_seq_element_t<I, OldIndices>,
231 ddc::type_seq_element_t<I, NewIndices>>(elem));
232 } else {
233 return elem;
234 }
235 }
236
237 template <class... DDim>
238 static constexpr auto run(ddc::DiscreteVector<DDim...> vect)
239 {
240 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
241 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
243 ddc::type_seq_element_t<I, OldIndices>,
244 ddc::type_seq_element_t<I, NewIndices>>(vect));
245 } else {
246 return vect;
247 }
248 }
249
250 template <class... DDim>
251 static constexpr auto run(ddc::DiscreteDomain<DDim...> dom)
252 {
253 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
254 return RelabelizeIndicesIn<OldIndices, NewIndices, I + 1>::run(
256 ddc::type_seq_element_t<I, OldIndices>,
257 ddc::type_seq_element_t<I, NewIndices>>(dom));
258 } else {
259 return dom;
260 }
261 }
262};
263
264} // namespace detail
265
266template <class OldIndices, class NewIndices, class T>
268{
269 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
270 return detail::RelabelizeIndicesIn<OldIndices, NewIndices>::run(t);
271}
272
273namespace detail {
274
275template <class TensorType, class OldIndex, class NewIndex>
276struct RelabelizeIndicesOfType;
277
278template <
279 class OldIndices,
280 class NewIndices,
281 class ElementType,
282 class Dom,
283 class LayoutStridedPolicy,
284 class MemorySpace>
285struct RelabelizeIndicesOfType<
286 Tensor<ElementType, Dom, LayoutStridedPolicy, MemorySpace>,
287 OldIndices,
288 NewIndices>
289{
290 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
291 using type = Tensor<
292 ElementType,
293 typename RelabelizeIndicesInType<Dom, OldIndices, NewIndices>::type,
294 LayoutStridedPolicy,
295 MemorySpace>;
296};
297
298template <
299 class OldIndices,
300 class NewIndices,
301 std::size_t I,
302 class ElementType,
303 class... DDim,
304 class LayoutStridedPolicy,
305 class MemorySpace>
306constexpr auto RelabelizeIndicesOf(
307 Tensor<ElementType, ddc::DiscreteDomain<DDim...>, LayoutStridedPolicy, MemorySpace>
308 old_tensor)
309{
310 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
311 if constexpr (I != ddc::type_seq_size_v<OldIndices>) {
312 return RelabelizeIndicesOf<
313 ddc::type_seq_replace_t<
314 OldIndices,
315 ddc::detail::TypeSeq<ddc::type_seq_element_t<I, OldIndices>>,
316 ddc::detail::TypeSeq<ddc::type_seq_element_t<I, NewIndices>>>,
317 NewIndices,
319 Tensor<ElementType,
320 ddc::DiscreteDomain<DDim...>,
321 LayoutStridedPolicy,
322 MemorySpace>,
323 ddc::type_seq_element_t<I, OldIndices>,
324 ddc::type_seq_element_t<I, NewIndices>>(
325 old_tensor.data_handle(),
326 typename detail::RelabelizeIndexInType<
327 ddc::DiscreteDomain<DDim...>,
328 ddc::type_seq_element_t<I, OldIndices>,
329 ddc::type_seq_element_t<I, NewIndices>>::
330 type(ddc::DiscreteDomain<typename detail::RelabelizeIndex<
331 DDim,
332 ddc::type_seq_element_t<I, OldIndices>,
333 ddc::type_seq_element_t<I, NewIndices>>::type>(
334 ddc::DiscreteElement<typename detail::RelabelizeIndex<
335 DDim,
336 ddc::type_seq_element_t<I, OldIndices>,
337 ddc::type_seq_element_t<I, NewIndices>>::type>(
338 old_tensor.domain().front().template uid<DDim>()),
339 ddc::DiscreteVector<typename detail::RelabelizeIndex<
340 DDim,
341 ddc::type_seq_element_t<I, OldIndices>,
342 ddc::type_seq_element_t<I, NewIndices>>::type>(
343 static_cast<std::size_t>(
344 old_tensor.template extent<DDim>())))...)));
345 } else {
346 return old_tensor;
347 }
348}
349
350} // namespace detail
351
352template <
353 misc::Specialization<Tensor> TensorType,
354 misc::Specialization<ddc::detail::TypeSeq> OldIndices,
355 misc::Specialization<ddc::detail::TypeSeq> NewIndices>
357 = detail::RelabelizeIndicesOfType<TensorType, OldIndices, NewIndices>::type;
358
359template <
364 Tensor tensor)
365{
366 static_assert(ddc::type_seq_size_v<OldIndices> == ddc::type_seq_size_v<NewIndices>);
367 return detail::RelabelizeIndicesOf<OldIndices, NewIndices, 0>(tensor);
368}
369
370} // namespace tensor
371
372} // namespace sil
Tensor(ddc::Chunk< ElementType, SupportType, Allocator >) -> Tensor< ElementType, SupportType, Kokkos::layout_right, typename Allocator::memory_space >
detail::RelabelizeIndicesOfType< TensorType, OldIndices, NewIndices >::type relabelize_indices_of_t
constexpr relabelize_indices_in_t< T, OldIndices, NewIndices > relabelize_indices_in(T t)
detail::RelabelizeIndexInType< T, OldIndex, NewIndex >::type relabelize_index_in_t
constexpr relabelize_index_in_t< T, OldIndex, NewIndex > relabelize_index_in(T t)
detail::RelabelizeIndexOfType< TensorType, OldIndex, NewIndex >::type relabelize_index_of_t
constexpr relabelize_index_of_t< Tensor< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace >, OldIndex, NewIndex > relabelize_index_of(Tensor< ElementType, ddc::DiscreteDomain< DDim... >, LayoutStridedPolicy, MemorySpace > old_tensor)
constexpr relabelize_indices_of_t< Tensor, OldIndices, NewIndices > relabelize_indices_of(Tensor tensor)
typename detail::RelabelizeIndicesInType< T, OldIndices, NewIndices >::type relabelize_indices_in_t
The top-level namespace of SimiLie.
Definition csr.hpp:14