@muladd begin
function rebalance_solver!(u_ode::AbstractVector,
mesh::Union{ParallelP4estMesh, ParallelT8codeMesh},
equations,
dg::DGSEM, cache, old_global_first_quadrant)
global_first_quadrant = get_global_first_element_ids(mesh)
if global_first_quadrant[mpi_rank() + 1] ==
old_global_first_quadrant[mpi_rank() + 1] &&
global_first_quadrant[mpi_rank() + 2] ==
old_global_first_quadrant[mpi_rank() + 2]
reinitialize_containers!(mesh, equations, dg, cache)
return
end
old_n_elements = nelements(dg, cache)
old_u_ode = copy(u_ode)
GC.@preserve old_u_ode begin
old_u = wrap_array_native(old_u_ode, mesh, equations, dg, cache)
@trixi_timeit timer() "reinitialize data structures" begin
reinitialize_containers!(mesh, equations, dg, cache)
end
resize!(u_ode,
nvariables(equations) * nnodes(dg)^ndims(mesh) * nelements(dg, cache))
u = wrap_array_native(u_ode, mesh, equations, dg, cache)
@trixi_timeit timer() "exchange data" begin
requests = Vector{MPI.Request}()
for old_element_id in 1:old_n_elements
global_quad_id = old_global_first_quadrant[mpi_rank() + 1] +
old_element_id - 1
if !(global_first_quadrant[mpi_rank() + 1] <= global_quad_id <
global_first_quadrant[mpi_rank() + 2])
dest = findfirst(r -> global_first_quadrant[r] <= global_quad_id <
global_first_quadrant[r + 1],
1:mpi_nranks()) - 1
request = MPI.Isend(@view(old_u[:, .., old_element_id]), dest,
global_quad_id, mpi_comm())
push!(requests, request)
end
end
for element in eachelement(dg, cache)
global_quad_id = global_first_quadrant[mpi_rank() + 1] + element - 1
if old_global_first_quadrant[mpi_rank() + 1] <= global_quad_id <
old_global_first_quadrant[mpi_rank() + 2]
old_element_id = global_quad_id -
old_global_first_quadrant[mpi_rank() + 1] + 1
@views u[:, .., element] .= old_u[:, .., old_element_id]
else
src = findfirst(r -> old_global_first_quadrant[r] <=
global_quad_id <
old_global_first_quadrant[r + 1],
1:mpi_nranks()) - 1
request = MPI.Irecv!(@view(u[:, .., element]), src, global_quad_id,
mpi_comm())
push!(requests, request)
end
end
MPI.Waitall(requests, MPI.Status)
end
end
return nothing
end
function create_cache(::Union{Type{ControllerThreeLevel},
Type{ControllerThreeLevelCombined}},
mesh::Union{TreeMesh, P4estMesh, T8codeMesh},
equations, dg::DG, cache)
controller_value = Vector{Int}(undef, nelements(dg, cache))
return (; controller_value)
end
end