Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/amr.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
AMRCallback(semi, controller [,adaptor=AdaptorAMR(semi)];
10
interval,
11
adapt_initial_condition=true,
12
adapt_initial_condition_only_refine=true,
13
dynamic_load_balancing=true)
14
15
Performs adaptive mesh refinement (AMR) every `interval` time steps
16
for a given semidiscretization `semi` using the chosen `controller`.
17
"""
18
struct AMRCallback{Controller, Adaptor, Cache}
19
controller::Controller
20
interval::Int
21
adapt_initial_condition::Bool
22
adapt_initial_condition_only_refine::Bool
23
dynamic_load_balancing::Bool
24
adaptor::Adaptor
25
amr_cache::Cache
26
end
27
28
function AMRCallback(semi, controller, adaptor;
29
interval,
30
adapt_initial_condition = true,
31
adapt_initial_condition_only_refine = true,
32
dynamic_load_balancing = true)
33
# check arguments
34
if !(interval isa Integer && interval >= 0)
35
throw(ArgumentError("`interval` must be a non-negative integer (provided `interval = $interval`)"))
36
end
37
38
# AMR every `interval` time steps, but not after the final step
39
# With error-based step size control, some steps can be rejected. Thus,
40
# `integrator.iter >= integrator.stats.naccept`
41
# (total #steps) (#accepted steps)
42
# We need to check the number of accepted steps since callbacks are not
43
# activated after a rejected step.
44
if interval > 0
45
condition = (u, t, integrator) -> ((integrator.stats.naccept % interval == 0) &&
46
!(integrator.stats.naccept == 0 &&
47
integrator.iter > 0) &&
48
!isfinished(integrator))
49
else # disable the AMR callback except possibly for initial refinement during initialization
50
condition = (u, t, integrator) -> false
51
end
52
53
to_refine = Int[]
54
to_coarsen = Int[]
55
amr_cache = (; to_refine, to_coarsen)
56
57
amr_callback = AMRCallback{typeof(controller), typeof(adaptor), typeof(amr_cache)}(controller,
58
interval,
59
adapt_initial_condition,
60
adapt_initial_condition_only_refine,
61
dynamic_load_balancing,
62
adaptor,
63
amr_cache)
64
65
DiscreteCallback(condition, amr_callback,
66
save_positions = (false, false),
67
initialize = initialize!)
68
end
69
70
function AMRCallback(semi, controller; kwargs...)
71
adaptor = AdaptorAMR(semi)
72
AMRCallback(semi, controller, adaptor; kwargs...)
73
end
74
75
function AdaptorAMR(semi; kwargs...)
76
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
77
AdaptorAMR(mesh, solver; kwargs...)
78
end
79
80
# TODO: Taal bikeshedding, implement a method with less information and the signature
81
# function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:AMRCallback})
82
# @nospecialize cb # reduce precompilation time
83
#
84
# amr_callback = cb.affect!
85
# print(io, "AMRCallback")
86
# end
87
function Base.show(io::IO, mime::MIME"text/plain",
88
cb::DiscreteCallback{<:Any, <:AMRCallback})
89
@nospecialize cb # reduce precompilation time
90
91
if get(io, :compact, false)
92
show(io, cb)
93
else
94
amr_callback = cb.affect!
95
96
summary_header(io, "AMRCallback")
97
summary_line(io, "controller", amr_callback.controller |> typeof |> nameof)
98
show(increment_indent(io), mime, amr_callback.controller)
99
summary_line(io, "interval", amr_callback.interval)
100
summary_line(io, "adapt IC",
101
amr_callback.adapt_initial_condition ? "yes" : "no")
102
if amr_callback.adapt_initial_condition
103
summary_line(io, "│ only refine",
104
amr_callback.adapt_initial_condition_only_refine ? "yes" :
105
"no")
106
end
107
summary_footer(io)
108
end
109
end
110
111
# The function below is used to control the output depending on whether or not AMR is enabled.
112
"""
113
uses_amr(callback)
114
115
Checks whether the provided callback or `CallbackSet` is an [`AMRCallback`](@ref)
116
or contains one.
117
"""
118
uses_amr(cb) = false
119
function uses_amr(cb::DiscreteCallback{Condition, Affect!}) where {Condition,
120
Affect! <:
121
AMRCallback}
122
true
123
end
124
uses_amr(callbacks::CallbackSet) = mapreduce(uses_amr, |, callbacks.discrete_callbacks)
125
126
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
127
amr_callback::AMRCallback; kwargs...)
128
get_element_variables!(element_variables, u, mesh, equations, solver, cache,
129
amr_callback.controller, amr_callback; kwargs...)
130
end
131
132
function initialize!(cb::DiscreteCallback{Condition, Affect!}, u, t,
133
integrator) where {Condition, Affect! <: AMRCallback}
134
amr_callback = cb.affect!
135
semi = integrator.p
136
137
@trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
138
# iterate until mesh does not change anymore
139
has_changed = amr_callback(integrator,
140
only_refine = amr_callback.adapt_initial_condition_only_refine)
141
iterations = 1
142
while has_changed
143
compute_coefficients!(integrator.u, t, semi)
144
u_modified!(integrator, true)
145
has_changed = amr_callback(integrator,
146
only_refine = amr_callback.adapt_initial_condition_only_refine)
147
iterations = iterations + 1
148
allowed_max_iterations = max(10, max_level(amr_callback.controller))
149
if iterations > allowed_max_iterations
150
@warn "AMR for initial condition did not settle within $(allowed_max_iterations) iterations!\n" *
151
"Consider adjusting thresholds or setting `adapt_initial_condition_only_refine`."
152
break
153
end
154
end
155
156
# Update initial state integrals of analysis callback if it exists
157
# See https://github.com/trixi-framework/Trixi.jl/issues/2536 for more information.
158
index = findfirst(cb -> cb.affect! isa AnalysisCallback,
159
integrator.opts.callback.discrete_callbacks)
160
if !isnothing(index)
161
analysis_callback = integrator.opts.callback.discrete_callbacks[index].affect!
162
163
initial_state_integrals = integrate(integrator.u, semi)
164
analysis_callback.initial_state_integrals = initial_state_integrals
165
end
166
end
167
168
return nothing
169
end
170
171
# TODO: Taal remove?
172
# function (cb::DiscreteCallback{Condition,Affect!})(ode::ODEProblem) where {Condition, Affect!<:AMRCallback}
173
# amr_callback = cb.affect!
174
# semi = ode.p
175
176
# @trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
177
# # iterate until mesh does not change anymore
178
# has_changed = true
179
# while has_changed
180
# has_changed = amr_callback(ode.u0, semi,
181
# only_refine=amr_callback.adapt_initial_condition_only_refine)
182
# compute_coefficients!(ode.u0, ode.tspan[1], semi)
183
# end
184
# end
185
186
# return nothing
187
# end
188
189
function (amr_callback::AMRCallback)(integrator; kwargs...)
190
u_ode = integrator.u
191
semi = integrator.p
192
193
@trixi_timeit timer() "AMR" begin
194
has_changed = amr_callback(u_ode, semi,
195
integrator.t, integrator.iter; kwargs...)
196
if has_changed
197
resize!(integrator, length(u_ode))
198
u_modified!(integrator, true)
199
end
200
end
201
202
return has_changed
203
end
204
205
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
206
semi::SemidiscretizationHyperbolic,
207
t, iter;
208
kwargs...)
209
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
210
# it when doing AMR while still dispatching on the `mesh` etc.
211
amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi, t, iter; kwargs...)
212
end
213
214
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
215
semi::SemidiscretizationHyperbolicParabolic,
216
t, iter;
217
kwargs...)
218
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
219
# it when doing AMR while still dispatching on the `mesh` etc.
220
amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi.cache_parabolic,
221
semi, t, iter; kwargs...)
222
end
223
224
# `passive_args` is currently used for Euler with self-gravity to adapt the gravity solver
225
# passively without querying its indicator, based on the assumption that both solvers use
226
# the same mesh. That's a hack and should be improved in the future once we have more examples
227
# and a better understanding of such a coupling.
228
# `passive_args` is expected to be an iterable of `Tuple`s of the form
229
# `(p_u_ode, p_mesh, p_equations, p_dg, p_cache)`.
230
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
231
equations, dg::DG, cache, semi,
232
t, iter;
233
only_refine = false, only_coarsen = false,
234
passive_args = ())
235
@unpack controller, adaptor = amr_callback
236
237
u = wrap_array(u_ode, mesh, equations, dg, cache)
238
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
239
t = t, iter = iter)
240
241
if mpi_isparallel()
242
# Collect lambda for all elements
243
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
244
# Use parent because n_elements_by_rank is an OffsetArray
245
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
246
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
247
lambda = lambda_global
248
end
249
250
leaf_cell_ids = leaf_cells(mesh.tree)
251
@boundscheck begin
252
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
253
end
254
255
@unpack to_refine, to_coarsen = amr_callback.amr_cache
256
empty!(to_refine)
257
empty!(to_coarsen)
258
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
259
for element in eachindex(lambda)
260
controller_value = lambda[element]
261
if controller_value > 0
262
push!(to_refine, leaf_cell_ids[element])
263
elseif controller_value < 0
264
push!(to_coarsen, leaf_cell_ids[element])
265
end
266
end
267
268
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
269
# refine mesh
270
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
271
to_refine)
272
273
# Find all indices of elements whose cell ids are in refined_original_cells
274
elements_to_refine = findall(in(refined_original_cells),
275
cache.elements.cell_ids)
276
277
# refine solver
278
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
279
cache, elements_to_refine)
280
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
281
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
282
p_equations, p_dg, p_cache,
283
elements_to_refine)
284
end
285
else
286
# If there is nothing to refine, create empty array for later use
287
refined_original_cells = Int[]
288
end
289
290
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
291
# Since the cells may have been shifted due to refinement, first we need to
292
# translate the old cell ids to the new cell ids
293
if !isempty(to_coarsen)
294
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
295
end
296
297
# Next, determine the parent cells from which the fine cells are to be
298
# removed, since these are needed for the coarsen! function. However, since
299
# we only want to coarsen if *all* child cells are marked for coarsening,
300
# we count the coarsening indicators for each parent cell and only coarsen
301
# if all children are marked as such (i.e., where the count is 2^ndims). At
302
# the same time, check if a cell is marked for coarsening even though it is
303
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
304
# smoothing during the preceding refinement operation.
305
parents_to_coarsen = zeros(Int, length(mesh.tree))
306
for cell_id in to_coarsen
307
# If cell has no parent, it cannot be coarsened
308
if !has_parent(mesh.tree, cell_id)
309
continue
310
end
311
312
# If cell is not leaf (anymore), it cannot be coarsened
313
if !is_leaf(mesh.tree, cell_id)
314
continue
315
end
316
317
# Increase count for parent cell
318
parent_id = mesh.tree.parent_ids[cell_id]
319
parents_to_coarsen[parent_id] += 1
320
end
321
322
# Extract only those parent cells for which all children should be coarsened
323
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
324
325
# Finally, coarsen mesh
326
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
327
to_coarsen)
328
329
# Convert coarsened parent cell ids to the list of child cell ids that have
330
# been removed, since this is the information that is expected by the solver
331
removed_child_cells = zeros(Int,
332
n_children_per_cell(mesh.tree) *
333
length(coarsened_original_cells))
334
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
335
for child in 1:n_children_per_cell(mesh.tree)
336
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
337
child
338
end
339
end
340
341
# Find all indices of elements whose cell ids are in removed_child_cells
342
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
343
344
# coarsen solver
345
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
346
cache, elements_to_remove)
347
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
348
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
349
p_equations, p_dg, p_cache,
350
elements_to_remove)
351
end
352
else
353
# If there is nothing to coarsen, create empty array for later use
354
coarsened_original_cells = Int[]
355
end
356
357
# Store whether there were any cells coarsened or refined
358
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
359
if has_changed # TODO: Taal decide, where shall we set this?
360
# don't set it to has_changed since there can be changes from earlier calls
361
mesh.unsaved_changes = true
362
end
363
364
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
365
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
366
@trixi_timeit timer() "dynamic load balancing" begin
367
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
368
369
partition!(mesh)
370
371
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
372
end
373
end
374
375
# Return true if there were any cells coarsened or refined, otherwise false
376
return has_changed
377
end
378
379
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
380
equations, dg::DG,
381
cache, cache_parabolic,
382
semi::SemidiscretizationHyperbolicParabolic,
383
t, iter;
384
only_refine = false, only_coarsen = false)
385
@unpack controller, adaptor = amr_callback
386
387
u = wrap_array(u_ode, mesh, equations, dg, cache)
388
# Indicator kept based on hyperbolic variables
389
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
390
t = t, iter = iter)
391
392
if mpi_isparallel()
393
error("MPI has not been verified yet for parabolic AMR")
394
395
# Collect lambda for all elements
396
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
397
# Use parent because n_elements_by_rank is an OffsetArray
398
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
399
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
400
lambda = lambda_global
401
end
402
403
leaf_cell_ids = leaf_cells(mesh.tree)
404
@boundscheck begin
405
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
406
end
407
408
@unpack to_refine, to_coarsen = amr_callback.amr_cache
409
empty!(to_refine)
410
empty!(to_coarsen)
411
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
412
for element in eachindex(lambda)
413
controller_value = lambda[element]
414
if controller_value > 0
415
push!(to_refine, leaf_cell_ids[element])
416
elseif controller_value < 0
417
push!(to_coarsen, leaf_cell_ids[element])
418
end
419
end
420
421
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
422
# refine mesh
423
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
424
to_refine)
425
426
# Find all indices of elements whose cell ids are in refined_original_cells
427
# Note: This assumes same indices for hyperbolic and parabolic part.
428
elements_to_refine = findall(in(refined_original_cells),
429
cache.elements.cell_ids)
430
431
# refine solver
432
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
433
cache, cache_parabolic,
434
elements_to_refine)
435
else
436
# If there is nothing to refine, create empty array for later use
437
refined_original_cells = Int[]
438
end
439
440
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
441
# Since the cells may have been shifted due to refinement, first we need to
442
# translate the old cell ids to the new cell ids
443
if !isempty(to_coarsen)
444
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
445
end
446
447
# Next, determine the parent cells from which the fine cells are to be
448
# removed, since these are needed for the coarsen! function. However, since
449
# we only want to coarsen if *all* child cells are marked for coarsening,
450
# we count the coarsening indicators for each parent cell and only coarsen
451
# if all children are marked as such (i.e., where the count is 2^ndims). At
452
# the same time, check if a cell is marked for coarsening even though it is
453
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
454
# smoothing during the preceding refinement operation.
455
parents_to_coarsen = zeros(Int, length(mesh.tree))
456
for cell_id in to_coarsen
457
# If cell has no parent, it cannot be coarsened
458
if !has_parent(mesh.tree, cell_id)
459
continue
460
end
461
462
# If cell is not leaf (anymore), it cannot be coarsened
463
if !is_leaf(mesh.tree, cell_id)
464
continue
465
end
466
467
# Increase count for parent cell
468
parent_id = mesh.tree.parent_ids[cell_id]
469
parents_to_coarsen[parent_id] += 1
470
end
471
472
# Extract only those parent cells for which all children should be coarsened
473
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
474
475
# Finally, coarsen mesh
476
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
477
to_coarsen)
478
479
# Convert coarsened parent cell ids to the list of child cell ids that have
480
# been removed, since this is the information that is expected by the solver
481
removed_child_cells = zeros(Int,
482
n_children_per_cell(mesh.tree) *
483
length(coarsened_original_cells))
484
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
485
for child in 1:n_children_per_cell(mesh.tree)
486
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
487
child
488
end
489
end
490
491
# Find all indices of elements whose cell ids are in removed_child_cells
492
# Note: This assumes same indices for hyperbolic and parabolic part.
493
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
494
495
# coarsen solver
496
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
497
cache, cache_parabolic,
498
elements_to_remove)
499
else
500
# If there is nothing to coarsen, create empty array for later use
501
coarsened_original_cells = Int[]
502
end
503
504
# Store whether there were any cells coarsened or refined
505
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
506
if has_changed # TODO: Taal decide, where shall we set this?
507
# don't set it to has_changed since there can be changes from earlier calls
508
mesh.unsaved_changes = true
509
end
510
511
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
512
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
513
error("MPI has not been verified yet for parabolic AMR")
514
515
@trixi_timeit timer() "dynamic load balancing" begin
516
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
517
518
partition!(mesh)
519
520
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
521
end
522
end
523
524
# Return true if there were any cells coarsened or refined, otherwise false
525
return has_changed
526
end
527
528
# Copy controller values to quad user data storage, will be called below
529
function copy_to_quad_iter_volume(info, user_data)
530
info_pw = PointerWrapper(info)
531
532
# Load tree from global trees array, one-based indexing
533
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
534
# Quadrant numbering offset of this quadrant
535
offset = tree_pw.quadrants_offset[]
536
# Global quad ID
537
quad_id = offset + info_pw.quadid[]
538
539
# Access user_data = lambda
540
user_data_pw = PointerWrapper(Int, user_data)
541
# Load controller_value = lambda[quad_id + 1]
542
controller_value = user_data_pw[quad_id + 1]
543
544
# Access quadrant's user data ([global quad ID, controller_value])
545
quad_data_pw = PointerWrapper(Int, info_pw.quad.p.user_data[])
546
# Save controller value to quadrant's user data.
547
quad_data_pw[2] = controller_value
548
549
return nothing
550
end
551
552
# specialized callback which includes the `cache_parabolic` argument
553
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
554
equations, dg::DG, cache, cache_parabolic,
555
semi,
556
t, iter;
557
only_refine = false, only_coarsen = false,
558
passive_args = ())
559
@unpack controller, adaptor = amr_callback
560
561
u = wrap_array(u_ode, mesh, equations, dg, cache)
562
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
563
t = t, iter = iter)
564
565
@boundscheck begin
566
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
567
end
568
569
# Copy controller value of each quad to the quad's user data storage
570
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
571
572
# The pointer to lambda will be interpreted as Ptr{Int} below
573
@assert lambda isa Vector{Int}
574
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
575
576
@trixi_timeit timer() "refine" if !only_coarsen
577
# Refine mesh
578
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
579
580
# Refine solver
581
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
582
cache, cache_parabolic,
583
refined_original_cells)
584
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
585
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
586
p_equations,
587
p_dg, p_cache,
588
refined_original_cells)
589
end
590
else
591
# If there is nothing to refine, create empty array for later use
592
refined_original_cells = Int[]
593
end
594
595
@trixi_timeit timer() "coarsen" if !only_refine
596
# Coarsen mesh
597
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
598
599
# coarsen solver
600
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
601
cache, cache_parabolic,
602
coarsened_original_cells)
603
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
604
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
605
p_equations,
606
p_dg, p_cache,
607
coarsened_original_cells)
608
end
609
else
610
# If there is nothing to coarsen, create empty array for later use
611
coarsened_original_cells = Int[]
612
end
613
614
# Store whether there were any cells coarsened or refined and perform load balancing
615
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
616
# Check if mesh changed on other processes
617
if mpi_isparallel()
618
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
619
end
620
621
if has_changed # TODO: Taal decide, where shall we set this?
622
# don't set it to has_changed since there can be changes from earlier calls
623
mesh.unsaved_changes = true
624
625
if mpi_isparallel() && amr_callback.dynamic_load_balancing
626
@trixi_timeit timer() "dynamic load balancing" begin
627
global_first_quadrant = unsafe_wrap(Array,
628
unsafe_load(mesh.p4est).global_first_quadrant,
629
mpi_nranks() + 1)
630
old_global_first_quadrant = copy(global_first_quadrant)
631
partition!(mesh)
632
rebalance_solver!(u_ode, mesh, equations, dg, cache,
633
old_global_first_quadrant)
634
end
635
end
636
637
reinitialize_boundaries!(semi.boundary_conditions, cache)
638
# if the semidiscretization also stores parabolic boundary conditions,
639
# reinitialize them after each refinement step as well.
640
if hasproperty(semi, :boundary_conditions_parabolic)
641
reinitialize_boundaries!(semi.boundary_conditions_parabolic, cache)
642
end
643
end
644
645
# Return true if there were any cells coarsened or refined, otherwise false
646
return has_changed
647
end
648
649
# 2D
650
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{2})
651
@cfunction(copy_to_quad_iter_volume, Cvoid,
652
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
653
end
654
# 3D
655
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{3})
656
@cfunction(copy_to_quad_iter_volume, Cvoid,
657
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
658
end
659
660
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
661
equations, dg::DG, cache, semi,
662
t, iter;
663
only_refine = false, only_coarsen = false,
664
passive_args = ())
665
@unpack controller, adaptor = amr_callback
666
667
u = wrap_array(u_ode, mesh, equations, dg, cache)
668
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
669
t = t, iter = iter)
670
671
@boundscheck begin
672
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
673
end
674
675
# Copy controller value of each quad to the quad's user data storage
676
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
677
678
# The pointer to lambda will be interpreted as Ptr{Int} above
679
@assert lambda isa Vector{Int}
680
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
681
682
@trixi_timeit timer() "refine" if !only_coarsen
683
# Refine mesh
684
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
685
686
# Refine solver
687
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
688
cache,
689
refined_original_cells)
690
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
691
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
692
p_equations,
693
p_dg, p_cache,
694
refined_original_cells)
695
end
696
else
697
# If there is nothing to refine, create empty array for later use
698
refined_original_cells = Int[]
699
end
700
701
@trixi_timeit timer() "coarsen" if !only_refine
702
# Coarsen mesh
703
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
704
705
# coarsen solver
706
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
707
cache,
708
coarsened_original_cells)
709
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
710
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
711
p_equations,
712
p_dg, p_cache,
713
coarsened_original_cells)
714
end
715
else
716
# If there is nothing to coarsen, create empty array for later use
717
coarsened_original_cells = Int[]
718
end
719
720
# Store whether there were any cells coarsened or refined and perform load balancing
721
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
722
# Check if mesh changed on other processes
723
if mpi_isparallel()
724
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
725
end
726
727
if has_changed # TODO: Taal decide, where shall we set this?
728
# don't set it to has_changed since there can be changes from earlier calls
729
mesh.unsaved_changes = true
730
731
if mpi_isparallel() && amr_callback.dynamic_load_balancing
732
@trixi_timeit timer() "dynamic load balancing" begin
733
global_first_quadrant = unsafe_wrap(Array,
734
unsafe_load(mesh.p4est).global_first_quadrant,
735
mpi_nranks() + 1)
736
old_global_first_quadrant = copy(global_first_quadrant)
737
partition!(mesh)
738
rebalance_solver!(u_ode, mesh, equations, dg, cache,
739
old_global_first_quadrant)
740
end
741
end
742
743
reinitialize_boundaries!(semi.boundary_conditions, cache)
744
end
745
746
# Return true if there were any cells coarsened or refined, otherwise false
747
return has_changed
748
end
749
750
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::T8codeMesh,
751
equations, dg::DG, cache, semi,
752
t, iter;
753
only_refine = false, only_coarsen = false,
754
passive_args = ())
755
has_changed = false
756
757
@unpack controller, adaptor = amr_callback
758
759
u = wrap_array(u_ode, mesh, equations, dg, cache)
760
indicators = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg,
761
cache, t = t, iter = iter)
762
763
if only_coarsen
764
indicators[indicators .> 0] .= 0
765
end
766
767
if only_refine
768
indicators[indicators .< 0] .= 0
769
end
770
771
@boundscheck begin
772
@assert axes(indicators)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(indicators))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
773
end
774
775
@trixi_timeit timer() "adapt" begin
776
difference = @trixi_timeit timer() "mesh" trixi_t8_adapt!(mesh, indicators)
777
778
# Store whether there were any cells coarsened or refined and perform load balancing.
779
has_changed = any(difference .!= 0)
780
781
# Check if mesh changed on other processes
782
if mpi_isparallel()
783
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
784
end
785
786
if has_changed
787
@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
788
cache, difference)
789
end
790
end
791
792
if has_changed
793
if mpi_isparallel() && amr_callback.dynamic_load_balancing
794
@trixi_timeit timer() "dynamic load balancing" begin
795
old_global_first_element_ids = get_global_first_element_ids(mesh)
796
partition!(mesh)
797
rebalance_solver!(u_ode, mesh, equations, dg, cache,
798
old_global_first_element_ids)
799
end
800
end
801
802
reinitialize_boundaries!(semi.boundary_conditions, cache)
803
end
804
805
mesh.unsaved_changes |= has_changed
806
807
# Return true if there were any cells coarsened or refined, otherwise false.
808
return has_changed
809
end
810
811
function reinitialize_boundaries!(boundary_conditions::UnstructuredSortedBoundaryTypes,
812
cache)
813
# Reinitialize boundary types container because boundaries may have changed.
814
initialize!(boundary_conditions, cache)
815
end
816
817
function reinitialize_boundaries!(boundary_conditions, cache)
818
return boundary_conditions
819
end
820
821
# After refining cells, shift original cell ids to match new locations
822
# Note: Assumes sorted lists of original and refined cell ids!
823
# Note: `mesh` is only required to extract ndims
824
function original2refined(original_cell_ids, refined_original_cells, mesh)
825
# Sanity check
826
@assert issorted(original_cell_ids) "`original_cell_ids` not sorted"
827
@assert issorted(refined_original_cells) "`refined_cell_ids` not sorted"
828
829
# Create array with original cell ids (not yet shifted)
830
shifted_cell_ids = collect(1:original_cell_ids[end])
831
832
# Loop over refined original cells and apply shift for all following cells
833
for cell_id in refined_original_cells
834
# Only calculate shifts for cell ids that are relevant
835
if cell_id > length(shifted_cell_ids)
836
break
837
end
838
839
# Shift all subsequent cells by 2^ndims ids
840
shifted_cell_ids[(cell_id + 1):end] .+= 2^ndims(mesh)
841
end
842
843
# Convert original cell ids to their shifted values
844
return shifted_cell_ids[original_cell_ids]
845
end
846
847
"""
848
ControllerThreeLevel(semi, indicator; base_level=1,
849
med_level=base_level, med_threshold=0.0,
850
max_level=base_level, max_threshold=1.0)
851
852
An AMR controller based on three levels (in descending order of precedence):
853
- set the target level to `max_level` if `indicator > max_threshold`
854
- set the target level to `med_level` if `indicator > med_threshold`;
855
if `med_level < 0`, set the target level to the current level
856
- set the target level to `base_level` otherwise
857
"""
858
struct ControllerThreeLevel{RealT <: Real, Indicator, Cache}
859
base_level::Int
860
med_level::Int
861
max_level::Int
862
med_threshold::RealT
863
max_threshold::RealT
864
indicator::Indicator
865
cache::Cache
866
end
867
868
function ControllerThreeLevel(semi, indicator; base_level = 1,
869
med_level = base_level, med_threshold = 0.0,
870
max_level = base_level, max_threshold = 1.0)
871
med_threshold, max_threshold = promote(med_threshold, max_threshold)
872
cache = create_cache(ControllerThreeLevel, semi)
873
ControllerThreeLevel{typeof(max_threshold), typeof(indicator), typeof(cache)}(base_level,
874
med_level,
875
max_level,
876
med_threshold,
877
max_threshold,
878
indicator,
879
cache)
880
end
881
882
max_level(controller::ControllerThreeLevel) = controller.max_level
883
884
function create_cache(indicator_type::Type{ControllerThreeLevel}, semi)
885
create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
886
end
887
888
function Base.show(io::IO, controller::ControllerThreeLevel)
889
@nospecialize controller # reduce precompilation time
890
891
print(io, "ControllerThreeLevel(")
892
print(io, controller.indicator)
893
print(io, ", base_level=", controller.base_level)
894
print(io, ", med_level=", controller.med_level)
895
print(io, ", max_level=", controller.max_level)
896
print(io, ", med_threshold=", controller.med_threshold)
897
print(io, ", max_threshold=", controller.max_threshold)
898
print(io, ")")
899
end
900
901
function Base.show(io::IO, mime::MIME"text/plain", controller::ControllerThreeLevel)
902
@nospecialize controller # reduce precompilation time
903
904
if get(io, :compact, false)
905
show(io, controller)
906
else
907
summary_header(io, "ControllerThreeLevel")
908
summary_line(io, "indicator", controller.indicator |> typeof |> nameof)
909
show(increment_indent(io), mime, controller.indicator)
910
summary_line(io, "base_level", controller.base_level)
911
summary_line(io, "med_level", controller.med_level)
912
summary_line(io, "max_level", controller.max_level)
913
summary_line(io, "med_threshold", controller.med_threshold)
914
summary_line(io, "max_threshold", controller.max_threshold)
915
summary_footer(io)
916
end
917
end
918
919
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
920
controller::ControllerThreeLevel,
921
amr_callback::AMRCallback;
922
kwargs...)
923
# call the indicator to get up-to-date values for IO
924
controller.indicator(u, mesh, equations, solver, cache; kwargs...)
925
get_element_variables!(element_variables, controller.indicator, amr_callback)
926
end
927
928
function get_element_variables!(element_variables, indicator::AbstractIndicator,
929
::AMRCallback)
930
element_variables[:indicator_amr] = indicator.cache.alpha
931
return nothing
932
end
933
934
function current_element_levels(mesh::TreeMesh, solver, cache)
935
cell_ids = cache.elements.cell_ids[eachelement(solver, cache)]
936
937
return mesh.tree.levels[cell_ids]
938
end
939
940
function extract_levels_iter_volume(info, user_data)
941
info_pw = PointerWrapper(info)
942
943
# Load tree from global trees array, one-based indexing
944
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
945
# Quadrant numbering offset of this quadrant
946
offset = tree_pw.quadrants_offset[]
947
# Global quad ID
948
quad_id = offset + info_pw.quadid[]
949
# Julia element ID
950
element_id = quad_id + 1
951
952
current_level = info_pw.quad.level[]
953
954
# Unpack user_data = current_levels and save current element level
955
pw = PointerWrapper(Int, user_data)
956
pw[element_id] = current_level
957
958
return nothing
959
end
960
961
# 2D
962
function cfunction(::typeof(extract_levels_iter_volume), ::Val{2})
963
@cfunction(extract_levels_iter_volume, Cvoid,
964
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
965
end
966
# 3D
967
function cfunction(::typeof(extract_levels_iter_volume), ::Val{3})
968
@cfunction(extract_levels_iter_volume, Cvoid,
969
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
970
end
971
972
function current_element_levels(mesh::P4estMesh, solver, cache)
973
current_levels = Vector{Int}(undef, nelements(solver, cache))
974
975
iter_volume_c = cfunction(extract_levels_iter_volume, Val(ndims(mesh)))
976
iterate_p4est(mesh.p4est, current_levels; iter_volume_c = iter_volume_c)
977
978
return current_levels
979
end
980
981
function current_element_levels(mesh::T8codeMesh, solver, cache)
982
return trixi_t8_get_local_element_levels(mesh.forest)
983
end
984
985
# TODO: Taal refactor, merge the two loops of ControllerThreeLevel and IndicatorLöhner etc.?
986
# But that would remove the simplest possibility to write that stuff to a file...
987
# We could of course implement some additional logic and workarounds, but is it worth the effort?
988
function (controller::ControllerThreeLevel)(u::AbstractArray{<:Any},
989
mesh, equations, dg::DG, cache;
990
kwargs...)
991
@unpack controller_value = controller.cache
992
resize!(controller_value, nelements(dg, cache))
993
994
alpha = controller.indicator(u, mesh, equations, dg, cache; kwargs...)
995
current_levels = current_element_levels(mesh, dg, cache)
996
997
@threaded for element in eachelement(dg, cache)
998
current_level = current_levels[element]
999
1000
# set target level
1001
target_level = current_level
1002
if alpha[element] > controller.max_threshold
1003
target_level = controller.max_level
1004
elseif alpha[element] > controller.med_threshold
1005
if controller.med_level > 0
1006
target_level = controller.med_level
1007
# otherwise, target_level = current_level
1008
# set med_level = -1 to implicitly use med_level = current_level
1009
end
1010
else
1011
target_level = controller.base_level
1012
end
1013
1014
# compare target level with actual level to set controller
1015
if current_level < target_level
1016
controller_value[element] = 1 # refine!
1017
elseif current_level > target_level
1018
controller_value[element] = -1 # coarsen!
1019
else
1020
controller_value[element] = 0 # we're good
1021
end
1022
end
1023
1024
return controller_value
1025
end
1026
1027
"""
1028
ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1029
base_level=1,
1030
med_level=base_level, med_threshold=0.0,
1031
max_level=base_level, max_threshold=1.0,
1032
max_threshold_secondary=1.0)
1033
1034
An AMR controller based on three levels (in descending order of precedence):
1035
- set the target level to `max_level` if `indicator_primary > max_threshold`
1036
- set the target level to `med_level` if `indicator_primary > med_threshold`;
1037
if `med_level < 0`, set the target level to the current level
1038
- set the target level to `base_level` otherwise
1039
If `indicator_secondary >= max_threshold_secondary`,
1040
set the target level to `max_level`.
1041
"""
1042
struct ControllerThreeLevelCombined{RealT <: Real, IndicatorPrimary, IndicatorSecondary,
1043
Cache}
1044
base_level::Int
1045
med_level::Int
1046
max_level::Int
1047
med_threshold::RealT
1048
max_threshold::RealT
1049
max_threshold_secondary::RealT
1050
indicator_primary::IndicatorPrimary
1051
indicator_secondary::IndicatorSecondary
1052
cache::Cache
1053
end
1054
1055
function ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1056
base_level = 1,
1057
med_level = base_level, med_threshold = 0.0,
1058
max_level = base_level, max_threshold = 1.0,
1059
max_threshold_secondary = 1.0)
1060
med_threshold, max_threshold, max_threshold_secondary = promote(med_threshold,
1061
max_threshold,
1062
max_threshold_secondary)
1063
cache = create_cache(ControllerThreeLevelCombined, semi)
1064
ControllerThreeLevelCombined{typeof(max_threshold), typeof(indicator_primary),
1065
typeof(indicator_secondary), typeof(cache)}(base_level,
1066
med_level,
1067
max_level,
1068
med_threshold,
1069
max_threshold,
1070
max_threshold_secondary,
1071
indicator_primary,
1072
indicator_secondary,
1073
cache)
1074
end
1075
1076
max_level(controller::ControllerThreeLevelCombined) = controller.max_level
1077
1078
function create_cache(indicator_type::Type{ControllerThreeLevelCombined}, semi)
1079
create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
1080
end
1081
1082
function Base.show(io::IO, controller::ControllerThreeLevelCombined)
1083
@nospecialize controller # reduce precompilation time
1084
1085
print(io, "ControllerThreeLevelCombined(")
1086
print(io, controller.indicator_primary)
1087
print(io, ", ", controller.indicator_secondary)
1088
print(io, ", base_level=", controller.base_level)
1089
print(io, ", med_level=", controller.med_level)
1090
print(io, ", max_level=", controller.max_level)
1091
print(io, ", med_threshold=", controller.med_threshold)
1092
print(io, ", max_threshold_secondary=", controller.max_threshold_secondary)
1093
print(io, ")")
1094
end
1095
1096
function Base.show(io::IO, mime::MIME"text/plain",
1097
controller::ControllerThreeLevelCombined)
1098
@nospecialize controller # reduce precompilation time
1099
1100
if get(io, :compact, false)
1101
show(io, controller)
1102
else
1103
summary_header(io, "ControllerThreeLevelCombined")
1104
summary_line(io, "primary indicator",
1105
controller.indicator_primary |> typeof |> nameof)
1106
show(increment_indent(io), mime, controller.indicator_primary)
1107
summary_line(io, "secondary indicator",
1108
controller.indicator_secondary |> typeof |> nameof)
1109
show(increment_indent(io), mime, controller.indicator_secondary)
1110
summary_line(io, "base_level", controller.base_level)
1111
summary_line(io, "med_level", controller.med_level)
1112
summary_line(io, "max_level", controller.max_level)
1113
summary_line(io, "med_threshold", controller.med_threshold)
1114
summary_line(io, "max_threshold", controller.max_threshold)
1115
summary_line(io, "max_threshold_secondary", controller.max_threshold_secondary)
1116
summary_footer(io)
1117
end
1118
end
1119
1120
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
1121
controller::ControllerThreeLevelCombined,
1122
amr_callback::AMRCallback;
1123
kwargs...)
1124
# call the indicator to get up-to-date values for IO
1125
controller.indicator_primary(u, mesh, equations, solver, cache; kwargs...)
1126
get_element_variables!(element_variables, controller.indicator_primary,
1127
amr_callback)
1128
end
1129
1130
function (controller::ControllerThreeLevelCombined)(u::AbstractArray{<:Any},
1131
mesh, equations, dg::DG, cache;
1132
kwargs...)
1133
@unpack controller_value = controller.cache
1134
resize!(controller_value, nelements(dg, cache))
1135
1136
alpha = controller.indicator_primary(u, mesh, equations, dg, cache; kwargs...)
1137
alpha_secondary = controller.indicator_secondary(u, mesh, equations, dg, cache)
1138
1139
current_levels = current_element_levels(mesh, dg, cache)
1140
1141
@threaded for element in eachelement(dg, cache)
1142
current_level = current_levels[element]
1143
1144
# set target level
1145
target_level = current_level
1146
if alpha[element] > controller.max_threshold
1147
target_level = controller.max_level
1148
elseif alpha[element] > controller.med_threshold
1149
if controller.med_level > 0
1150
target_level = controller.med_level
1151
# otherwise, target_level = current_level
1152
# set med_level = -1 to implicitly use med_level = current_level
1153
end
1154
else
1155
target_level = controller.base_level
1156
end
1157
1158
if alpha_secondary[element] >= controller.max_threshold_secondary
1159
target_level = controller.max_level
1160
end
1161
1162
# compare target level with actual level to set controller
1163
if current_level < target_level
1164
controller_value[element] = 1 # refine!
1165
elseif current_level > target_level
1166
controller_value[element] = -1 # coarsen!
1167
else
1168
controller_value[element] = 0 # we're good
1169
end
1170
end
1171
1172
return controller_value
1173
end
1174
1175
include("amr_dg.jl")
1176
include("amr_dg1d.jl")
1177
include("amr_dg2d.jl")
1178
include("amr_dg3d.jl")
1179
end # @muladd
1180
1181