Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/save_restart_dg.jl
2055 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
function save_restart_file(u, time, dt, timestep,
9
mesh::Union{SerialTreeMesh, StructuredMesh,
10
UnstructuredMesh2D, SerialP4estMesh,
11
SerialT8codeMesh},
12
equations, dg::DG, cache,
13
restart_callback)
14
@unpack output_directory = restart_callback
15
16
# Filename based on current time step
17
filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))
18
19
# Restart files always store conservative variables
20
data = u
21
22
# Open file (clobber existing content)
23
h5open(filename, "w") do file
24
# Add context information as attributes
25
attributes(file)["ndims"] = ndims(mesh)
26
attributes(file)["equations"] = get_name(equations)
27
attributes(file)["polydeg"] = polydeg(dg)
28
attributes(file)["n_vars"] = nvariables(equations)
29
attributes(file)["n_elements"] = nelements(dg, cache)
30
attributes(file)["mesh_type"] = get_name(mesh)
31
attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]
32
attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar
33
attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar
34
attributes(file)["timestep"] = timestep
35
36
# Store each variable of the solution
37
for v in eachvariable(equations)
38
# Convert to 1D array
39
file["variables_$v"] = vec(data[v, .., :])
40
41
# Add variable name as attribute
42
var = file["variables_$v"]
43
attributes(var)["name"] = varnames(cons2cons, equations)[v]
44
end
45
end
46
47
return filename
48
end
49
50
function load_restart_file(mesh::Union{SerialTreeMesh, StructuredMesh,
51
UnstructuredMesh2D, SerialP4estMesh,
52
SerialT8codeMesh},
53
equations, dg::DG, cache, restart_file)
54
55
# allocate memory
56
u_ode = allocate_coefficients(mesh, equations, dg, cache)
57
u = wrap_array_native(u_ode, mesh, equations, dg, cache)
58
59
h5open(restart_file, "r") do file
60
# Read attributes to perform some sanity checks
61
if read(attributes(file)["ndims"]) != ndims(mesh)
62
error("restart mismatch: ndims differs from value in restart file")
63
end
64
if read(attributes(file)["equations"]) != get_name(equations)
65
error("restart mismatch: equations differ from value in restart file")
66
end
67
if read(attributes(file)["polydeg"]) != polydeg(dg)
68
error("restart mismatch: polynomial degree in solver differs from value in restart file")
69
end
70
if read(attributes(file)["n_elements"]) != nelements(dg, cache)
71
error("restart mismatch: number of elements in solver differs from value in restart file")
72
end
73
74
# Read data
75
for v in eachvariable(equations)
76
# Check if variable name matches
77
var = file["variables_$v"]
78
if (name = read(attributes(var)["name"])) !=
79
varnames(cons2cons, equations)[v]
80
error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")
81
end
82
83
# Read variable
84
u[v, .., :] = read(file["variables_$v"])
85
end
86
end
87
88
return u_ode
89
end
90
91
function save_restart_file(u, time, dt, timestep,
92
mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
93
ParallelT8codeMesh}, equations,
94
dg::DG, cache,
95
restart_callback)
96
@unpack output_directory = restart_callback
97
# Filename based on current time step
98
filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))
99
100
if HDF5.has_parallel()
101
save_restart_file_parallel(u, time, dt, timestep, mesh, equations, dg, cache,
102
filename)
103
else
104
save_restart_file_on_root(u, time, dt, timestep, mesh, equations, dg, cache,
105
filename)
106
end
107
end
108
109
function save_restart_file_parallel(u, time, dt, timestep,
110
mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
111
ParallelT8codeMesh},
112
equations, dg::DG, cache,
113
filename)
114
115
# Restart files always store conservative variables
116
data = u
117
118
# Calculate element and node counts by MPI rank
119
element_size = nnodes(dg)^ndims(mesh)
120
element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))
121
node_counts = element_counts * Cint(element_size)
122
# Cumulative sum of nodes per rank starting with an additional 0
123
cum_node_counts = append!(zeros(eltype(node_counts), 1), cumsum(node_counts))
124
125
# Open file (clobber existing content)
126
h5open(filename, "w", mpi_comm()) do file
127
# Add context information as attributes
128
attributes(file)["ndims"] = ndims(mesh)
129
attributes(file)["equations"] = get_name(equations)
130
attributes(file)["polydeg"] = polydeg(dg)
131
attributes(file)["n_vars"] = nvariables(equations)
132
attributes(file)["n_elements"] = nelementsglobal(mesh, dg, cache)
133
attributes(file)["mesh_type"] = get_name(mesh)
134
attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]
135
attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar
136
attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar
137
attributes(file)["timestep"] = timestep
138
139
# Store each variable of the solution
140
for v in eachvariable(equations)
141
# Need to create dataset explicitly in parallel case
142
var = create_dataset(file, "/variables_$v", datatype(eltype(data)),
143
dataspace((ndofsglobal(mesh, dg, cache),)))
144
# Write data of each process in slices (ranks start with 0)
145
slice = (cum_node_counts[mpi_rank() + 1] + 1):cum_node_counts[mpi_rank() + 2]
146
# Convert to 1D array
147
var[slice] = vec(data[v, .., :])
148
# Add variable name as attribute
149
attributes(var)["name"] = varnames(cons2cons, equations)[v]
150
end
151
end
152
153
return filename
154
end
155
156
function save_restart_file_on_root(u, time, dt, timestep,
157
mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
158
ParallelT8codeMesh},
159
equations, dg::DG, cache,
160
filename)
161
162
# Restart files always store conservative variables
163
data = u
164
165
# Calculate element and node counts by MPI rank
166
element_size = nnodes(dg)^ndims(mesh)
167
element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))
168
node_counts = element_counts * Cint(element_size)
169
170
# non-root ranks only send data
171
if !mpi_isroot()
172
# Send nodal data to root
173
for v in eachvariable(equations)
174
MPI.Gatherv!(vec(data[v, .., :]), nothing, mpi_root(), mpi_comm())
175
end
176
177
return filename
178
end
179
180
# Open file (clobber existing content)
181
h5open(filename, "w") do file
182
# Add context information as attributes
183
attributes(file)["ndims"] = ndims(mesh)
184
attributes(file)["equations"] = get_name(equations)
185
attributes(file)["polydeg"] = polydeg(dg)
186
attributes(file)["n_vars"] = nvariables(equations)
187
attributes(file)["n_elements"] = nelements(dg, cache)
188
attributes(file)["mesh_type"] = get_name(mesh)
189
attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]
190
attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar
191
attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar
192
attributes(file)["timestep"] = timestep
193
194
# Store each variable of the solution
195
for v in eachvariable(equations)
196
# Convert to 1D array
197
recv = Vector{eltype(data)}(undef, sum(node_counts))
198
MPI.Gatherv!(vec(data[v, .., :]), MPI.VBuffer(recv, node_counts),
199
mpi_root(), mpi_comm())
200
file["variables_$v"] = recv
201
202
# Add variable name as attribute
203
var = file["variables_$v"]
204
attributes(var)["name"] = varnames(cons2cons, equations)[v]
205
end
206
end
207
208
return filename
209
end
210
211
function load_restart_file(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
212
ParallelT8codeMesh}, equations,
213
dg::DG, cache, restart_file)
214
if HDF5.has_parallel()
215
load_restart_file_parallel(mesh, equations, dg, cache, restart_file)
216
else
217
load_restart_file_on_root(mesh, equations, dg, cache, restart_file)
218
end
219
end
220
221
function load_restart_file_parallel(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
222
ParallelT8codeMesh},
223
equations, dg::DG, cache, restart_file)
224
225
# Calculate element and node counts by MPI rank
226
element_size = nnodes(dg)^ndims(mesh)
227
element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))
228
node_counts = element_counts * Cint(element_size)
229
# Cumulative sum of nodes per rank starting with an additional 0
230
cum_node_counts = append!(zeros(eltype(node_counts), 1), cumsum(node_counts))
231
232
# allocate memory
233
u_ode = allocate_coefficients(mesh, equations, dg, cache)
234
u = wrap_array_native(u_ode, mesh, equations, dg, cache)
235
236
# read in parallel
237
h5open(restart_file, "r", mpi_comm()) do file
238
# Read attributes to perform some sanity checks
239
if read(attributes(file)["ndims"]) != ndims(mesh)
240
error("restart mismatch: ndims differs from value in restart file")
241
end
242
if read(attributes(file)["equations"]) != get_name(equations)
243
error("restart mismatch: equations differ from value in restart file")
244
end
245
if read(attributes(file)["polydeg"]) != polydeg(dg)
246
error("restart mismatch: polynomial degree in solver differs from value in restart file")
247
end
248
if read(attributes(file)["n_elements"]) != nelementsglobal(mesh, dg, cache)
249
error("restart mismatch: number of elements in solver differs from value in restart file")
250
end
251
252
# Read data
253
for v in eachvariable(equations)
254
# Check if variable name matches
255
var = file["variables_$v"]
256
if (name = read(attributes(var)["name"])) !=
257
varnames(cons2cons, equations)[v]
258
error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")
259
end
260
261
# Read variable
262
mpi_println("Reading variables_$v ($name)...")
263
# Read data of each process in slices (ranks start with 0)
264
slice = (cum_node_counts[mpi_rank() + 1] + 1):cum_node_counts[mpi_rank() + 2]
265
# Convert 1D array back to actual size of `u`
266
u[v, .., :] = reshape(read(var)[slice], size(@view u[v, .., :]))
267
end
268
end
269
270
return u_ode
271
end
272
273
function load_restart_file_on_root(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,
274
ParallelT8codeMesh},
275
equations, dg::DG, cache, restart_file)
276
277
# Calculate element and node counts by MPI rank
278
element_size = nnodes(dg)^ndims(mesh)
279
element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))
280
node_counts = element_counts * Cint(element_size)
281
282
# allocate memory
283
u_ode = allocate_coefficients(mesh, equations, dg, cache)
284
u = wrap_array_native(u_ode, mesh, equations, dg, cache)
285
286
# non-root ranks only receive data
287
if !mpi_isroot()
288
# Receive nodal data from root
289
for v in eachvariable(equations)
290
# put Scatterv in both blocks of the if condition to avoid type instability
291
if isempty(u)
292
data = eltype(u)[]
293
MPI.Scatterv!(nothing, data, mpi_root(), mpi_comm())
294
else
295
data = @view u[v, .., :]
296
MPI.Scatterv!(nothing, data, mpi_root(), mpi_comm())
297
end
298
end
299
300
return u_ode
301
end
302
303
# read only on MPI root
304
h5open(restart_file, "r") do file
305
# Read attributes to perform some sanity checks
306
if read(attributes(file)["ndims"]) != ndims(mesh)
307
error("restart mismatch: ndims differs from value in restart file")
308
end
309
if read(attributes(file)["equations"]) != get_name(equations)
310
error("restart mismatch: equations differ from value in restart file")
311
end
312
if read(attributes(file)["polydeg"]) != polydeg(dg)
313
error("restart mismatch: polynomial degree in solver differs from value in restart file")
314
end
315
if read(attributes(file)["n_elements"]) != nelements(dg, cache)
316
error("restart mismatch: number of elements in solver differs from value in restart file")
317
end
318
319
# Read data
320
for v in eachvariable(equations)
321
# Check if variable name matches
322
var = file["variables_$v"]
323
if (name = read(attributes(var)["name"])) !=
324
varnames(cons2cons, equations)[v]
325
error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")
326
end
327
328
# Read variable
329
println("Reading variables_$v ($name)...")
330
sendbuf = MPI.VBuffer(read(file["variables_$v"]), node_counts)
331
MPI.Scatterv!(sendbuf, @view(u[v, .., :]), mpi_root(), mpi_comm())
332
end
333
end
334
335
return u_ode
336
end
337
338
# Store controller values for an adaptive time stepping scheme
339
function save_adaptive_time_integrator(integrator,
340
controller, restart_callback)
341
# Save only on root
342
if mpi_isroot()
343
@unpack output_directory = restart_callback
344
timestep = integrator.stats.naccept
345
346
# Filename based on current time step
347
filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))
348
349
# Open file (preserve existing content)
350
h5open(filename, "r+") do file
351
# Add context information as attributes both for PIController and PIDController
352
attributes(file)["time_integrator_qold"] = integrator.qold
353
attributes(file)["time_integrator_dtpropose"] = integrator.dtpropose
354
# For PIDController is necessary to save additional parameters
355
if hasproperty(controller, :err) # Distinguish PIDController from PIController
356
attributes(file)["time_integrator_controller_err"] = controller.err
357
end
358
end
359
end
360
end
361
end # @muladd
362
363