Path: blob/main/src/callbacks_step/stepsize.jl
2055 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67"""8StepsizeCallback(; cfl=1.0, interval = 1)910Set the time step size according to a CFL condition with CFL number `cfl`11if the time integration method isn't adaptive itself.1213The supplied keyword argument `cfl` must be either a `Real` number or14a function of time `t` returning a `Real` number.15By default, the timestep will be adjusted at every step.16For different values of `interval`, the timestep will be adjusted every `interval` steps.17"""18mutable struct StepsizeCallback{CflType}19cfl_number::CflType20interval::Int21end2223function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:StepsizeCallback})24@nospecialize cb # reduce precompilation time2526stepsize_callback = cb.affect!27@unpack cfl_number, interval = stepsize_callback28print(io, "StepsizeCallback(",29"cfl_number=", cfl_number, ", ",30"interval=", interval, ")")31end3233function Base.show(io::IO, ::MIME"text/plain",34cb::DiscreteCallback{<:Any, <:StepsizeCallback})35@nospecialize cb # reduce precompilation time3637if get(io, :compact, false)38show(io, cb)39else40stepsize_callback = cb.affect!4142setup = ["CFL number" => stepsize_callback.cfl_number43"Interval" => stepsize_callback.interval]44summary_box(io, "StepsizeCallback", setup)45end46end4748function StepsizeCallback(; cfl = 1.0, interval = 1)49stepsize_callback = StepsizeCallback{typeof(cfl)}(cfl, interval)5051DiscreteCallback(stepsize_callback, stepsize_callback, # the first one is the condition, the second the affect!52save_positions = (false, false),53initialize = initialize!)54end5556# Compatibility constructor57function StepsizeCallback(cfl)58StepsizeCallback{typeof(cfl)}(cfl, 1)59end6061function initialize!(cb::DiscreteCallback{Condition, Affect!}, u, t,62integrator) where {Condition, Affect! <: StepsizeCallback}63cb.affect!(integrator)64end6566# this method is called to determine whether the callback should be activated67function (stepsize_callback::StepsizeCallback)(u, t, integrator)68@unpack interval = stepsize_callback6970# Although the CFL-based timestep is usually not used with71# adaptive time integration methods, we still check the accepted steps `naccept` here.72return interval > 0 && integrator.stats.naccept % interval == 073end7475# This method is called as callback during the time integration.76@inline function (stepsize_callback::StepsizeCallback)(integrator)77if integrator.opts.adaptive78throw(ArgumentError("The `StepsizeCallback` has no effect when using an adaptive time integration scheme. Please remove the `StepsizeCallback` or set `adaptive = false` in `solve`."))79end8081t = integrator.t82u_ode = integrator.u83semi = integrator.p84@unpack cfl_number = stepsize_callback8586# Dispatch based on semidiscretization87dt = @trixi_timeit timer() "calculate dt" calculate_dt(u_ode, t, cfl_number, semi)8889set_proposed_dt!(integrator, dt)90integrator.opts.dtmax = dt91integrator.dtcache = dt9293# avoid re-evaluating possible FSAL stages94u_modified!(integrator, false)95return nothing96end9798# Time integration methods from the DiffEq ecosystem without adaptive time stepping on their own99# such as `CarpenterKennedy2N54` require passing `dt=...` in `solve(ode, ...)`. Since we don't have100# an integrator at this stage but only the ODE, this method will be used there. It's called in101# many examples in `solve(ode, ..., dt=stepsize_callback(ode), ...)`.102function (cb::DiscreteCallback{Condition, Affect!})(ode::ODEProblem) where {Condition,103Affect! <:104StepsizeCallback105}106stepsize_callback = cb.affect!107@unpack cfl_number = stepsize_callback108u_ode = ode.u0109t = first(ode.tspan)110semi = ode.p111112dt = calculate_dt(u_ode, t, cfl_number, semi)113end114115# General case for a single (i.e., non-coupled) semidiscretization116# Case for constant `cfl_number`.117function calculate_dt(u_ode, t, cfl_number::Real, semi::AbstractSemidiscretization)118mesh, equations, solver, cache = mesh_equations_solver_cache(semi)119u = wrap_array(u_ode, mesh, equations, solver, cache)120121dt = cfl_number * max_dt(u, t, mesh,122have_constant_speed(equations), equations,123solver, cache)124end125# Case for `cfl_number` as a function of time `t`.126function calculate_dt(u_ode, t, cfl_number, semi::AbstractSemidiscretization)127mesh, equations, solver, cache = mesh_equations_solver_cache(semi)128u = wrap_array(u_ode, mesh, equations, solver, cache)129130dt = cfl_number(t) * max_dt(u, t, mesh,131have_constant_speed(equations), equations,132solver, cache)133end134135include("stepsize_dg1d.jl")136include("stepsize_dg2d.jl")137include("stepsize_dg3d.jl")138end # @muladd139140141