Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/save_solution.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
SaveSolutionCallback(; interval::Integer=0,
10
dt=nothing,
11
save_initial_solution=true,
12
save_final_solution=true,
13
output_directory="out",
14
solution_variables=cons2prim,
15
extra_node_variables=())
16
17
Save the current numerical solution in regular intervals. Either pass `interval` to save
18
every `interval` time steps or pass `dt` to save in intervals of `dt` in terms
19
of integration time by adding additional (shortened) time steps where necessary (note that this may change the solution).
20
`solution_variables` can be any callable that converts the conservative variables
21
at a single point to a set of solution variables. The first parameter passed
22
to `solution_variables` will be the set of conservative variables
23
and the second parameter is the equation struct.
24
25
Additional nodal variables such as vorticity or the Mach number can be saved by passing a tuple of symbols
26
to `extra_node_variables`, e.g., `extra_node_variables = (:vorticity, :mach)`.
27
In that case the function `get_node_variable` must be defined for each symbol in the tuple.
28
The expected signature of the function for (purely) hyperbolic equations is:
29
```julia
30
function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache)
31
# Implementation goes here
32
end
33
```
34
and must return an array of dimension
35
`(ntuple(_ -> n_nodes, ndims(mesh))..., n_elements)`.
36
37
For parabolic-hyperbolic equations `equations_parabolic` and `cache_parabolic` must be added:
38
```julia
39
function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,
40
equations_parabolic, cache_parabolic)
41
# Implementation goes here
42
end
43
```
44
"""
45
mutable struct SaveSolutionCallback{IntervalType, SolutionVariablesType}
46
interval_or_dt::IntervalType
47
save_initial_solution::Bool
48
save_final_solution::Bool
49
output_directory::String
50
solution_variables::SolutionVariablesType
51
node_variables::Dict{Symbol, Any}
52
end
53
54
function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})
55
@nospecialize cb # reduce precompilation time
56
57
save_solution_callback = cb.affect!
58
print(io, "SaveSolutionCallback(interval=", save_solution_callback.interval_or_dt,
59
")")
60
end
61
62
function Base.show(io::IO,
63
cb::DiscreteCallback{<:Any,
64
<:PeriodicCallbackAffect{<:SaveSolutionCallback}})
65
@nospecialize cb # reduce precompilation time
66
67
save_solution_callback = cb.affect!.affect!
68
print(io, "SaveSolutionCallback(dt=", save_solution_callback.interval_or_dt, ")")
69
end
70
71
function Base.show(io::IO, ::MIME"text/plain",
72
cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})
73
@nospecialize cb # reduce precompilation time
74
75
if get(io, :compact, false)
76
show(io, cb)
77
else
78
save_solution_callback = cb.affect!
79
80
setup = [
81
"interval" => save_solution_callback.interval_or_dt,
82
"solution variables" => save_solution_callback.solution_variables,
83
"save initial solution" => save_solution_callback.save_initial_solution ?
84
"yes" : "no",
85
"save final solution" => save_solution_callback.save_final_solution ?
86
"yes" : "no",
87
"output directory" => abspath(normpath(save_solution_callback.output_directory))
88
]
89
summary_box(io, "SaveSolutionCallback", setup)
90
end
91
end
92
93
function Base.show(io::IO, ::MIME"text/plain",
94
cb::DiscreteCallback{<:Any,
95
<:PeriodicCallbackAffect{<:SaveSolutionCallback}})
96
@nospecialize cb # reduce precompilation time
97
98
if get(io, :compact, false)
99
show(io, cb)
100
else
101
save_solution_callback = cb.affect!.affect!
102
103
setup = [
104
"dt" => save_solution_callback.interval_or_dt,
105
"solution variables" => save_solution_callback.solution_variables,
106
"save initial solution" => save_solution_callback.save_initial_solution ?
107
"yes" : "no",
108
"save final solution" => save_solution_callback.save_final_solution ?
109
"yes" : "no",
110
"output directory" => abspath(normpath(save_solution_callback.output_directory))
111
]
112
summary_box(io, "SaveSolutionCallback", setup)
113
end
114
end
115
116
function SaveSolutionCallback(; interval::Integer = 0,
117
dt = nothing,
118
save_initial_solution = true,
119
save_final_solution = true,
120
output_directory = "out",
121
solution_variables = cons2prim,
122
extra_node_variables = ())
123
if !isnothing(dt) && interval > 0
124
throw(ArgumentError("You can either set the number of steps between output (using `interval`) or the time between outputs (using `dt`) but not both simultaneously"))
125
end
126
127
# Expected most frequent behavior comes first
128
if isnothing(dt)
129
interval_or_dt = interval
130
else # !isnothing(dt)
131
interval_or_dt = dt
132
end
133
134
node_variables = Dict{Symbol, Any}(var => nothing for var in extra_node_variables)
135
solution_callback = SaveSolutionCallback(interval_or_dt,
136
save_initial_solution, save_final_solution,
137
output_directory, solution_variables,
138
node_variables)
139
140
# Expected most frequent behavior comes first
141
if isnothing(dt)
142
# Save every `interval` (accepted) time steps
143
# The first one is the condition, the second the affect!
144
return DiscreteCallback(solution_callback, solution_callback,
145
save_positions = (false, false),
146
initialize = initialize_save_cb!)
147
else
148
# Add a `tstop` every `dt`, and save the final solution.
149
return PeriodicCallback(solution_callback, dt,
150
save_positions = (false, false),
151
initialize = initialize_save_cb!,
152
final_affect = save_final_solution)
153
end
154
end
155
156
function initialize_save_cb!(cb, u, t, integrator)
157
# The SaveSolutionCallback is either cb.affect! (with DiscreteCallback)
158
# or cb.affect!.affect! (with PeriodicCallback).
159
# Let recursive dispatch handle this.
160
initialize_save_cb!(cb.affect!, u, t, integrator)
161
end
162
163
function initialize_save_cb!(solution_callback::SaveSolutionCallback, u, t, integrator)
164
mpi_isroot() && mkpath(solution_callback.output_directory)
165
166
semi = integrator.p
167
@trixi_timeit timer() "I/O" save_mesh(semi, solution_callback.output_directory)
168
169
if solution_callback.save_initial_solution
170
solution_callback(integrator)
171
end
172
173
return nothing
174
end
175
176
# Save mesh for a general semidiscretization (default)
177
function save_mesh(semi::AbstractSemidiscretization, output_directory, timestep = 0)
178
mesh, _, _, _ = mesh_equations_solver_cache(semi)
179
180
if mesh.unsaved_changes
181
# We only append the time step number to the mesh file name if it has
182
# changed during the simulation due to AMR. We do not append it for
183
# the first time step.
184
if timestep == 0
185
mesh.current_filename = save_mesh_file(mesh, output_directory)
186
else
187
mesh.current_filename = save_mesh_file(mesh, output_directory, timestep)
188
end
189
mesh.unsaved_changes = false
190
end
191
return mesh.current_filename
192
end
193
194
# Save mesh for a DGMultiMesh, which requires passing the `basis` as an argument to
195
# save_mesh_file
196
function save_mesh(semi::Union{SemidiscretizationHyperbolic{<:DGMultiMesh},
197
SemidiscretizationHyperbolicParabolic{<:DGMultiMesh}},
198
output_directory, timestep = 0)
199
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
200
201
if mesh.unsaved_changes
202
# We only append the time step number to the mesh file name if it has
203
# changed during the simulation due to AMR. We do not append it for
204
# the first time step.
205
if timestep == 0
206
mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,
207
output_directory)
208
else
209
mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,
210
output_directory, timestep)
211
end
212
mesh.unsaved_changes = false
213
end
214
return mesh.current_filename
215
end
216
217
# this method is called to determine whether the callback should be activated
218
function (solution_callback::SaveSolutionCallback)(u, t, integrator)
219
@unpack interval_or_dt, save_final_solution = solution_callback
220
221
# With error-based step size control, some steps can be rejected. Thus,
222
# `integrator.iter >= integrator.stats.naccept`
223
# (total #steps) (#accepted steps)
224
# We need to check the number of accepted steps since callbacks are not
225
# activated after a rejected step.
226
return interval_or_dt > 0 && (integrator.stats.naccept % interval_or_dt == 0 ||
227
(save_final_solution && isfinished(integrator)))
228
end
229
230
# this method is called when the callback is activated
231
function (solution_callback::SaveSolutionCallback)(integrator)
232
u_ode = integrator.u
233
semi = integrator.p
234
iter = integrator.stats.naccept
235
236
@trixi_timeit timer() "I/O" begin
237
# Call high-level functions that dispatch on semidiscretization type
238
@trixi_timeit timer() "save mesh" save_mesh(semi,
239
solution_callback.output_directory,
240
iter)
241
save_solution_file(semi, u_ode, solution_callback, integrator)
242
end
243
244
# avoid re-evaluating possible FSAL stages
245
u_modified!(integrator, false)
246
return nothing
247
end
248
249
@inline function save_solution_file(semi::AbstractSemidiscretization, u_ode,
250
solution_callback,
251
integrator; system = "")
252
@unpack t, dt = integrator
253
iter = integrator.stats.naccept
254
255
element_variables = Dict{Symbol, Any}()
256
@trixi_timeit timer() "get element variables" begin
257
get_element_variables!(element_variables, u_ode, semi)
258
callbacks = integrator.opts.callback
259
if callbacks isa CallbackSet
260
foreach(callbacks.continuous_callbacks) do cb
261
get_element_variables!(element_variables, u_ode, semi, cb;
262
t = integrator.t, iter = iter)
263
end
264
foreach(callbacks.discrete_callbacks) do cb
265
get_element_variables!(element_variables, u_ode, semi, cb;
266
t = integrator.t, iter = iter)
267
end
268
end
269
end
270
271
@trixi_timeit timer() "get node variables" get_node_variables!(solution_callback.node_variables,
272
u_ode, semi)
273
274
@trixi_timeit timer() "save solution" save_solution_file(u_ode, t, dt, iter, semi,
275
solution_callback,
276
element_variables,
277
solution_callback.node_variables,
278
system = system)
279
280
return nothing
281
end
282
283
@inline function save_solution_file(u_ode, t, dt, iter,
284
semi::AbstractSemidiscretization, solution_callback,
285
element_variables = Dict{Symbol, Any}(),
286
node_variables = Dict{Symbol, Any}();
287
system = "")
288
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
289
u = wrap_array_native(u_ode, mesh, equations, solver, cache)
290
save_solution_file(u, t, dt, iter, mesh, equations, solver, cache,
291
solution_callback,
292
element_variables,
293
node_variables; system = system)
294
295
return nothing
296
end
297
298
# TODO: Taal refactor, move save_mesh_file?
299
# function save_mesh_file(mesh::TreeMesh, output_directory, timestep=-1) in io/io.jl
300
301
include("save_solution_dg.jl")
302
end # @muladd
303
304