Path: blob/main/src/callbacks_step/save_solution.jl
2055 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67"""8SaveSolutionCallback(; interval::Integer=0,9dt=nothing,10save_initial_solution=true,11save_final_solution=true,12output_directory="out",13solution_variables=cons2prim,14extra_node_variables=())1516Save the current numerical solution in regular intervals. Either pass `interval` to save17every `interval` time steps or pass `dt` to save in intervals of `dt` in terms18of integration time by adding additional (shortened) time steps where necessary (note that this may change the solution).19`solution_variables` can be any callable that converts the conservative variables20at a single point to a set of solution variables. The first parameter passed21to `solution_variables` will be the set of conservative variables22and the second parameter is the equation struct.2324Additional nodal variables such as vorticity or the Mach number can be saved by passing a tuple of symbols25to `extra_node_variables`, e.g., `extra_node_variables = (:vorticity, :mach)`.26In that case the function `get_node_variable` must be defined for each symbol in the tuple.27The expected signature of the function for (purely) hyperbolic equations is:28```julia29function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache)30# Implementation goes here31end32```33and must return an array of dimension34`(ntuple(_ -> n_nodes, ndims(mesh))..., n_elements)`.3536For parabolic-hyperbolic equations `equations_parabolic` and `cache_parabolic` must be added:37```julia38function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,39equations_parabolic, cache_parabolic)40# Implementation goes here41end42```43"""44mutable struct SaveSolutionCallback{IntervalType, SolutionVariablesType}45interval_or_dt::IntervalType46save_initial_solution::Bool47save_final_solution::Bool48output_directory::String49solution_variables::SolutionVariablesType50node_variables::Dict{Symbol, Any}51end5253function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})54@nospecialize cb # reduce precompilation time5556save_solution_callback = cb.affect!57print(io, "SaveSolutionCallback(interval=", save_solution_callback.interval_or_dt,58")")59end6061function Base.show(io::IO,62cb::DiscreteCallback{<:Any,63<:PeriodicCallbackAffect{<:SaveSolutionCallback}})64@nospecialize cb # reduce precompilation time6566save_solution_callback = cb.affect!.affect!67print(io, "SaveSolutionCallback(dt=", save_solution_callback.interval_or_dt, ")")68end6970function Base.show(io::IO, ::MIME"text/plain",71cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})72@nospecialize cb # reduce precompilation time7374if get(io, :compact, false)75show(io, cb)76else77save_solution_callback = cb.affect!7879setup = [80"interval" => save_solution_callback.interval_or_dt,81"solution variables" => save_solution_callback.solution_variables,82"save initial solution" => save_solution_callback.save_initial_solution ?83"yes" : "no",84"save final solution" => save_solution_callback.save_final_solution ?85"yes" : "no",86"output directory" => abspath(normpath(save_solution_callback.output_directory))87]88summary_box(io, "SaveSolutionCallback", setup)89end90end9192function Base.show(io::IO, ::MIME"text/plain",93cb::DiscreteCallback{<:Any,94<:PeriodicCallbackAffect{<:SaveSolutionCallback}})95@nospecialize cb # reduce precompilation time9697if get(io, :compact, false)98show(io, cb)99else100save_solution_callback = cb.affect!.affect!101102setup = [103"dt" => save_solution_callback.interval_or_dt,104"solution variables" => save_solution_callback.solution_variables,105"save initial solution" => save_solution_callback.save_initial_solution ?106"yes" : "no",107"save final solution" => save_solution_callback.save_final_solution ?108"yes" : "no",109"output directory" => abspath(normpath(save_solution_callback.output_directory))110]111summary_box(io, "SaveSolutionCallback", setup)112end113end114115function SaveSolutionCallback(; interval::Integer = 0,116dt = nothing,117save_initial_solution = true,118save_final_solution = true,119output_directory = "out",120solution_variables = cons2prim,121extra_node_variables = ())122if !isnothing(dt) && interval > 0123throw(ArgumentError("You can either set the number of steps between output (using `interval`) or the time between outputs (using `dt`) but not both simultaneously"))124end125126# Expected most frequent behavior comes first127if isnothing(dt)128interval_or_dt = interval129else # !isnothing(dt)130interval_or_dt = dt131end132133node_variables = Dict{Symbol, Any}(var => nothing for var in extra_node_variables)134solution_callback = SaveSolutionCallback(interval_or_dt,135save_initial_solution, save_final_solution,136output_directory, solution_variables,137node_variables)138139# Expected most frequent behavior comes first140if isnothing(dt)141# Save every `interval` (accepted) time steps142# The first one is the condition, the second the affect!143return DiscreteCallback(solution_callback, solution_callback,144save_positions = (false, false),145initialize = initialize_save_cb!)146else147# Add a `tstop` every `dt`, and save the final solution.148return PeriodicCallback(solution_callback, dt,149save_positions = (false, false),150initialize = initialize_save_cb!,151final_affect = save_final_solution)152end153end154155function initialize_save_cb!(cb, u, t, integrator)156# The SaveSolutionCallback is either cb.affect! (with DiscreteCallback)157# or cb.affect!.affect! (with PeriodicCallback).158# Let recursive dispatch handle this.159initialize_save_cb!(cb.affect!, u, t, integrator)160end161162function initialize_save_cb!(solution_callback::SaveSolutionCallback, u, t, integrator)163mpi_isroot() && mkpath(solution_callback.output_directory)164165semi = integrator.p166@trixi_timeit timer() "I/O" save_mesh(semi, solution_callback.output_directory)167168if solution_callback.save_initial_solution169solution_callback(integrator)170end171172return nothing173end174175# Save mesh for a general semidiscretization (default)176function save_mesh(semi::AbstractSemidiscretization, output_directory, timestep = 0)177mesh, _, _, _ = mesh_equations_solver_cache(semi)178179if mesh.unsaved_changes180# We only append the time step number to the mesh file name if it has181# changed during the simulation due to AMR. We do not append it for182# the first time step.183if timestep == 0184mesh.current_filename = save_mesh_file(mesh, output_directory)185else186mesh.current_filename = save_mesh_file(mesh, output_directory, timestep)187end188mesh.unsaved_changes = false189end190return mesh.current_filename191end192193# Save mesh for a DGMultiMesh, which requires passing the `basis` as an argument to194# save_mesh_file195function save_mesh(semi::Union{SemidiscretizationHyperbolic{<:DGMultiMesh},196SemidiscretizationHyperbolicParabolic{<:DGMultiMesh}},197output_directory, timestep = 0)198mesh, _, solver, _ = mesh_equations_solver_cache(semi)199200if mesh.unsaved_changes201# We only append the time step number to the mesh file name if it has202# changed during the simulation due to AMR. We do not append it for203# the first time step.204if timestep == 0205mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,206output_directory)207else208mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,209output_directory, timestep)210end211mesh.unsaved_changes = false212end213return mesh.current_filename214end215216# this method is called to determine whether the callback should be activated217function (solution_callback::SaveSolutionCallback)(u, t, integrator)218@unpack interval_or_dt, save_final_solution = solution_callback219220# With error-based step size control, some steps can be rejected. Thus,221# `integrator.iter >= integrator.stats.naccept`222# (total #steps) (#accepted steps)223# We need to check the number of accepted steps since callbacks are not224# activated after a rejected step.225return interval_or_dt > 0 && (integrator.stats.naccept % interval_or_dt == 0 ||226(save_final_solution && isfinished(integrator)))227end228229# this method is called when the callback is activated230function (solution_callback::SaveSolutionCallback)(integrator)231u_ode = integrator.u232semi = integrator.p233iter = integrator.stats.naccept234235@trixi_timeit timer() "I/O" begin236# Call high-level functions that dispatch on semidiscretization type237@trixi_timeit timer() "save mesh" save_mesh(semi,238solution_callback.output_directory,239iter)240save_solution_file(semi, u_ode, solution_callback, integrator)241end242243# avoid re-evaluating possible FSAL stages244u_modified!(integrator, false)245return nothing246end247248@inline function save_solution_file(semi::AbstractSemidiscretization, u_ode,249solution_callback,250integrator; system = "")251@unpack t, dt = integrator252iter = integrator.stats.naccept253254element_variables = Dict{Symbol, Any}()255@trixi_timeit timer() "get element variables" begin256get_element_variables!(element_variables, u_ode, semi)257callbacks = integrator.opts.callback258if callbacks isa CallbackSet259foreach(callbacks.continuous_callbacks) do cb260get_element_variables!(element_variables, u_ode, semi, cb;261t = integrator.t, iter = iter)262end263foreach(callbacks.discrete_callbacks) do cb264get_element_variables!(element_variables, u_ode, semi, cb;265t = integrator.t, iter = iter)266end267end268end269270@trixi_timeit timer() "get node variables" get_node_variables!(solution_callback.node_variables,271u_ode, semi)272273@trixi_timeit timer() "save solution" save_solution_file(u_ode, t, dt, iter, semi,274solution_callback,275element_variables,276solution_callback.node_variables,277system = system)278279return nothing280end281282@inline function save_solution_file(u_ode, t, dt, iter,283semi::AbstractSemidiscretization, solution_callback,284element_variables = Dict{Symbol, Any}(),285node_variables = Dict{Symbol, Any}();286system = "")287mesh, equations, solver, cache = mesh_equations_solver_cache(semi)288u = wrap_array_native(u_ode, mesh, equations, solver, cache)289save_solution_file(u, t, dt, iter, mesh, equations, solver, cache,290solution_callback,291element_variables,292node_variables; system = system)293294return nothing295end296297# TODO: Taal refactor, move save_mesh_file?298# function save_mesh_file(mesh::TreeMesh, output_directory, timestep=-1) in io/io.jl299300include("save_solution_dg.jl")301end # @muladd302303304