Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/amr_dg.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
# Redistribute data for load balancing after partitioning the mesh
9
function rebalance_solver!(u_ode::AbstractVector,
10
mesh::Union{ParallelP4estMesh, ParallelT8codeMesh},
11
equations,
12
dg::DGSEM, cache, old_global_first_quadrant)
13
14
# MPI ranks are 0-based. This array uses 1-based indices.
15
global_first_quadrant = get_global_first_element_ids(mesh)
16
17
if global_first_quadrant[mpi_rank() + 1] ==
18
old_global_first_quadrant[mpi_rank() + 1] &&
19
global_first_quadrant[mpi_rank() + 2] ==
20
old_global_first_quadrant[mpi_rank() + 2]
21
# Global ids of first and last local quadrants are the same for newly partitioned mesh so the
22
# solver does not need to be rebalanced on this rank.
23
# Container init uses all-to-all communication -> reinitialize even if there is nothing to do
24
# locally (there are other MPI ranks that need to be rebalanced if this function is called)
25
reinitialize_containers!(mesh, equations, dg, cache)
26
return
27
end
28
# Retain current solution data
29
old_n_elements = nelements(dg, cache)
30
old_u_ode = copy(u_ode)
31
GC.@preserve old_u_ode begin # OBS! If we don't GC.@preserve old_u_ode, it might be GC'ed
32
# Use `wrap_array_native` instead of `wrap_array` since MPI might not interact
33
# nicely with non-base array types
34
old_u = wrap_array_native(old_u_ode, mesh, equations, dg, cache)
35
36
@trixi_timeit timer() "reinitialize data structures" begin
37
reinitialize_containers!(mesh, equations, dg, cache)
38
end
39
40
resize!(u_ode,
41
nvariables(equations) * nnodes(dg)^ndims(mesh) * nelements(dg, cache))
42
u = wrap_array_native(u_ode, mesh, equations, dg, cache)
43
44
@trixi_timeit timer() "exchange data" begin
45
# Collect MPI requests for MPI_Waitall
46
requests = Vector{MPI.Request}()
47
# Find elements that will change their rank and send their data to the new rank
48
for old_element_id in 1:old_n_elements
49
# Get global quad ID of old element; local quad id is element id - 1
50
global_quad_id = old_global_first_quadrant[mpi_rank() + 1] +
51
old_element_id - 1
52
if !(global_first_quadrant[mpi_rank() + 1] <= global_quad_id <
53
global_first_quadrant[mpi_rank() + 2])
54
# Send element data to new rank, use global_quad_id as tag (non-blocking)
55
dest = findfirst(r -> global_first_quadrant[r] <= global_quad_id <
56
global_first_quadrant[r + 1],
57
1:mpi_nranks()) - 1 # mpi ranks 0-based
58
request = MPI.Isend(@view(old_u[:, .., old_element_id]), dest,
59
global_quad_id, mpi_comm())
60
push!(requests, request)
61
end
62
end
63
64
# Loop over all elements in new container and either copy them from old container
65
# or receive them with MPI
66
for element in eachelement(dg, cache)
67
# Get global quad ID of element; local quad id is element id - 1
68
global_quad_id = global_first_quadrant[mpi_rank() + 1] + element - 1
69
if old_global_first_quadrant[mpi_rank() + 1] <= global_quad_id <
70
old_global_first_quadrant[mpi_rank() + 2]
71
# Quad ids are 0-based, element ids are 1-based, hence add 1
72
old_element_id = global_quad_id -
73
old_global_first_quadrant[mpi_rank() + 1] + 1
74
# Copy old element data to new element container
75
@views u[:, .., element] .= old_u[:, .., old_element_id]
76
else
77
# Receive old element data
78
src = findfirst(r -> old_global_first_quadrant[r] <=
79
global_quad_id <
80
old_global_first_quadrant[r + 1],
81
1:mpi_nranks()) - 1 # mpi ranks 0-based
82
request = MPI.Irecv!(@view(u[:, .., element]), src, global_quad_id,
83
mpi_comm())
84
push!(requests, request)
85
end
86
end
87
88
# Wait for all non-blocking MPI send/receive operations to finish
89
MPI.Waitall(requests, MPI.Status)
90
end
91
end # GC.@preserve old_u_ode
92
93
return nothing
94
end
95
96
# Construct cache for ControllerThreeLevel and ControllerThreeLevelCombined.
97
# This method is called when a controller is constructed
98
function create_cache(::Union{Type{ControllerThreeLevel},
99
Type{ControllerThreeLevelCombined}},
100
mesh::Union{TreeMesh, P4estMesh, T8codeMesh},
101
equations, dg::DG, cache)
102
controller_value = Vector{Int}(undef, nelements(dg, cache))
103
return (; controller_value)
104
end
105
end # @muladd
106
107