@muladd begin
"""
AMRCallback(semi, controller [,adaptor=AdaptorAMR(semi)];
interval,
adapt_initial_condition=true,
adapt_initial_condition_only_refine=true,
dynamic_load_balancing=true)
Performs adaptive mesh refinement (AMR) every `interval` time steps
for a given semidiscretization `semi` using the chosen `controller`.
"""
struct AMRCallback{Controller, Adaptor, Cache}
controller::Controller
interval::Int
adapt_initial_condition::Bool
adapt_initial_condition_only_refine::Bool
dynamic_load_balancing::Bool
adaptor::Adaptor
amr_cache::Cache
end
function AMRCallback(semi, controller, adaptor;
interval,
adapt_initial_condition = true,
adapt_initial_condition_only_refine = true,
dynamic_load_balancing = true)
if !(interval isa Integer && interval >= 0)
throw(ArgumentError("`interval` must be a non-negative integer (provided `interval = $interval`)"))
end
if interval > 0
condition = (u, t, integrator) -> ((integrator.stats.naccept % interval == 0) &&
!(integrator.stats.naccept == 0 &&
integrator.iter > 0) &&
!isfinished(integrator))
else
condition = (u, t, integrator) -> false
end
to_refine = Int[]
to_coarsen = Int[]
amr_cache = (; to_refine, to_coarsen)
amr_callback = AMRCallback{typeof(controller), typeof(adaptor), typeof(amr_cache)}(controller,
interval,
adapt_initial_condition,
adapt_initial_condition_only_refine,
dynamic_load_balancing,
adaptor,
amr_cache)
DiscreteCallback(condition, amr_callback,
save_positions = (false, false),
initialize = initialize!)
end
function AMRCallback(semi, controller; kwargs...)
adaptor = AdaptorAMR(semi)
AMRCallback(semi, controller, adaptor; kwargs...)
end
function AdaptorAMR(semi; kwargs...)
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
AdaptorAMR(mesh, solver; kwargs...)
end
function Base.show(io::IO, mime::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:AMRCallback})
@nospecialize cb
if get(io, :compact, false)
show(io, cb)
else
amr_callback = cb.affect!
summary_header(io, "AMRCallback")
summary_line(io, "controller", amr_callback.controller |> typeof |> nameof)
show(increment_indent(io), mime, amr_callback.controller)
summary_line(io, "interval", amr_callback.interval)
summary_line(io, "adapt IC",
amr_callback.adapt_initial_condition ? "yes" : "no")
if amr_callback.adapt_initial_condition
summary_line(io, "│ only refine",
amr_callback.adapt_initial_condition_only_refine ? "yes" :
"no")
end
summary_footer(io)
end
end
"""
uses_amr(callback)
Checks whether the provided callback or `CallbackSet` is an [`AMRCallback`](@ref)
or contains one.
"""
uses_amr(cb) = false
function uses_amr(cb::DiscreteCallback{Condition, Affect!}) where {Condition,
Affect! <:
AMRCallback}
true
end
uses_amr(callbacks::CallbackSet) = mapreduce(uses_amr, |, callbacks.discrete_callbacks)
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
amr_callback::AMRCallback; kwargs...)
get_element_variables!(element_variables, u, mesh, equations, solver, cache,
amr_callback.controller, amr_callback; kwargs...)
end
function initialize!(cb::DiscreteCallback{Condition, Affect!}, u, t,
integrator) where {Condition, Affect! <: AMRCallback}
amr_callback = cb.affect!
semi = integrator.p
@trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
has_changed = amr_callback(integrator,
only_refine = amr_callback.adapt_initial_condition_only_refine)
iterations = 1
while has_changed
compute_coefficients!(integrator.u, t, semi)
u_modified!(integrator, true)
has_changed = amr_callback(integrator,
only_refine = amr_callback.adapt_initial_condition_only_refine)
iterations = iterations + 1
allowed_max_iterations = max(10, max_level(amr_callback.controller))
if iterations > allowed_max_iterations
@warn "AMR for initial condition did not settle within $(allowed_max_iterations) iterations!\n" *
"Consider adjusting thresholds or setting `adapt_initial_condition_only_refine`."
break
end
end
index = findfirst(cb -> cb.affect! isa AnalysisCallback,
integrator.opts.callback.discrete_callbacks)
if !isnothing(index)
analysis_callback = integrator.opts.callback.discrete_callbacks[index].affect!
initial_state_integrals = integrate(integrator.u, semi)
analysis_callback.initial_state_integrals = initial_state_integrals
end
end
return nothing
end
function (amr_callback::AMRCallback)(integrator; kwargs...)
u_ode = integrator.u
semi = integrator.p
@trixi_timeit timer() "AMR" begin
has_changed = amr_callback(u_ode, semi,
integrator.t, integrator.iter; kwargs...)
if has_changed
resize!(integrator, length(u_ode))
u_modified!(integrator, true)
end
end
return has_changed
end
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
semi::SemidiscretizationHyperbolic,
t, iter;
kwargs...)
amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi, t, iter; kwargs...)
end
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
semi::SemidiscretizationHyperbolicParabolic,
t, iter;
kwargs...)
amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi.cache_parabolic,
semi, t, iter; kwargs...)
end
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
equations, dg::DG, cache, semi,
t, iter;
only_refine = false, only_coarsen = false,
passive_args = ())
@unpack controller, adaptor = amr_callback
u = wrap_array(u_ode, mesh, equations, dg, cache)
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
t = t, iter = iter)
if mpi_isparallel()
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
lambda = lambda_global
end
leaf_cell_ids = leaf_cells(mesh.tree)
@boundscheck begin
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
end
@unpack to_refine, to_coarsen = amr_callback.amr_cache
empty!(to_refine)
empty!(to_coarsen)
for element in eachindex(lambda)
controller_value = lambda[element]
if controller_value > 0
push!(to_refine, leaf_cell_ids[element])
elseif controller_value < 0
push!(to_coarsen, leaf_cell_ids[element])
end
end
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
to_refine)
elements_to_refine = findall(in(refined_original_cells),
cache.elements.cell_ids)
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
cache, elements_to_refine)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
p_equations, p_dg, p_cache,
elements_to_refine)
end
else
refined_original_cells = Int[]
end
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
if !isempty(to_coarsen)
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
end
parents_to_coarsen = zeros(Int, length(mesh.tree))
for cell_id in to_coarsen
if !has_parent(mesh.tree, cell_id)
continue
end
if !is_leaf(mesh.tree, cell_id)
continue
end
parent_id = mesh.tree.parent_ids[cell_id]
parents_to_coarsen[parent_id] += 1
end
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
to_coarsen)
removed_child_cells = zeros(Int,
n_children_per_cell(mesh.tree) *
length(coarsened_original_cells))
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
for child in 1:n_children_per_cell(mesh.tree)
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
child
end
end
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
cache, elements_to_remove)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
p_equations, p_dg, p_cache,
elements_to_remove)
end
else
coarsened_original_cells = Int[]
end
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
if has_changed
mesh.unsaved_changes = true
end
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
@trixi_timeit timer() "dynamic load balancing" begin
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
partition!(mesh)
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
end
end
return has_changed
end
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
equations, dg::DG,
cache, cache_parabolic,
semi::SemidiscretizationHyperbolicParabolic,
t, iter;
only_refine = false, only_coarsen = false)
@unpack controller, adaptor = amr_callback
u = wrap_array(u_ode, mesh, equations, dg, cache)
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
t = t, iter = iter)
if mpi_isparallel()
error("MPI has not been verified yet for parabolic AMR")
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
lambda = lambda_global
end
leaf_cell_ids = leaf_cells(mesh.tree)
@boundscheck begin
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
end
@unpack to_refine, to_coarsen = amr_callback.amr_cache
empty!(to_refine)
empty!(to_coarsen)
for element in eachindex(lambda)
controller_value = lambda[element]
if controller_value > 0
push!(to_refine, leaf_cell_ids[element])
elseif controller_value < 0
push!(to_coarsen, leaf_cell_ids[element])
end
end
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
to_refine)
elements_to_refine = findall(in(refined_original_cells),
cache.elements.cell_ids)
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
cache, cache_parabolic,
elements_to_refine)
else
refined_original_cells = Int[]
end
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
if !isempty(to_coarsen)
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
end
parents_to_coarsen = zeros(Int, length(mesh.tree))
for cell_id in to_coarsen
if !has_parent(mesh.tree, cell_id)
continue
end
if !is_leaf(mesh.tree, cell_id)
continue
end
parent_id = mesh.tree.parent_ids[cell_id]
parents_to_coarsen[parent_id] += 1
end
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
to_coarsen)
removed_child_cells = zeros(Int,
n_children_per_cell(mesh.tree) *
length(coarsened_original_cells))
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
for child in 1:n_children_per_cell(mesh.tree)
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
child
end
end
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
cache, cache_parabolic,
elements_to_remove)
else
coarsened_original_cells = Int[]
end
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
if has_changed
mesh.unsaved_changes = true
end
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
error("MPI has not been verified yet for parabolic AMR")
@trixi_timeit timer() "dynamic load balancing" begin
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
partition!(mesh)
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
end
end
return has_changed
end
function copy_to_quad_iter_volume(info, user_data)
info_pw = PointerWrapper(info)
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
offset = tree_pw.quadrants_offset[]
quad_id = offset + info_pw.quadid[]
user_data_pw = PointerWrapper(Int, user_data)
controller_value = user_data_pw[quad_id + 1]
quad_data_pw = PointerWrapper(Int, info_pw.quad.p.user_data[])
quad_data_pw[2] = controller_value
return nothing
end
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
equations, dg::DG, cache, cache_parabolic,
semi,
t, iter;
only_refine = false, only_coarsen = false,
passive_args = ())
@unpack controller, adaptor = amr_callback
u = wrap_array(u_ode, mesh, equations, dg, cache)
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
t = t, iter = iter)
@boundscheck begin
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
end
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
@assert lambda isa Vector{Int}
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
@trixi_timeit timer() "refine" if !only_coarsen
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
cache, cache_parabolic,
refined_original_cells)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
p_equations,
p_dg, p_cache,
refined_original_cells)
end
else
refined_original_cells = Int[]
end
@trixi_timeit timer() "coarsen" if !only_refine
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
cache, cache_parabolic,
coarsened_original_cells)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
p_equations,
p_dg, p_cache,
coarsened_original_cells)
end
else
coarsened_original_cells = Int[]
end
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
if mpi_isparallel()
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
end
if has_changed
mesh.unsaved_changes = true
if mpi_isparallel() && amr_callback.dynamic_load_balancing
@trixi_timeit timer() "dynamic load balancing" begin
global_first_quadrant = unsafe_wrap(Array,
unsafe_load(mesh.p4est).global_first_quadrant,
mpi_nranks() + 1)
old_global_first_quadrant = copy(global_first_quadrant)
partition!(mesh)
rebalance_solver!(u_ode, mesh, equations, dg, cache,
old_global_first_quadrant)
end
end
reinitialize_boundaries!(semi.boundary_conditions, cache)
if hasproperty(semi, :boundary_conditions_parabolic)
reinitialize_boundaries!(semi.boundary_conditions_parabolic, cache)
end
end
return has_changed
end
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{2})
@cfunction(copy_to_quad_iter_volume, Cvoid,
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
end
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{3})
@cfunction(copy_to_quad_iter_volume, Cvoid,
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
end
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
equations, dg::DG, cache, semi,
t, iter;
only_refine = false, only_coarsen = false,
passive_args = ())
@unpack controller, adaptor = amr_callback
u = wrap_array(u_ode, mesh, equations, dg, cache)
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
t = t, iter = iter)
@boundscheck begin
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
end
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
@assert lambda isa Vector{Int}
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
@trixi_timeit timer() "refine" if !only_coarsen
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
cache,
refined_original_cells)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
p_equations,
p_dg, p_cache,
refined_original_cells)
end
else
refined_original_cells = Int[]
end
@trixi_timeit timer() "coarsen" if !only_refine
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
cache,
coarsened_original_cells)
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
p_equations,
p_dg, p_cache,
coarsened_original_cells)
end
else
coarsened_original_cells = Int[]
end
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
if mpi_isparallel()
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
end
if has_changed
mesh.unsaved_changes = true
if mpi_isparallel() && amr_callback.dynamic_load_balancing
@trixi_timeit timer() "dynamic load balancing" begin
global_first_quadrant = unsafe_wrap(Array,
unsafe_load(mesh.p4est).global_first_quadrant,
mpi_nranks() + 1)
old_global_first_quadrant = copy(global_first_quadrant)
partition!(mesh)
rebalance_solver!(u_ode, mesh, equations, dg, cache,
old_global_first_quadrant)
end
end
reinitialize_boundaries!(semi.boundary_conditions, cache)
end
return has_changed
end
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::T8codeMesh,
equations, dg::DG, cache, semi,
t, iter;
only_refine = false, only_coarsen = false,
passive_args = ())
has_changed = false
@unpack controller, adaptor = amr_callback
u = wrap_array(u_ode, mesh, equations, dg, cache)
indicators = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg,
cache, t = t, iter = iter)
if only_coarsen
indicators[indicators .> 0] .= 0
end
if only_refine
indicators[indicators .< 0] .= 0
end
@boundscheck begin
@assert axes(indicators)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(indicators))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
end
@trixi_timeit timer() "adapt" begin
difference = @trixi_timeit timer() "mesh" trixi_t8_adapt!(mesh, indicators)
has_changed = any(difference .!= 0)
if mpi_isparallel()
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
end
if has_changed
@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
cache, difference)
end
end
if has_changed
if mpi_isparallel() && amr_callback.dynamic_load_balancing
@trixi_timeit timer() "dynamic load balancing" begin
old_global_first_element_ids = get_global_first_element_ids(mesh)
partition!(mesh)
rebalance_solver!(u_ode, mesh, equations, dg, cache,
old_global_first_element_ids)
end
end
reinitialize_boundaries!(semi.boundary_conditions, cache)
end
mesh.unsaved_changes |= has_changed
return has_changed
end
function reinitialize_boundaries!(boundary_conditions::UnstructuredSortedBoundaryTypes,
cache)
initialize!(boundary_conditions, cache)
end
function reinitialize_boundaries!(boundary_conditions, cache)
return boundary_conditions
end
function original2refined(original_cell_ids, refined_original_cells, mesh)
@assert issorted(original_cell_ids) "`original_cell_ids` not sorted"
@assert issorted(refined_original_cells) "`refined_cell_ids` not sorted"
shifted_cell_ids = collect(1:original_cell_ids[end])
for cell_id in refined_original_cells
if cell_id > length(shifted_cell_ids)
break
end
shifted_cell_ids[(cell_id + 1):end] .+= 2^ndims(mesh)
end
return shifted_cell_ids[original_cell_ids]
end
"""
ControllerThreeLevel(semi, indicator; base_level=1,
med_level=base_level, med_threshold=0.0,
max_level=base_level, max_threshold=1.0)
An AMR controller based on three levels (in descending order of precedence):
- set the target level to `max_level` if `indicator > max_threshold`
- set the target level to `med_level` if `indicator > med_threshold`;
if `med_level < 0`, set the target level to the current level
- set the target level to `base_level` otherwise
"""
struct ControllerThreeLevel{RealT <: Real, Indicator, Cache}
base_level::Int
med_level::Int
max_level::Int
med_threshold::RealT
max_threshold::RealT
indicator::Indicator
cache::Cache
end
function ControllerThreeLevel(semi, indicator; base_level = 1,
med_level = base_level, med_threshold = 0.0,
max_level = base_level, max_threshold = 1.0)
med_threshold, max_threshold = promote(med_threshold, max_threshold)
cache = create_cache(ControllerThreeLevel, semi)
ControllerThreeLevel{typeof(max_threshold), typeof(indicator), typeof(cache)}(base_level,
med_level,
max_level,
med_threshold,
max_threshold,
indicator,
cache)
end
max_level(controller::ControllerThreeLevel) = controller.max_level
function create_cache(indicator_type::Type{ControllerThreeLevel}, semi)
create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
end
function Base.show(io::IO, controller::ControllerThreeLevel)
@nospecialize controller
print(io, "ControllerThreeLevel(")
print(io, controller.indicator)
print(io, ", base_level=", controller.base_level)
print(io, ", med_level=", controller.med_level)
print(io, ", max_level=", controller.max_level)
print(io, ", med_threshold=", controller.med_threshold)
print(io, ", max_threshold=", controller.max_threshold)
print(io, ")")
end
function Base.show(io::IO, mime::MIME"text/plain", controller::ControllerThreeLevel)
@nospecialize controller
if get(io, :compact, false)
show(io, controller)
else
summary_header(io, "ControllerThreeLevel")
summary_line(io, "indicator", controller.indicator |> typeof |> nameof)
show(increment_indent(io), mime, controller.indicator)
summary_line(io, "base_level", controller.base_level)
summary_line(io, "med_level", controller.med_level)
summary_line(io, "max_level", controller.max_level)
summary_line(io, "med_threshold", controller.med_threshold)
summary_line(io, "max_threshold", controller.max_threshold)
summary_footer(io)
end
end
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
controller::ControllerThreeLevel,
amr_callback::AMRCallback;
kwargs...)
controller.indicator(u, mesh, equations, solver, cache; kwargs...)
get_element_variables!(element_variables, controller.indicator, amr_callback)
end
function get_element_variables!(element_variables, indicator::AbstractIndicator,
::AMRCallback)
element_variables[:indicator_amr] = indicator.cache.alpha
return nothing
end
function current_element_levels(mesh::TreeMesh, solver, cache)
cell_ids = cache.elements.cell_ids[eachelement(solver, cache)]
return mesh.tree.levels[cell_ids]
end
function extract_levels_iter_volume(info, user_data)
info_pw = PointerWrapper(info)
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
offset = tree_pw.quadrants_offset[]
quad_id = offset + info_pw.quadid[]
element_id = quad_id + 1
current_level = info_pw.quad.level[]
pw = PointerWrapper(Int, user_data)
pw[element_id] = current_level
return nothing
end
function cfunction(::typeof(extract_levels_iter_volume), ::Val{2})
@cfunction(extract_levels_iter_volume, Cvoid,
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
end
function cfunction(::typeof(extract_levels_iter_volume), ::Val{3})
@cfunction(extract_levels_iter_volume, Cvoid,
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
end
function current_element_levels(mesh::P4estMesh, solver, cache)
current_levels = Vector{Int}(undef, nelements(solver, cache))
iter_volume_c = cfunction(extract_levels_iter_volume, Val(ndims(mesh)))
iterate_p4est(mesh.p4est, current_levels; iter_volume_c = iter_volume_c)
return current_levels
end
function current_element_levels(mesh::T8codeMesh, solver, cache)
return trixi_t8_get_local_element_levels(mesh.forest)
end
function (controller::ControllerThreeLevel)(u::AbstractArray{<:Any},
mesh, equations, dg::DG, cache;
kwargs...)
@unpack controller_value = controller.cache
resize!(controller_value, nelements(dg, cache))
alpha = controller.indicator(u, mesh, equations, dg, cache; kwargs...)
current_levels = current_element_levels(mesh, dg, cache)
@threaded for element in eachelement(dg, cache)
current_level = current_levels[element]
target_level = current_level
if alpha[element] > controller.max_threshold
target_level = controller.max_level
elseif alpha[element] > controller.med_threshold
if controller.med_level > 0
target_level = controller.med_level
end
else
target_level = controller.base_level
end
if current_level < target_level
controller_value[element] = 1
elseif current_level > target_level
controller_value[element] = -1
else
controller_value[element] = 0
end
end
return controller_value
end
"""
ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
base_level=1,
med_level=base_level, med_threshold=0.0,
max_level=base_level, max_threshold=1.0,
max_threshold_secondary=1.0)
An AMR controller based on three levels (in descending order of precedence):
- set the target level to `max_level` if `indicator_primary > max_threshold`
- set the target level to `med_level` if `indicator_primary > med_threshold`;
if `med_level < 0`, set the target level to the current level
- set the target level to `base_level` otherwise
If `indicator_secondary >= max_threshold_secondary`,
set the target level to `max_level`.
"""
struct ControllerThreeLevelCombined{RealT <: Real, IndicatorPrimary, IndicatorSecondary,
Cache}
base_level::Int
med_level::Int
max_level::Int
med_threshold::RealT
max_threshold::RealT
max_threshold_secondary::RealT
indicator_primary::IndicatorPrimary
indicator_secondary::IndicatorSecondary
cache::Cache
end
function ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
base_level = 1,
med_level = base_level, med_threshold = 0.0,
max_level = base_level, max_threshold = 1.0,
max_threshold_secondary = 1.0)
med_threshold, max_threshold, max_threshold_secondary = promote(med_threshold,
max_threshold,
max_threshold_secondary)
cache = create_cache(ControllerThreeLevelCombined, semi)
ControllerThreeLevelCombined{typeof(max_threshold), typeof(indicator_primary),
typeof(indicator_secondary), typeof(cache)}(base_level,
med_level,
max_level,
med_threshold,
max_threshold,
max_threshold_secondary,
indicator_primary,
indicator_secondary,
cache)
end
max_level(controller::ControllerThreeLevelCombined) = controller.max_level
function create_cache(indicator_type::Type{ControllerThreeLevelCombined}, semi)
create_cache(indicator_type, mesh_equations_solver_cache(semi)...)
end
function Base.show(io::IO, controller::ControllerThreeLevelCombined)
@nospecialize controller
print(io, "ControllerThreeLevelCombined(")
print(io, controller.indicator_primary)
print(io, ", ", controller.indicator_secondary)
print(io, ", base_level=", controller.base_level)
print(io, ", med_level=", controller.med_level)
print(io, ", max_level=", controller.max_level)
print(io, ", med_threshold=", controller.med_threshold)
print(io, ", max_threshold_secondary=", controller.max_threshold_secondary)
print(io, ")")
end
function Base.show(io::IO, mime::MIME"text/plain",
controller::ControllerThreeLevelCombined)
@nospecialize controller
if get(io, :compact, false)
show(io, controller)
else
summary_header(io, "ControllerThreeLevelCombined")
summary_line(io, "primary indicator",
controller.indicator_primary |> typeof |> nameof)
show(increment_indent(io), mime, controller.indicator_primary)
summary_line(io, "secondary indicator",
controller.indicator_secondary |> typeof |> nameof)
show(increment_indent(io), mime, controller.indicator_secondary)
summary_line(io, "base_level", controller.base_level)
summary_line(io, "med_level", controller.med_level)
summary_line(io, "max_level", controller.max_level)
summary_line(io, "med_threshold", controller.med_threshold)
summary_line(io, "max_threshold", controller.max_threshold)
summary_line(io, "max_threshold_secondary", controller.max_threshold_secondary)
summary_footer(io)
end
end
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
controller::ControllerThreeLevelCombined,
amr_callback::AMRCallback;
kwargs...)
controller.indicator_primary(u, mesh, equations, solver, cache; kwargs...)
get_element_variables!(element_variables, controller.indicator_primary,
amr_callback)
end
function (controller::ControllerThreeLevelCombined)(u::AbstractArray{<:Any},
mesh, equations, dg::DG, cache;
kwargs...)
@unpack controller_value = controller.cache
resize!(controller_value, nelements(dg, cache))
alpha = controller.indicator_primary(u, mesh, equations, dg, cache; kwargs...)
alpha_secondary = controller.indicator_secondary(u, mesh, equations, dg, cache)
current_levels = current_element_levels(mesh, dg, cache)
@threaded for element in eachelement(dg, cache)
current_level = current_levels[element]
target_level = current_level
if alpha[element] > controller.max_threshold
target_level = controller.max_level
elseif alpha[element] > controller.med_threshold
if controller.med_level > 0
target_level = controller.med_level
end
else
target_level = controller.base_level
end
if alpha_secondary[element] >= controller.max_threshold_secondary
target_level = controller.max_level
end
if current_level < target_level
controller_value[element] = 1
elseif current_level > target_level
controller_value[element] = -1
else
controller_value[element] = 0
end
end
return controller_value
end
include("amr_dg.jl")
include("amr_dg1d.jl")
include("amr_dg2d.jl")
include("amr_dg3d.jl")
end