Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/auxiliary/mpi.jl
2055 views
1
2
"""
3
init_mpi()
4
5
Initialize MPI by calling `MPI.Initialized()`. The function will check if MPI is already initialized
6
and if yes, do nothing, thus it is safe to call it multiple times.
7
"""
8
function init_mpi()
9
if MPI_INITIALIZED[]
10
return nothing
11
end
12
13
# MPI.jl handles multiple calls to MPI.Init appropriately. Thus, we don't need
14
# any common checks of the form `if MPI.Initialized() ...`.
15
# threadlevel=MPI.THREAD_FUNNELED: Only main thread makes MPI calls
16
# finalize_atexit=true : MPI.jl will call call MPI.Finalize as `atexit` hook
17
provided = MPI.Init(threadlevel = MPI.THREAD_FUNNELED, finalize_atexit = true)
18
@assert provided>=MPI.THREAD_FUNNELED "MPI library with insufficient threading support"
19
20
# Initialize global MPI state
21
MPI_RANK[] = MPI.Comm_rank(MPI.COMM_WORLD)
22
MPI_SIZE[] = MPI.Comm_size(MPI.COMM_WORLD)
23
MPI_IS_PARALLEL[] = MPI_SIZE[] > 1
24
MPI_IS_SERIAL[] = !MPI_IS_PARALLEL[]
25
MPI_IS_ROOT[] = MPI_IS_SERIAL[] || MPI_RANK[] == 0
26
MPI_INITIALIZED[] = true
27
28
return nothing
29
end
30
31
const MPI_INITIALIZED = Ref(false)
32
const MPI_RANK = Ref(-1)
33
const MPI_SIZE = Ref(-1)
34
const MPI_IS_PARALLEL = Ref(false)
35
const MPI_IS_SERIAL = Ref(true)
36
const MPI_IS_ROOT = Ref(true)
37
38
@inline mpi_comm() = MPI.COMM_WORLD
39
40
@inline mpi_rank() = MPI_RANK[]
41
42
@inline mpi_nranks() = MPI_SIZE[]
43
44
@inline mpi_isparallel() = MPI_IS_PARALLEL[]
45
46
@inline mpi_isroot() = MPI_IS_ROOT[]
47
48
@inline mpi_root() = 0
49
50
@inline function mpi_println(args...)
51
if mpi_isroot()
52
println(args...)
53
end
54
return nothing
55
end
56
@inline function mpi_print(args...)
57
if mpi_isroot()
58
print(args...)
59
end
60
return nothing
61
end
62
63
"""
64
ode_norm(u, t)
65
66
Implementation of the weighted L2 norm of Hairer and Wanner used for error-based
67
step size control in OrdinaryDiffEq.jl. This function is aware of MPI and uses
68
global MPI communication when running in parallel.
69
70
You must pass this function as a keyword argument
71
`internalnorm = ode_norm`
72
to OrdinaryDiffEq.jl's `solve` when using error-based step size control with MPI
73
parallel execution of Trixi.jl.
74
75
See the "Advanced Adaptive Stepsize Control" section of the
76
[documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
77
"""
78
ode_norm(u::Number, t) = @fastmath abs(u)
79
function ode_norm(u::AbstractArray, t)
80
local_sumabs2 = recursive_sum_abs2(u) # sum(abs2, u)
81
local_length = recursive_length(u) # length(u)
82
if mpi_isparallel()
83
global_sumabs2, global_length = MPI.Allreduce([local_sumabs2, local_length], +,
84
mpi_comm())
85
return sqrt(global_sumabs2 / global_length)
86
else
87
return sqrt(local_sumabs2 / local_length)
88
end
89
end
90
91
# Recursive `sum(abs2, ...)` and `length(...)` are required when dealing with
92
# arrays of arrays, e.g., when using `DGMulti` solvers with an array-of-structs
93
# (`Array{SVector}`) or a structure-of-arrays (`StructArray`). We need to take
94
# care of these situations when allowing to use `ode_norm` as default norm in
95
# OrdinaryDiffEq.jl throughout all applications of Trixi.jl.
96
recursive_sum_abs2(u::Number) = abs2(u)
97
# Use `mapreduce` instead of `sum` since `sum` from StaticArrays.jl does not
98
# support the kwarg `init`
99
# We need `init=zero(eltype(eltype(u))` below to deal with arrays of `SVector`s etc.
100
# A better solution would be `recursive_unitless_bottom_eltype` from
101
# https://github.com/SciML/RecursiveArrayTools.jl
102
# However, what you have is good enough for us for now, so we don't need this
103
# additional dependency at the moment.
104
function recursive_sum_abs2(u::AbstractArray)
105
mapreduce(recursive_sum_abs2, +, u; init = zero(eltype(eltype(u))))
106
end
107
108
recursive_length(u::Number) = length(u)
109
recursive_length(u::AbstractArray{<:Number}) = length(u)
110
recursive_length(u::AbstractArray{<:AbstractArray}) = sum(recursive_length, u)
111
function recursive_length(u::AbstractArray{<:StaticArrays.StaticArray{S,
112
<:Number}}) where {S}
113
prod(StaticArrays.Size(eltype(u))) * length(u)
114
end
115
116
"""
117
ode_unstable_check(dt, u, semi, t)
118
119
Implementation of the basic check for instability used in OrdinaryDiffEq.jl.
120
Instead of checking something like `any(isnan, u)`, this function just checks
121
`isnan(dt)`. This helps when using MPI parallelization, since no additional
122
global communication is required and all ranks will return the same result.
123
124
You should pass this function as a keyword argument
125
`unstable_check=ode_unstable_check`
126
to OrdinaryDiffEq.jl's `solve` when using error-based step size control with MPI
127
parallel execution of Trixi.jl.
128
129
See the "Miscellaneous" section of the
130
[documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
131
"""
132
ode_unstable_check(dt, u, semi, t) = isnan(dt)
133
134