Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/time_integration/methods_SSP.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
# Abstract base type for time integration schemes of explicit strong stability-preserving (SSP)
9
# Runge-Kutta (RK) methods. They are high-order time discretizations that guarantee the TVD property.
10
abstract type SimpleAlgorithmSSP <: AbstractTimeIntegrationAlgorithm end
11
12
"""
13
SimpleSSPRK33(; stage_callbacks=())
14
15
The third-order SSP Runge-Kutta method of Shu and Osher.
16
17
## References
18
19
- Shu, Osher (1988)
20
"Efficient Implementation of Essentially Non-oscillatory Shock-Capturing Schemes" (Eq. 2.18)
21
[DOI: 10.1016/0021-9991(88)90177-5](https://doi.org/10.1016/0021-9991(88)90177-5)
22
"""
23
struct SimpleSSPRK33{StageCallbacks} <: SimpleAlgorithmSSP
24
numerator_a::SVector{3, Float64}
25
numerator_b::SVector{3, Float64}
26
denominator::SVector{3, Float64}
27
c::SVector{3, Float64}
28
stage_callbacks::StageCallbacks
29
30
function SimpleSSPRK33(; stage_callbacks = ())
31
# Mathematically speaking, it is not necessary for the algorithm to split the factors
32
# into numerator and denominator. Otherwise, however, rounding errors of the order of
33
# the machine accuracy will occur, which will add up over time and thus endanger the
34
# conservation of the simulation.
35
# See also https://github.com/trixi-framework/Trixi.jl/pull/1640.
36
numerator_a = SVector(0.0, 3.0, 1.0) # a = numerator_a / denominator
37
numerator_b = SVector(1.0, 1.0, 2.0) # b = numerator_b / denominator
38
denominator = SVector(1.0, 4.0, 3.0)
39
c = SVector(0.0, 1.0, 1 / 2)
40
41
# Butcher tableau
42
# c | A
43
# 0 |
44
# 1 | 1
45
# 1/2 | 1/4 1/4
46
# --------------------
47
# b | 1/6 1/6 2/3
48
49
new{typeof(stage_callbacks)}(numerator_a, numerator_b, denominator, c,
50
stage_callbacks)
51
end
52
end
53
54
# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
55
mutable struct SimpleIntegratorSSPOptions{Callback, TStops}
56
callback::Callback # callbacks; used in Trixi
57
adaptive::Bool # whether the algorithm is adaptive; ignored
58
dtmax::Float64 # ignored
59
maxiters::Int # maximal number of time steps
60
tstops::TStops # tstops from https://diffeq.sciml.ai/v6.8/basics/common_solver_opts/#Output-Control-1; ignored
61
end
62
63
function SimpleIntegratorSSPOptions(callback, tspan; maxiters = typemax(Int), kwargs...)
64
tstops_internal = BinaryHeap{eltype(tspan)}(FasterForward())
65
# We add last(tspan) to make sure that the time integration stops at the end time
66
push!(tstops_internal, last(tspan))
67
# We add 2 * last(tspan) because add_tstop!(integrator, t) is only called by DiffEqCallbacks.jl if tstops contains a time that is larger than t
68
# (https://github.com/SciML/DiffEqCallbacks.jl/blob/025dfe99029bd0f30a2e027582744528eb92cd24/src/iterative_and_periodic.jl#L92)
69
push!(tstops_internal, 2 * last(tspan))
70
SimpleIntegratorSSPOptions{typeof(callback), typeof(tstops_internal)}(callback,
71
false, Inf,
72
maxiters,
73
tstops_internal)
74
end
75
76
# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L77
77
# This implements the interface components described at
78
# https://diffeq.sciml.ai/v6.8/basics/integrator/#Handing-Integrators-1
79
# which are used in Trixi.
80
mutable struct SimpleIntegratorSSP{RealT <: Real, uType,
81
Params, Sol, F, Alg,
82
SimpleIntegratorSSPOptions} <: AbstractTimeIntegrator
83
u::uType
84
du::uType
85
u_tmp::uType
86
t::RealT
87
tdir::RealT # DIRection of time integration, i.e., if one marches forward or backward in time
88
dt::RealT # current time step
89
dtcache::RealT # manually set time step
90
iter::Int # current number of time steps (iteration)
91
p::Params # will be the semidiscretization from Trixi
92
sol::Sol # faked
93
f::F # `rhs!` of the semidiscretization
94
alg::Alg # SimpleSSPRK33
95
opts::SimpleIntegratorSSPOptions
96
finalstep::Bool # added for convenience
97
dtchangeable::Bool
98
force_stepfail::Bool
99
end
100
101
"""
102
add_tstop!(integrator::SimpleIntegratorSSP, t)
103
Add a time stop during the time integration process.
104
This function is called after the periodic SaveSolutionCallback to specify the next stop to save the solution.
105
"""
106
function add_tstop!(integrator::SimpleIntegratorSSP, t)
107
integrator.tdir * (t - integrator.t) < zero(integrator.t) &&
108
error("Tried to add a tstop that is behind the current time. This is strictly forbidden")
109
# We need to remove the first entry of tstops when a new entry is added.
110
# Otherwise, the simulation gets stuck at the previous tstop and dt is adjusted to zero.
111
if length(integrator.opts.tstops) > 1
112
pop!(integrator.opts.tstops)
113
end
114
push!(integrator.opts.tstops, integrator.tdir * t)
115
end
116
117
has_tstop(integrator::SimpleIntegratorSSP) = !isempty(integrator.opts.tstops)
118
first_tstop(integrator::SimpleIntegratorSSP) = first(integrator.opts.tstops)
119
120
function init(ode::ODEProblem, alg::SimpleAlgorithmSSP;
121
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
122
u = copy(ode.u0)
123
du = similar(u)
124
u_tmp = similar(u)
125
t = first(ode.tspan)
126
tdir = sign(ode.tspan[end] - ode.tspan[1])
127
iter = 0
128
integrator = SimpleIntegratorSSP(u, du, u_tmp, t, tdir, dt, dt, iter, ode.p,
129
(prob = ode,), ode.f, alg,
130
SimpleIntegratorSSPOptions(callback, ode.tspan;
131
kwargs...),
132
false, true, false)
133
134
# resize container
135
resize!(integrator.p, integrator.p.solver.volume_integral,
136
nelements(integrator.p.solver, integrator.p.cache))
137
138
# Standard callbacks
139
initialize_callbacks!(callback, integrator)
140
141
# Addition for `SimpleAlgorithmSSP` which may have stage callbacks
142
for stage_callback in alg.stage_callbacks
143
init_callback(stage_callback, integrator.p)
144
end
145
146
return integrator
147
end
148
149
function solve!(integrator::SimpleIntegratorSSP)
150
@unpack prob = integrator.sol
151
152
integrator.finalstep = false
153
154
@trixi_timeit timer() "main loop" while !integrator.finalstep
155
step!(integrator)
156
end
157
158
# Empty the tstops array.
159
# This cannot be done in terminate!(integrator::SimpleIntegratorSSP) because DiffEqCallbacks.PeriodicCallbackAffect would return at error.
160
extract_all!(integrator.opts.tstops)
161
162
for stage_callback in integrator.alg.stage_callbacks
163
finalize_callback(stage_callback, integrator.p)
164
end
165
166
finalize_callbacks(integrator)
167
168
return TimeIntegratorSolution((first(prob.tspan), integrator.t),
169
(prob.u0, integrator.u), prob)
170
end
171
172
function step!(integrator::SimpleIntegratorSSP)
173
@unpack prob = integrator.sol
174
@unpack alg = integrator
175
t_end = last(prob.tspan)
176
callbacks = integrator.opts.callback
177
178
@assert !integrator.finalstep
179
if isnan(integrator.dt)
180
error("time step size `dt` is NaN")
181
end
182
183
modify_dt_for_tstops!(integrator)
184
185
limit_dt!(integrator, t_end)
186
187
@. integrator.u_tmp = integrator.u
188
for stage in eachindex(alg.c)
189
t_stage = integrator.t + integrator.dt * alg.c[stage]
190
# compute du
191
integrator.f(integrator.du, integrator.u, integrator.p, t_stage)
192
193
# perform forward Euler step
194
@. integrator.u = integrator.u + integrator.dt * integrator.du
195
196
for stage_callback in alg.stage_callbacks
197
stage_callback(integrator.u, integrator, stage)
198
end
199
200
# perform convex combination
201
@. integrator.u = (alg.numerator_a[stage] * integrator.u_tmp +
202
alg.numerator_b[stage] * integrator.u) /
203
alg.denominator[stage]
204
end
205
integrator.iter += 1
206
integrator.t += integrator.dt
207
208
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)
209
210
check_max_iter!(integrator)
211
212
return nothing
213
end
214
215
# get a cache where the RHS can be stored
216
get_tmp_cache(integrator::SimpleIntegratorSSP) = (integrator.u_tmp,)
217
218
# some algorithms from DiffEq like FSAL-ones need to be informed when a callback has modified u
219
u_modified!(integrator::SimpleIntegratorSSP, ::Bool) = false
220
221
# stop the time integration
222
function terminate!(integrator::SimpleIntegratorSSP)
223
integrator.finalstep = true
224
225
return nothing
226
end
227
228
"""
229
modify_dt_for_tstops!(integrator::SimpleIntegratorSSP)
230
Modify the time-step size to match the time stops specified in integrator.opts.tstops.
231
To avoid adding OrdinaryDiffEq to Trixi's dependencies, this routine is a copy of
232
https://github.com/SciML/OrdinaryDiffEq.jl/blob/d76335281c540ee5a6d1bd8bb634713e004f62ee/src/integrators/integrator_utils.jl#L38-L54
233
"""
234
function modify_dt_for_tstops!(integrator::SimpleIntegratorSSP)
235
if has_tstop(integrator)
236
tdir_t = integrator.tdir * integrator.t
237
tdir_tstop = first_tstop(integrator)
238
if integrator.opts.adaptive
239
integrator.dt = integrator.tdir *
240
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
241
elseif iszero(integrator.dtcache) && integrator.dtchangeable
242
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
243
elseif integrator.dtchangeable && !integrator.force_stepfail
244
# always try to step! with dtcache, but lower if a tstop
245
# however, if force_stepfail then don't set to dtcache, and no tstop worry
246
integrator.dt = integrator.tdir *
247
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
248
end
249
end
250
251
return nothing
252
end
253
254
# used for AMR
255
function Base.resize!(integrator::SimpleIntegratorSSP, new_size)
256
resize!(integrator.u, new_size)
257
resize!(integrator.du, new_size)
258
resize!(integrator.u_tmp, new_size)
259
260
# Resize container
261
# new_size = n_variables * n_nodes^n_dims * n_elements
262
n_elements = nelements(integrator.p.solver, integrator.p.cache)
263
resize!(integrator.p, integrator.p.solver.volume_integral, n_elements)
264
265
return nothing
266
end
267
end # @muladd
268
269