Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/meshes/parallel_tree.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
# Composite type that represents a NDIMS-dimensional tree (parallel version).
9
#
10
# Implements everything required for AbstractContainer.
11
#
12
# Note: The way the data structures are set up and the way most algorithms
13
# work, it is *always* assumed that
14
# a) we have a balanced tree (= at most one level difference between
15
# neighboring cells, or 2:1 rule)
16
# b) we may not have all children (= some children may not exist)
17
# c) the tree is stored depth-first
18
#
19
# However, the way the refinement/coarsening algorithms are currently
20
# implemented, we only have fully refined cells. That is, a cell either has 2^NDIMS children or
21
# no children at all (= leaf cell). This restriction is also assumed at
22
# multiple positions in the refinement/coarsening algorithms.
23
#
24
# An exception to the 2:1 rule exists for the low-level `refine_unbalanced!`
25
# function, which is required for implementing level-wise refinement in a sane
26
# way. Also, depth-first ordering *might* not be guaranteed during
27
# refinement/coarsening operations.
28
mutable struct ParallelTree{NDIMS, RealT <: Real} <: AbstractTree{NDIMS}
29
parent_ids::Vector{Int}
30
child_ids::Matrix{Int}
31
neighbor_ids::Matrix{Int}
32
levels::Vector{Int}
33
coordinates::Matrix{RealT}
34
original_cell_ids::Vector{Int}
35
mpi_ranks::Vector{Int}
36
37
capacity::Int
38
length::Int
39
dummy::Int
40
41
center_level_0::SVector{NDIMS, RealT}
42
length_level_0::RealT
43
periodicity::NTuple{NDIMS, Bool}
44
45
function ParallelTree{NDIMS, RealT}(capacity::Integer) where {NDIMS, RealT <: Real}
46
# Verify that NDIMS is an integer
47
@assert NDIMS isa Integer
48
49
# Create instance
50
t = new()
51
52
# Initialize fields with defaults
53
# Note: length as capacity + 1 is to use `capacity + 1` as temporary storage for swap operations
54
t.parent_ids = fill(typemin(Int), capacity + 1)
55
t.child_ids = fill(typemin(Int), 2^NDIMS, capacity + 1)
56
t.neighbor_ids = fill(typemin(Int), 2 * NDIMS, capacity + 1)
57
t.levels = fill(typemin(Int), capacity + 1)
58
t.coordinates = fill(convert(RealT, NaN), NDIMS, capacity + 1) # `NaN` is of type Float64
59
t.original_cell_ids = fill(typemin(Int), capacity + 1)
60
t.mpi_ranks = fill(typemin(Int), capacity + 1)
61
62
t.capacity = capacity
63
t.length = 0
64
t.dummy = capacity + 1
65
66
t.center_level_0 = SVector(ntuple(_ -> convert(RealT, NaN), NDIMS))
67
t.length_level_0 = convert(RealT, NaN)
68
69
return t
70
end
71
end
72
73
# Constructor for passing the dimension as an argument. Default datatype: Float64
74
ParallelTree(::Val{NDIMS}, args...) where {NDIMS} = ParallelTree{NDIMS, Float64}(args...)
75
76
# Create and initialize tree
77
function ParallelTree{NDIMS, RealT}(capacity::Int, center::AbstractArray{RealT},
78
length::RealT,
79
periodicity = true) where {NDIMS, RealT <: Real}
80
# Create instance
81
t = ParallelTree{NDIMS, RealT}(capacity)
82
83
# Initialize root cell
84
init!(t, center, length, periodicity)
85
86
return t
87
end
88
89
# Constructors accepting a single number as center (as opposed to an array) for 1D
90
function ParallelTree{1, RealT}(cap::Int, center::RealT, len::RealT,
91
periodicity = true) where {RealT <: Real}
92
ParallelTree{1, RealT}(cap, [center], len, periodicity)
93
end
94
function ParallelTree{1}(cap::Int, center::RealT, len::RealT,
95
periodicity = true) where {RealT <: Real}
96
ParallelTree{1, RealT}(cap, [center], len, periodicity)
97
end
98
99
# Clear tree with deleting data structures, store center and length, and create root cell
100
function init!(t::ParallelTree, center::AbstractArray{RealT}, length::RealT,
101
periodicity = true) where {RealT}
102
clear!(t)
103
104
# Set domain information
105
t.center_level_0 = center
106
t.length_level_0 = length
107
108
# Create root cell
109
t.length += 1
110
t.parent_ids[1] = 0
111
t.child_ids[:, 1] .= 0
112
t.levels[1] = 0
113
set_cell_coordinates!(t, t.center_level_0, 1)
114
t.original_cell_ids[1] = 0
115
t.mpi_ranks[1] = typemin(Int)
116
117
# Set neighbor ids: for each periodic direction, the level-0 cell is its own neighbor
118
if all(periodicity)
119
# Also catches case where periodicity = true
120
t.neighbor_ids[:, 1] .= 1
121
t.periodicity = ntuple(x -> true, ndims(t))
122
elseif !any(periodicity)
123
# Also catches case where periodicity = false
124
t.neighbor_ids[:, 1] .= 0
125
t.periodicity = ntuple(x -> false, ndims(t))
126
else
127
# Default case if periodicity is an iterable
128
for dimension in 1:ndims(t)
129
if periodicity[dimension]
130
t.neighbor_ids[2 * dimension - 1, 1] = 1
131
t.neighbor_ids[2 * dimension - 0, 1] = 1
132
else
133
t.neighbor_ids[2 * dimension - 1, 1] = 0
134
t.neighbor_ids[2 * dimension - 0, 1] = 0
135
end
136
end
137
138
t.periodicity = Tuple(periodicity)
139
end
140
end
141
142
# Convenience output for debugging
143
function Base.show(io::IO, ::MIME"text/plain", t::ParallelTree)
144
@nospecialize t # reduce precompilation time
145
146
l = t.length
147
println(io, '*'^20)
148
println(io, "t.parent_ids[1:l] = $(t.parent_ids[1:l])")
149
println(io, "transpose(t.child_ids[:, 1:l]) = $(transpose(t.child_ids[:, 1:l]))")
150
println(io,
151
"transpose(t.neighbor_ids[:, 1:l]) = $(transpose(t.neighbor_ids[:, 1:l]))")
152
println(io, "t.levels[1:l] = $(t.levels[1:l])")
153
println(io,
154
"transpose(t.coordinates[:, 1:l]) = $(transpose(t.coordinates[:, 1:l]))")
155
println(io, "t.original_cell_ids[1:l] = $(t.original_cell_ids[1:l])")
156
println(io, "t.mpi_ranks[1:l] = $(t.mpi_ranks[1:l])")
157
println(io, "t.capacity = $(t.capacity)")
158
println(io, "t.length = $(t.length)")
159
println(io, "t.dummy = $(t.dummy)")
160
println(io, "t.center_level_0 = $(t.center_level_0)")
161
println(io, "t.length_level_0 = $(t.length_level_0)")
162
println(io, '*'^20)
163
end
164
165
# Check if cell is own cell, i.e., belongs to this MPI rank
166
is_own_cell(t::ParallelTree, cell_id) = t.mpi_ranks[cell_id] == mpi_rank()
167
168
# Return an array with the ids of all leaf cells for a given rank
169
leaf_cells_by_rank(t::ParallelTree, rank) =
170
filter_leaf_cells(t) do cell_id
171
t.mpi_ranks[cell_id] == rank
172
end
173
174
# Return an array with the ids of all local leaf cells
175
local_leaf_cells(t::ParallelTree) = leaf_cells_by_rank(t, mpi_rank())
176
177
# Set information for child cell `child_id` based on parent cell `cell_id` (except neighbors)
178
function init_child!(t::ParallelTree, cell_id, child, child_id)
179
t.parent_ids[child_id] = cell_id
180
t.child_ids[child, cell_id] = child_id
181
t.child_ids[:, child_id] .= 0
182
t.levels[child_id] = t.levels[cell_id] + 1
183
set_cell_coordinates!(t,
184
child_coordinates(t, cell_coordinates(t, cell_id),
185
length_at_cell(t, cell_id), child),
186
child_id)
187
t.original_cell_ids[child_id] = 0
188
t.mpi_ranks[child_id] = t.mpi_ranks[cell_id]
189
190
return nothing
191
end
192
193
# Reset range of cells to values that are prone to cause errors as soon as they are used.
194
#
195
# Rationale: If an invalid cell is accidentally used, we want to know it as soon as possible.
196
function invalidate!(t::ParallelTree{NDIMS, RealT},
197
first::Int, last::Int) where {NDIMS, RealT <: Real}
198
@assert first > 0
199
@assert last <= t.capacity + 1
200
201
# Integer values are set to smallest negative value, floating point values to NaN
202
t.parent_ids[first:last] .= typemin(Int)
203
t.child_ids[:, first:last] .= typemin(Int)
204
t.neighbor_ids[:, first:last] .= typemin(Int)
205
t.levels[first:last] .= typemin(Int)
206
t.coordinates[:, first:last] .= convert(RealT, NaN) # `NaN` is of type Float64
207
t.original_cell_ids[first:last] .= typemin(Int)
208
t.mpi_ranks[first:last] .= typemin(Int)
209
210
return nothing
211
end
212
213
# Raw copy operation for ranges of cells.
214
#
215
# This method is used by the higher-level copy operations for AbstractContainer
216
function raw_copy!(target::ParallelTree, source::ParallelTree, first::Int, last::Int,
217
destination::Int)
218
copy_data!(target.parent_ids, source.parent_ids, first, last, destination)
219
copy_data!(target.child_ids, source.child_ids, first, last, destination,
220
n_children_per_cell(target))
221
copy_data!(target.neighbor_ids, source.neighbor_ids, first, last,
222
destination, n_directions(target))
223
copy_data!(target.levels, source.levels, first, last, destination)
224
copy_data!(target.coordinates, source.coordinates, first, last, destination,
225
ndims(target))
226
copy_data!(target.original_cell_ids, source.original_cell_ids, first, last,
227
destination)
228
copy_data!(target.mpi_ranks, source.mpi_ranks, first, last, destination)
229
end
230
231
# Reset data structures by recreating all internal storage containers and invalidating all elements
232
function reset_data_structures!(t::ParallelTree{NDIMS, RealT}) where {NDIMS,
233
RealT <: Real}
234
t.parent_ids = Vector{Int}(undef, t.capacity + 1)
235
t.child_ids = Matrix{Int}(undef, 2^NDIMS, t.capacity + 1)
236
t.neighbor_ids = Matrix{Int}(undef, 2 * NDIMS, t.capacity + 1)
237
t.levels = Vector{Int}(undef, t.capacity + 1)
238
t.coordinates = Matrix{RealT}(undef, NDIMS, t.capacity + 1)
239
t.original_cell_ids = Vector{Int}(undef, t.capacity + 1)
240
t.mpi_ranks = Vector{Int}(undef, t.capacity + 1)
241
242
invalidate!(t, 1, capacity(t) + 1)
243
end
244
end # @muladd
245
246