Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/amr.jl
2802 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
AMRCallback(semi, controller [,adaptor=AdaptorAMR(semi)];
10
interval,
11
adapt_initial_condition=true,
12
adapt_initial_condition_only_refine=true,
13
dynamic_load_balancing=true)
14
15
Performs adaptive mesh refinement (AMR) every `interval` time steps
16
for a given semidiscretization `semi` using the chosen `controller`.
17
"""
18
struct AMRCallback{Controller, Adaptor, Cache}
19
controller::Controller
20
interval::Int
21
adapt_initial_condition::Bool
22
adapt_initial_condition_only_refine::Bool
23
dynamic_load_balancing::Bool
24
adaptor::Adaptor
25
amr_cache::Cache
26
end
27
28
function AMRCallback(semi, controller, adaptor;
29
interval,
30
adapt_initial_condition = true,
31
adapt_initial_condition_only_refine = true,
32
dynamic_load_balancing = true)
33
# check arguments
34
if !(interval isa Integer && interval >= 0)
35
throw(ArgumentError("`interval` must be a non-negative integer (provided `interval = $interval`)"))
36
end
37
38
# AMR every `interval` time steps, but not after the final step
39
# With error-based step size control, some steps can be rejected. Thus,
40
# `integrator.iter >= integrator.stats.naccept`
41
# (total #steps) (#accepted steps)
42
# We need to check the number of accepted steps since callbacks are not
43
# activated after a rejected step.
44
if interval > 0
45
condition = (u, t, integrator) -> ((integrator.stats.naccept % interval == 0) &&
46
!(integrator.stats.naccept == 0 &&
47
integrator.iter > 0) &&
48
!isfinished(integrator))
49
else # disable the AMR callback except possibly for initial refinement during initialization
50
condition = (u, t, integrator) -> false
51
end
52
53
to_refine = Int[]
54
to_coarsen = Int[]
55
amr_cache = (; to_refine, to_coarsen)
56
57
amr_callback = AMRCallback{typeof(controller), typeof(adaptor), typeof(amr_cache)}(controller,
58
interval,
59
adapt_initial_condition,
60
adapt_initial_condition_only_refine,
61
dynamic_load_balancing,
62
adaptor,
63
amr_cache)
64
65
return DiscreteCallback(condition, amr_callback,
66
save_positions = (false, false),
67
initialize = initialize!)
68
end
69
70
function AMRCallback(semi, controller; kwargs...)
71
adaptor = AdaptorAMR(semi)
72
return AMRCallback(semi, controller, adaptor; kwargs...)
73
end
74
75
function AdaptorAMR(semi; kwargs...)
76
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
77
return AdaptorAMR(mesh, solver; kwargs...)
78
end
79
80
# TODO: Taal bikeshedding, implement a method with less information and the signature
81
# function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:AMRCallback})
82
# @nospecialize cb # reduce precompilation time
83
#
84
# amr_callback = cb.affect!
85
# print(io, "AMRCallback")
86
# end
87
function Base.show(io::IO, mime::MIME"text/plain",
88
cb::DiscreteCallback{<:Any, <:AMRCallback})
89
@nospecialize cb # reduce precompilation time
90
91
if get(io, :compact, false)
92
show(io, cb)
93
else
94
amr_callback = cb.affect!
95
96
summary_header(io, "AMRCallback")
97
summary_line(io, "controller", amr_callback.controller |> typeof |> nameof)
98
show(increment_indent(io), mime, amr_callback.controller)
99
summary_line(io, "interval", amr_callback.interval)
100
summary_line(io, "adapt IC",
101
amr_callback.adapt_initial_condition ? "yes" : "no")
102
if amr_callback.adapt_initial_condition
103
summary_line(io, "│ only refine",
104
amr_callback.adapt_initial_condition_only_refine ? "yes" :
105
"no")
106
end
107
summary_footer(io)
108
end
109
end
110
111
# The function below is used to control the output depending on whether or not AMR is enabled.
112
"""
113
uses_amr(callback)
114
115
Checks whether the provided callback or `CallbackSet` is an [`AMRCallback`](@ref)
116
or contains one.
117
"""
118
uses_amr(cb) = false
119
function uses_amr(cb::DiscreteCallback{Condition, Affect!}) where {Condition,
120
Affect! <:
121
AMRCallback}
122
return true
123
end
124
uses_amr(callbacks::CallbackSet) = mapreduce(uses_amr, |, callbacks.discrete_callbacks)
125
126
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
127
amr_callback::AMRCallback; kwargs...)
128
return get_element_variables!(element_variables, u, mesh, equations, solver, cache,
129
amr_callback.controller, amr_callback; kwargs...)
130
end
131
132
function initialize!(cb::DiscreteCallback{Condition, Affect!}, u, t,
133
integrator) where {Condition, Affect! <: AMRCallback}
134
amr_callback = cb.affect!
135
semi = integrator.p
136
137
@trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
138
# iterate until mesh does not change anymore
139
has_changed = amr_callback(integrator,
140
only_refine = amr_callback.adapt_initial_condition_only_refine)
141
iterations = 1
142
while has_changed
143
compute_coefficients!(integrator.u, t, semi)
144
u_modified!(integrator, true)
145
has_changed = amr_callback(integrator,
146
only_refine = amr_callback.adapt_initial_condition_only_refine)
147
iterations = iterations + 1
148
allowed_max_iterations = max(10, max_level(amr_callback.controller))
149
if iterations > allowed_max_iterations
150
@warn "AMR for initial condition did not settle within $(allowed_max_iterations) iterations!\n" *
151
"Consider adjusting thresholds or setting `adapt_initial_condition_only_refine`."
152
break
153
end
154
end
155
156
# Update initial state integrals of analysis callback if it exists
157
# See https://github.com/trixi-framework/Trixi.jl/issues/2536 for more information.
158
index = findfirst(cb -> cb.affect! isa AnalysisCallback,
159
integrator.opts.callback.discrete_callbacks)
160
if !isnothing(index)
161
analysis_callback = integrator.opts.callback.discrete_callbacks[index].affect!
162
163
initial_state_integrals = integrate(integrator.u, semi)
164
analysis_callback.initial_state_integrals = initial_state_integrals
165
end
166
end
167
168
return nothing
169
end
170
171
# TODO: Taal remove?
172
# function (cb::DiscreteCallback{Condition,Affect!})(ode::ODEProblem) where {Condition, Affect!<:AMRCallback}
173
# amr_callback = cb.affect!
174
# semi = ode.p
175
176
# @trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
177
# # iterate until mesh does not change anymore
178
# has_changed = true
179
# while has_changed
180
# has_changed = amr_callback(ode.u0, semi,
181
# only_refine=amr_callback.adapt_initial_condition_only_refine)
182
# compute_coefficients!(ode.u0, ode.tspan[1], semi)
183
# end
184
# end
185
186
# return nothing
187
# end
188
189
function (amr_callback::AMRCallback)(integrator; kwargs...)
190
u_ode = integrator.u
191
semi = integrator.p
192
193
@trixi_timeit timer() "AMR" begin
194
has_changed = amr_callback(u_ode, semi,
195
integrator.t, integrator.iter; kwargs...)
196
if has_changed
197
resize!(integrator, length(u_ode))
198
u_modified!(integrator, true)
199
end
200
end
201
202
return has_changed
203
end
204
205
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
206
semi::SemidiscretizationHyperbolic,
207
t, iter;
208
kwargs...)
209
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
210
# it when doing AMR while still dispatching on the `mesh` etc.
211
return amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi, t, iter;
212
kwargs...)
213
end
214
215
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
216
semi::SemidiscretizationHyperbolicParabolic,
217
t, iter;
218
kwargs...)
219
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
220
# it when doing AMR while still dispatching on the `mesh` etc.
221
return amr_callback(u_ode, mesh_equations_solver_cache(semi)...,
222
semi.cache_parabolic,
223
semi, t, iter; kwargs...)
224
end
225
226
# `passive_args` is currently used for Euler with self-gravity to adapt the gravity solver
227
# passively without querying its indicator, based on the assumption that both solvers use
228
# the same mesh. That's a hack and should be improved in the future once we have more examples
229
# and a better understanding of such a coupling.
230
# `passive_args` is expected to be an iterable of `Tuple`s of the form
231
# `(p_u_ode, p_mesh, p_equations, p_dg, p_cache)`.
232
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
233
equations, dg::DG, cache, semi,
234
t, iter;
235
only_refine = false, only_coarsen = false,
236
passive_args = ())
237
@unpack controller, adaptor = amr_callback
238
239
u = wrap_array(u_ode, mesh, equations, dg, cache)
240
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
241
t = t, iter = iter)
242
243
if mpi_isparallel()
244
# Collect lambda for all elements
245
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
246
# Use parent because n_elements_by_rank is an OffsetArray
247
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
248
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
249
lambda = lambda_global
250
end
251
252
leaf_cell_ids = leaf_cells(mesh.tree)
253
@boundscheck begin
254
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
255
end
256
257
@unpack to_refine, to_coarsen = amr_callback.amr_cache
258
empty!(to_refine)
259
empty!(to_coarsen)
260
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
261
for element in eachindex(lambda)
262
controller_value = lambda[element]
263
if controller_value > 0
264
push!(to_refine, leaf_cell_ids[element])
265
elseif controller_value < 0
266
push!(to_coarsen, leaf_cell_ids[element])
267
end
268
end
269
270
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
271
# refine mesh
272
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
273
to_refine)
274
275
# Find all indices of elements whose cell ids are in refined_original_cells
276
elements_to_refine = findall(in(refined_original_cells),
277
cache.elements.cell_ids)
278
279
# refine solver
280
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
281
cache, elements_to_refine)
282
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
283
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
284
p_equations, p_dg, p_cache,
285
elements_to_refine)
286
end
287
else
288
# If there is nothing to refine, create empty array for later use
289
refined_original_cells = Int[]
290
end
291
292
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
293
# Since the cells may have been shifted due to refinement, first we need to
294
# translate the old cell ids to the new cell ids
295
if !isempty(to_coarsen)
296
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
297
end
298
299
# Next, determine the parent cells from which the fine cells are to be
300
# removed, since these are needed for the coarsen! function. However, since
301
# we only want to coarsen if *all* child cells are marked for coarsening,
302
# we count the coarsening indicators for each parent cell and only coarsen
303
# if all children are marked as such (i.e., where the count is 2^ndims). At
304
# the same time, check if a cell is marked for coarsening even though it is
305
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
306
# smoothing during the preceding refinement operation.
307
parents_to_coarsen = zeros(Int, length(mesh.tree))
308
for cell_id in to_coarsen
309
# If cell has no parent, it cannot be coarsened
310
if !has_parent(mesh.tree, cell_id)
311
continue
312
end
313
314
# If cell is not leaf (anymore), it cannot be coarsened
315
if !is_leaf(mesh.tree, cell_id)
316
continue
317
end
318
319
# Increase count for parent cell
320
parent_id = mesh.tree.parent_ids[cell_id]
321
parents_to_coarsen[parent_id] += 1
322
end
323
324
# Extract only those parent cells for which all children should be coarsened
325
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
326
327
# Finally, coarsen mesh
328
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
329
to_coarsen)
330
331
# Convert coarsened parent cell ids to the list of child cell ids that have
332
# been removed, since this is the information that is expected by the solver
333
removed_child_cells = zeros(Int,
334
n_children_per_cell(mesh.tree) *
335
length(coarsened_original_cells))
336
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
337
for child in 1:n_children_per_cell(mesh.tree)
338
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
339
child
340
end
341
end
342
343
# Find all indices of elements whose cell ids are in removed_child_cells
344
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
345
346
# coarsen solver
347
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
348
cache, elements_to_remove)
349
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
350
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
351
p_equations, p_dg, p_cache,
352
elements_to_remove)
353
end
354
else
355
# If there is nothing to coarsen, create empty array for later use
356
coarsened_original_cells = Int[]
357
end
358
359
# Store whether there were any cells coarsened or refined
360
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
361
if has_changed # TODO: Taal decide, where shall we set this?
362
# don't set it to has_changed since there can be changes from earlier calls
363
mesh.unsaved_changes = true
364
end
365
366
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
367
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
368
@trixi_timeit timer() "dynamic load balancing" begin
369
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
370
371
partition!(mesh)
372
373
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
374
end
375
end
376
377
# Return true if there were any cells coarsened or refined, otherwise false
378
return has_changed
379
end
380
381
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
382
equations, dg::DG,
383
cache, cache_parabolic,
384
semi::SemidiscretizationHyperbolicParabolic,
385
t, iter;
386
only_refine = false, only_coarsen = false)
387
@unpack controller, adaptor = amr_callback
388
389
u = wrap_array(u_ode, mesh, equations, dg, cache)
390
# Indicator kept based on hyperbolic variables
391
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
392
t = t, iter = iter)
393
394
if mpi_isparallel()
395
error("MPI has not been verified yet for parabolic AMR")
396
397
# Collect lambda for all elements
398
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
399
# Use parent because n_elements_by_rank is an OffsetArray
400
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
401
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
402
lambda = lambda_global
403
end
404
405
leaf_cell_ids = leaf_cells(mesh.tree)
406
@boundscheck begin
407
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
408
end
409
410
@unpack to_refine, to_coarsen = amr_callback.amr_cache
411
empty!(to_refine)
412
empty!(to_coarsen)
413
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
414
for element in eachindex(lambda)
415
controller_value = lambda[element]
416
if controller_value > 0
417
push!(to_refine, leaf_cell_ids[element])
418
elseif controller_value < 0
419
push!(to_coarsen, leaf_cell_ids[element])
420
end
421
end
422
423
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
424
# refine mesh
425
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
426
to_refine)
427
428
# Find all indices of elements whose cell ids are in refined_original_cells
429
# Note: This assumes same indices for hyperbolic and parabolic part.
430
elements_to_refine = findall(in(refined_original_cells),
431
cache.elements.cell_ids)
432
433
# refine solver
434
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
435
cache, cache_parabolic,
436
elements_to_refine)
437
else
438
# If there is nothing to refine, create empty array for later use
439
refined_original_cells = Int[]
440
end
441
442
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
443
# Since the cells may have been shifted due to refinement, first we need to
444
# translate the old cell ids to the new cell ids
445
if !isempty(to_coarsen)
446
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
447
end
448
449
# Next, determine the parent cells from which the fine cells are to be
450
# removed, since these are needed for the coarsen! function. However, since
451
# we only want to coarsen if *all* child cells are marked for coarsening,
452
# we count the coarsening indicators for each parent cell and only coarsen
453
# if all children are marked as such (i.e., where the count is 2^ndims). At
454
# the same time, check if a cell is marked for coarsening even though it is
455
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
456
# smoothing during the preceding refinement operation.
457
parents_to_coarsen = zeros(Int, length(mesh.tree))
458
for cell_id in to_coarsen
459
# If cell has no parent, it cannot be coarsened
460
if !has_parent(mesh.tree, cell_id)
461
continue
462
end
463
464
# If cell is not leaf (anymore), it cannot be coarsened
465
if !is_leaf(mesh.tree, cell_id)
466
continue
467
end
468
469
# Increase count for parent cell
470
parent_id = mesh.tree.parent_ids[cell_id]
471
parents_to_coarsen[parent_id] += 1
472
end
473
474
# Extract only those parent cells for which all children should be coarsened
475
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
476
477
# Finally, coarsen mesh
478
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
479
to_coarsen)
480
481
# Convert coarsened parent cell ids to the list of child cell ids that have
482
# been removed, since this is the information that is expected by the solver
483
removed_child_cells = zeros(Int,
484
n_children_per_cell(mesh.tree) *
485
length(coarsened_original_cells))
486
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
487
for child in 1:n_children_per_cell(mesh.tree)
488
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
489
child
490
end
491
end
492
493
# Find all indices of elements whose cell ids are in removed_child_cells
494
# Note: This assumes same indices for hyperbolic and parabolic part.
495
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
496
497
# coarsen solver
498
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
499
cache, cache_parabolic,
500
elements_to_remove)
501
else
502
# If there is nothing to coarsen, create empty array for later use
503
coarsened_original_cells = Int[]
504
end
505
506
# Store whether there were any cells coarsened or refined
507
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
508
if has_changed # TODO: Taal decide, where shall we set this?
509
# don't set it to has_changed since there can be changes from earlier calls
510
mesh.unsaved_changes = true
511
end
512
513
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
514
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
515
error("MPI has not been verified yet for parabolic AMR")
516
517
@trixi_timeit timer() "dynamic load balancing" begin
518
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
519
520
partition!(mesh)
521
522
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
523
end
524
end
525
526
# Return true if there were any cells coarsened or refined, otherwise false
527
return has_changed
528
end
529
530
# Copy controller values to quad user data storage, will be called below
531
function copy_to_quad_iter_volume(info, user_data)
532
info_pw = PointerWrapper(info)
533
534
# Load tree from global trees array, one-based indexing
535
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
536
# Quadrant numbering offset of this quadrant
537
offset = tree_pw.quadrants_offset[]
538
# Global quad ID
539
quad_id = offset + info_pw.quadid[]
540
541
# Access user_data = lambda
542
user_data_pw = PointerWrapper(Int, user_data)
543
# Load controller_value = lambda[quad_id + 1]
544
controller_value = user_data_pw[quad_id + 1]
545
546
# Access quadrant's user data ([global quad ID, controller_value])
547
quad_data_pw = PointerWrapper(Int, info_pw.quad.p.user_data[])
548
# Save controller value to quadrant's user data.
549
quad_data_pw[2] = controller_value
550
551
return nothing
552
end
553
554
# specialized callback which includes the `cache_parabolic` argument
555
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
556
equations, dg::DG, cache, cache_parabolic,
557
semi,
558
t, iter;
559
only_refine = false, only_coarsen = false,
560
passive_args = ())
561
@unpack controller, adaptor = amr_callback
562
563
u = wrap_array(u_ode, mesh, equations, dg, cache)
564
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
565
t = t, iter = iter)
566
567
@boundscheck begin
568
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
569
end
570
571
# Copy controller value of each quad to the quad's user data storage
572
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
573
574
# The pointer to lambda will be interpreted as Ptr{Int} below
575
@assert lambda isa Vector{Int}
576
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
577
578
@trixi_timeit timer() "refine" if !only_coarsen
579
# Refine mesh
580
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
581
582
# Refine solver
583
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
584
cache, cache_parabolic,
585
refined_original_cells)
586
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
587
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
588
p_equations,
589
p_dg, p_cache,
590
refined_original_cells)
591
end
592
else
593
# If there is nothing to refine, create empty array for later use
594
refined_original_cells = Int[]
595
end
596
597
@trixi_timeit timer() "coarsen" if !only_refine
598
# Coarsen mesh
599
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
600
601
# coarsen solver
602
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
603
cache, cache_parabolic,
604
coarsened_original_cells)
605
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
606
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
607
p_equations,
608
p_dg, p_cache,
609
coarsened_original_cells)
610
end
611
else
612
# If there is nothing to coarsen, create empty array for later use
613
coarsened_original_cells = Int[]
614
end
615
616
# Store whether there were any cells coarsened or refined and perform load balancing
617
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
618
# Check if mesh changed on other processes
619
if mpi_isparallel()
620
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
621
end
622
623
if has_changed # TODO: Taal decide, where shall we set this?
624
# don't set it to has_changed since there can be changes from earlier calls
625
mesh.unsaved_changes = true
626
627
if mpi_isparallel() && amr_callback.dynamic_load_balancing
628
@trixi_timeit timer() "dynamic load balancing" begin
629
global_first_quadrant = unsafe_wrap(Array,
630
unsafe_load(mesh.p4est).global_first_quadrant,
631
mpi_nranks() + 1)
632
old_global_first_quadrant = copy(global_first_quadrant)
633
partition!(mesh)
634
rebalance_solver!(u_ode, mesh, equations, dg, cache,
635
old_global_first_quadrant)
636
end
637
end
638
639
reinitialize_boundaries!(semi.boundary_conditions, cache)
640
# if the semidiscretization also stores parabolic boundary conditions,
641
# reinitialize them after each refinement step as well.
642
if hasproperty(semi, :boundary_conditions_parabolic)
643
reinitialize_boundaries!(semi.boundary_conditions_parabolic, cache)
644
end
645
end
646
647
# Return true if there were any cells coarsened or refined, otherwise false
648
return has_changed
649
end
650
651
# 2D
652
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{2})
653
@cfunction(copy_to_quad_iter_volume, Cvoid,
654
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
655
end
656
# 3D
657
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{3})
658
@cfunction(copy_to_quad_iter_volume, Cvoid,
659
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
660
end
661
662
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
663
equations, dg::DG, cache, semi,
664
t, iter;
665
only_refine = false, only_coarsen = false,
666
passive_args = ())
667
@unpack controller, adaptor = amr_callback
668
669
u = wrap_array(u_ode, mesh, equations, dg, cache)
670
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
671
t = t, iter = iter)
672
673
@boundscheck begin
674
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
675
end
676
677
# Copy controller value of each quad to the quad's user data storage
678
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
679
680
# The pointer to lambda will be interpreted as Ptr{Int} above
681
@assert lambda isa Vector{Int}
682
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
683
684
@trixi_timeit timer() "refine" if !only_coarsen
685
# Refine mesh
686
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
687
688
# Refine solver
689
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
690
cache,
691
refined_original_cells)
692
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
693
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
694
p_equations,
695
p_dg, p_cache,
696
refined_original_cells)
697
end
698
else
699
# If there is nothing to refine, create empty array for later use
700
refined_original_cells = Int[]
701
end
702
703
@trixi_timeit timer() "coarsen" if !only_refine
704
# Coarsen mesh
705
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
706
707
# coarsen solver
708
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
709
cache,
710
coarsened_original_cells)
711
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
712
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
713
p_equations,
714
p_dg, p_cache,
715
coarsened_original_cells)
716
end
717
else
718
# If there is nothing to coarsen, create empty array for later use
719
coarsened_original_cells = Int[]
720
end
721
722
# Store whether there were any cells coarsened or refined and perform load balancing
723
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
724
# Check if mesh changed on other processes
725
if mpi_isparallel()
726
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
727
end
728
729
if has_changed # TODO: Taal decide, where shall we set this?
730
# don't set it to has_changed since there can be changes from earlier calls
731
mesh.unsaved_changes = true
732
733
if mpi_isparallel() && amr_callback.dynamic_load_balancing
734
@trixi_timeit timer() "dynamic load balancing" begin
735
global_first_quadrant = unsafe_wrap(Array,
736
unsafe_load(mesh.p4est).global_first_quadrant,
737
mpi_nranks() + 1)
738
old_global_first_quadrant = copy(global_first_quadrant)
739
partition!(mesh)
740
rebalance_solver!(u_ode, mesh, equations, dg, cache,
741
old_global_first_quadrant)
742
end
743
end
744
745
reinitialize_boundaries!(semi.boundary_conditions, cache)
746
end
747
748
# Return true if there were any cells coarsened or refined, otherwise false
749
return has_changed
750
end
751
752
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::T8codeMesh,
753
equations, dg::DG, cache, semi,
754
t, iter;
755
only_refine = false, only_coarsen = false,
756
passive_args = ())
757
has_changed = false
758
759
@unpack controller, adaptor = amr_callback
760
761
u = wrap_array(u_ode, mesh, equations, dg, cache)
762
indicators = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg,
763
cache, t = t, iter = iter)
764
765
if only_coarsen
766
indicators[indicators .> 0] .= 0
767
end
768
769
if only_refine
770
indicators[indicators .< 0] .= 0
771
end
772
773
@boundscheck begin
774
@assert axes(indicators)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(indicators))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
775
end
776
777
@trixi_timeit timer() "adapt" begin
778
difference = @trixi_timeit timer() "mesh" trixi_t8_adapt!(mesh, indicators)
779
780
# Store whether there were any cells coarsened or refined and perform load balancing.
781
has_changed = any(difference .!= 0)
782
783
# Check if mesh changed on other processes
784
if mpi_isparallel()
785
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
786
end
787
788
if has_changed
789
@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
790
cache, difference)
791
end
792
end
793
794
if has_changed
795
if mpi_isparallel() && amr_callback.dynamic_load_balancing
796
@trixi_timeit timer() "dynamic load balancing" begin
797
old_global_first_element_ids = get_global_first_element_ids(mesh)
798
partition!(mesh)
799
rebalance_solver!(u_ode, mesh, equations, dg, cache,
800
old_global_first_element_ids)
801
end
802
end
803
804
reinitialize_boundaries!(semi.boundary_conditions, cache)
805
end
806
807
mesh.unsaved_changes |= has_changed
808
809
# Return true if there were any cells coarsened or refined, otherwise false.
810
return has_changed
811
end
812
813
function reinitialize_boundaries!(boundary_conditions::UnstructuredSortedBoundaryTypes,
814
cache)
815
# Reinitialize boundary types container because boundaries may have changed.
816
return initialize!(boundary_conditions, cache)
817
end
818
819
function reinitialize_boundaries!(boundary_conditions, cache)
820
return boundary_conditions
821
end
822
823
# After refining cells, shift original cell ids to match new locations
824
# Note: Assumes sorted lists of original and refined cell ids!
825
# Note: `mesh` is only required to extract ndims
826
function original2refined(original_cell_ids, refined_original_cells, mesh)
827
# Sanity check
828
@assert issorted(original_cell_ids) "`original_cell_ids` not sorted"
829
@assert issorted(refined_original_cells) "`refined_cell_ids` not sorted"
830
831
# Create array with original cell ids (not yet shifted)
832
shifted_cell_ids = collect(1:original_cell_ids[end])
833
834
# Loop over refined original cells and apply shift for all following cells
835
for cell_id in refined_original_cells
836
# Only calculate shifts for cell ids that are relevant
837
if cell_id > length(shifted_cell_ids)
838
break
839
end
840
841
# Shift all subsequent cells by 2^ndims ids
842
shifted_cell_ids[(cell_id + 1):end] .+= 2^ndims(mesh)
843
end
844
845
# Convert original cell ids to their shifted values
846
return shifted_cell_ids[original_cell_ids]
847
end
848
849
"""
850
ControllerThreeLevel(semi, indicator; base_level=1,
851
med_level=base_level, med_threshold=0.0,
852
max_level=base_level, max_threshold=1.0)
853
854
An AMR controller based on three levels (in descending order of precedence):
855
- set the target level to `max_level` if `indicator > max_threshold`
856
- set the target level to `med_level` if `indicator > med_threshold`;
857
if `med_level < 0`, set the target level to the current level
858
- set the target level to `base_level` otherwise
859
"""
860
struct ControllerThreeLevel{RealT <: Real, Indicator, Cache}
861
base_level::Int
862
med_level::Int
863
max_level::Int
864
med_threshold::RealT
865
max_threshold::RealT
866
indicator::Indicator
867
cache::Cache
868
end
869
870
function ControllerThreeLevel(semi, indicator; base_level = 1,
871
med_level = base_level, med_threshold = 0.0,
872
max_level = base_level, max_threshold = 1.0)
873
med_threshold, max_threshold = promote(med_threshold, max_threshold)
874
cache = create_cache(ControllerThreeLevel, semi)
875
return ControllerThreeLevel{typeof(max_threshold), typeof(indicator),
876
typeof(cache)}(base_level,
877
med_level,
878
max_level,
879
med_threshold,
880
max_threshold,
881
indicator,
882
cache)
883
end
884
885
max_level(controller::ControllerThreeLevel) = controller.max_level
886
887
function create_cache(indicator_type::Type{ControllerThreeLevel}, semi)
888
return create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
889
end
890
891
function Base.show(io::IO, controller::ControllerThreeLevel)
892
@nospecialize controller # reduce precompilation time
893
894
print(io, "ControllerThreeLevel(")
895
print(io, controller.indicator)
896
print(io, ", base_level=", controller.base_level)
897
print(io, ", med_level=", controller.med_level)
898
print(io, ", max_level=", controller.max_level)
899
print(io, ", med_threshold=", controller.med_threshold)
900
print(io, ", max_threshold=", controller.max_threshold)
901
print(io, ")")
902
return nothing
903
end
904
905
function Base.show(io::IO, mime::MIME"text/plain", controller::ControllerThreeLevel)
906
@nospecialize controller # reduce precompilation time
907
908
if get(io, :compact, false)
909
show(io, controller)
910
else
911
summary_header(io, "ControllerThreeLevel")
912
summary_line(io, "indicator", controller.indicator |> typeof |> nameof)
913
show(increment_indent(io), mime, controller.indicator)
914
summary_line(io, "base_level", controller.base_level)
915
summary_line(io, "med_level", controller.med_level)
916
summary_line(io, "max_level", controller.max_level)
917
summary_line(io, "med_threshold", controller.med_threshold)
918
summary_line(io, "max_threshold", controller.max_threshold)
919
summary_footer(io)
920
end
921
end
922
923
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
924
controller::ControllerThreeLevel,
925
amr_callback::AMRCallback;
926
kwargs...)
927
# call the indicator to get up-to-date values for IO
928
controller.indicator(u, mesh, equations, solver, cache; kwargs...)
929
return get_element_variables!(element_variables, controller.indicator, amr_callback)
930
end
931
932
function get_element_variables!(element_variables, indicator::AbstractIndicator,
933
::AMRCallback)
934
element_variables[:indicator_amr] = indicator.cache.alpha
935
return nothing
936
end
937
938
function current_element_levels(mesh::TreeMesh, solver, cache)
939
cell_ids = cache.elements.cell_ids[eachelement(solver, cache)]
940
941
return mesh.tree.levels[cell_ids]
942
end
943
944
function extract_levels_iter_volume(info, user_data)
945
info_pw = PointerWrapper(info)
946
947
# Load tree from global trees array, one-based indexing
948
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
949
# Quadrant numbering offset of this quadrant
950
offset = tree_pw.quadrants_offset[]
951
# Global quad ID
952
quad_id = offset + info_pw.quadid[]
953
# Julia element ID
954
element_id = quad_id + 1
955
956
current_level = info_pw.quad.level[]
957
958
# Unpack user_data = current_levels and save current element level
959
pw = PointerWrapper(Int, user_data)
960
pw[element_id] = current_level
961
962
return nothing
963
end
964
965
# 2D
966
function cfunction(::typeof(extract_levels_iter_volume), ::Val{2})
967
@cfunction(extract_levels_iter_volume, Cvoid,
968
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
969
end
970
# 3D
971
function cfunction(::typeof(extract_levels_iter_volume), ::Val{3})
972
@cfunction(extract_levels_iter_volume, Cvoid,
973
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
974
end
975
976
function current_element_levels(mesh::P4estMesh, solver, cache)
977
current_levels = Vector{Int}(undef, nelements(solver, cache))
978
979
iter_volume_c = cfunction(extract_levels_iter_volume, Val(ndims(mesh)))
980
iterate_p4est(mesh.p4est, current_levels; iter_volume_c = iter_volume_c)
981
982
return current_levels
983
end
984
985
function current_element_levels(mesh::T8codeMesh, solver, cache)
986
return trixi_t8_get_local_element_levels(mesh.forest)
987
end
988
989
# TODO: Taal refactor, merge the two loops of ControllerThreeLevel and IndicatorLöhner etc.?
990
# But that would remove the simplest possibility to write that stuff to a file...
991
# We could of course implement some additional logic and workarounds, but is it worth the effort?
992
function (controller::ControllerThreeLevel)(u::AbstractArray{<:Any},
993
mesh, equations, dg::DG, cache;
994
kwargs...)
995
@unpack controller_value = controller.cache
996
resize!(controller_value, nelements(dg, cache))
997
998
alpha = controller.indicator(u, mesh, equations, dg, cache; kwargs...)
999
current_levels = current_element_levels(mesh, dg, cache)
1000
1001
@threaded for element in eachelement(dg, cache)
1002
current_level = current_levels[element]
1003
1004
# set target level
1005
target_level = current_level
1006
if alpha[element] > controller.max_threshold
1007
target_level = controller.max_level
1008
elseif alpha[element] > controller.med_threshold
1009
if controller.med_level > 0
1010
target_level = controller.med_level
1011
# otherwise, target_level = current_level
1012
# set med_level = -1 to implicitly use med_level = current_level
1013
end
1014
else
1015
target_level = controller.base_level
1016
end
1017
1018
# compare target level with actual level to set controller
1019
if current_level < target_level
1020
controller_value[element] = 1 # refine!
1021
elseif current_level > target_level
1022
controller_value[element] = -1 # coarsen!
1023
else
1024
controller_value[element] = 0 # we're good
1025
end
1026
end
1027
1028
return controller_value
1029
end
1030
1031
"""
1032
ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1033
base_level=1,
1034
med_level=base_level, med_threshold=0.0,
1035
max_level=base_level, max_threshold=1.0,
1036
max_threshold_secondary=1.0)
1037
1038
An AMR controller based on three levels (in descending order of precedence):
1039
- set the target level to `max_level` if `indicator_primary > max_threshold`
1040
- set the target level to `med_level` if `indicator_primary > med_threshold`;
1041
if `med_level < 0`, set the target level to the current level
1042
- set the target level to `base_level` otherwise
1043
If `indicator_secondary >= max_threshold_secondary`,
1044
set the target level to `max_level`.
1045
"""
1046
struct ControllerThreeLevelCombined{RealT <: Real, IndicatorPrimary, IndicatorSecondary,
1047
Cache}
1048
base_level::Int
1049
med_level::Int
1050
max_level::Int
1051
med_threshold::RealT
1052
max_threshold::RealT
1053
max_threshold_secondary::RealT
1054
indicator_primary::IndicatorPrimary
1055
indicator_secondary::IndicatorSecondary
1056
cache::Cache
1057
end
1058
1059
function ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1060
base_level = 1,
1061
med_level = base_level, med_threshold = 0.0,
1062
max_level = base_level, max_threshold = 1.0,
1063
max_threshold_secondary = 1.0)
1064
med_threshold, max_threshold, max_threshold_secondary = promote(med_threshold,
1065
max_threshold,
1066
max_threshold_secondary)
1067
cache = create_cache(ControllerThreeLevelCombined, semi)
1068
return ControllerThreeLevelCombined{typeof(max_threshold),
1069
typeof(indicator_primary),
1070
typeof(indicator_secondary), typeof(cache)}(base_level,
1071
med_level,
1072
max_level,
1073
med_threshold,
1074
max_threshold,
1075
max_threshold_secondary,
1076
indicator_primary,
1077
indicator_secondary,
1078
cache)
1079
end
1080
1081
max_level(controller::ControllerThreeLevelCombined) = controller.max_level
1082
1083
function create_cache(indicator_type::Type{ControllerThreeLevelCombined}, semi)
1084
return create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
1085
end
1086
1087
function Base.show(io::IO, controller::ControllerThreeLevelCombined)
1088
@nospecialize controller # reduce precompilation time
1089
1090
print(io, "ControllerThreeLevelCombined(")
1091
print(io, controller.indicator_primary)
1092
print(io, ", ", controller.indicator_secondary)
1093
print(io, ", base_level=", controller.base_level)
1094
print(io, ", med_level=", controller.med_level)
1095
print(io, ", max_level=", controller.max_level)
1096
print(io, ", med_threshold=", controller.med_threshold)
1097
print(io, ", max_threshold_secondary=", controller.max_threshold_secondary)
1098
print(io, ")")
1099
return nothing
1100
end
1101
1102
function Base.show(io::IO, mime::MIME"text/plain",
1103
controller::ControllerThreeLevelCombined)
1104
@nospecialize controller # reduce precompilation time
1105
1106
if get(io, :compact, false)
1107
show(io, controller)
1108
else
1109
summary_header(io, "ControllerThreeLevelCombined")
1110
summary_line(io, "primary indicator",
1111
controller.indicator_primary |> typeof |> nameof)
1112
show(increment_indent(io), mime, controller.indicator_primary)
1113
summary_line(io, "secondary indicator",
1114
controller.indicator_secondary |> typeof |> nameof)
1115
show(increment_indent(io), mime, controller.indicator_secondary)
1116
summary_line(io, "base_level", controller.base_level)
1117
summary_line(io, "med_level", controller.med_level)
1118
summary_line(io, "max_level", controller.max_level)
1119
summary_line(io, "med_threshold", controller.med_threshold)
1120
summary_line(io, "max_threshold", controller.max_threshold)
1121
summary_line(io, "max_threshold_secondary", controller.max_threshold_secondary)
1122
summary_footer(io)
1123
end
1124
end
1125
1126
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
1127
controller::ControllerThreeLevelCombined,
1128
amr_callback::AMRCallback;
1129
kwargs...)
1130
# call the indicator to get up-to-date values for IO
1131
controller.indicator_primary(u, mesh, equations, solver, cache; kwargs...)
1132
return get_element_variables!(element_variables, controller.indicator_primary,
1133
amr_callback)
1134
end
1135
1136
function (controller::ControllerThreeLevelCombined)(u::AbstractArray{<:Any},
1137
mesh, equations, dg::DG, cache;
1138
kwargs...)
1139
@unpack controller_value = controller.cache
1140
resize!(controller_value, nelements(dg, cache))
1141
1142
alpha = controller.indicator_primary(u, mesh, equations, dg, cache; kwargs...)
1143
alpha_secondary = controller.indicator_secondary(u, mesh, equations, dg, cache)
1144
1145
current_levels = current_element_levels(mesh, dg, cache)
1146
1147
@threaded for element in eachelement(dg, cache)
1148
current_level = current_levels[element]
1149
1150
# set target level
1151
target_level = current_level
1152
if alpha[element] > controller.max_threshold
1153
target_level = controller.max_level
1154
elseif alpha[element] > controller.med_threshold
1155
if controller.med_level > 0
1156
target_level = controller.med_level
1157
# otherwise, target_level = current_level
1158
# set med_level = -1 to implicitly use med_level = current_level
1159
end
1160
else
1161
target_level = controller.base_level
1162
end
1163
1164
if alpha_secondary[element] >= controller.max_threshold_secondary
1165
target_level = controller.max_level
1166
end
1167
1168
# compare target level with actual level to set controller
1169
if current_level < target_level
1170
controller_value[element] = 1 # refine!
1171
elseif current_level > target_level
1172
controller_value[element] = -1 # coarsen!
1173
else
1174
controller_value[element] = 0 # we're good
1175
end
1176
end
1177
1178
return controller_value
1179
end
1180
1181
include("amr_dg.jl")
1182
include("amr_dg1d.jl")
1183
include("amr_dg2d.jl")
1184
include("amr_dg3d.jl")
1185
end # @muladd
1186
1187