Path: blob/main/src/callbacks_step/save_restart_dg.jl
2055 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67function save_restart_file(u, time, dt, timestep,8mesh::Union{SerialTreeMesh, StructuredMesh,9UnstructuredMesh2D, SerialP4estMesh,10SerialT8codeMesh},11equations, dg::DG, cache,12restart_callback)13@unpack output_directory = restart_callback1415# Filename based on current time step16filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))1718# Restart files always store conservative variables19data = u2021# Open file (clobber existing content)22h5open(filename, "w") do file23# Add context information as attributes24attributes(file)["ndims"] = ndims(mesh)25attributes(file)["equations"] = get_name(equations)26attributes(file)["polydeg"] = polydeg(dg)27attributes(file)["n_vars"] = nvariables(equations)28attributes(file)["n_elements"] = nelements(dg, cache)29attributes(file)["mesh_type"] = get_name(mesh)30attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]31attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar32attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar33attributes(file)["timestep"] = timestep3435# Store each variable of the solution36for v in eachvariable(equations)37# Convert to 1D array38file["variables_$v"] = vec(data[v, .., :])3940# Add variable name as attribute41var = file["variables_$v"]42attributes(var)["name"] = varnames(cons2cons, equations)[v]43end44end4546return filename47end4849function load_restart_file(mesh::Union{SerialTreeMesh, StructuredMesh,50UnstructuredMesh2D, SerialP4estMesh,51SerialT8codeMesh},52equations, dg::DG, cache, restart_file)5354# allocate memory55u_ode = allocate_coefficients(mesh, equations, dg, cache)56u = wrap_array_native(u_ode, mesh, equations, dg, cache)5758h5open(restart_file, "r") do file59# Read attributes to perform some sanity checks60if read(attributes(file)["ndims"]) != ndims(mesh)61error("restart mismatch: ndims differs from value in restart file")62end63if read(attributes(file)["equations"]) != get_name(equations)64error("restart mismatch: equations differ from value in restart file")65end66if read(attributes(file)["polydeg"]) != polydeg(dg)67error("restart mismatch: polynomial degree in solver differs from value in restart file")68end69if read(attributes(file)["n_elements"]) != nelements(dg, cache)70error("restart mismatch: number of elements in solver differs from value in restart file")71end7273# Read data74for v in eachvariable(equations)75# Check if variable name matches76var = file["variables_$v"]77if (name = read(attributes(var)["name"])) !=78varnames(cons2cons, equations)[v]79error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")80end8182# Read variable83u[v, .., :] = read(file["variables_$v"])84end85end8687return u_ode88end8990function save_restart_file(u, time, dt, timestep,91mesh::Union{ParallelTreeMesh, ParallelP4estMesh,92ParallelT8codeMesh}, equations,93dg::DG, cache,94restart_callback)95@unpack output_directory = restart_callback96# Filename based on current time step97filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))9899if HDF5.has_parallel()100save_restart_file_parallel(u, time, dt, timestep, mesh, equations, dg, cache,101filename)102else103save_restart_file_on_root(u, time, dt, timestep, mesh, equations, dg, cache,104filename)105end106end107108function save_restart_file_parallel(u, time, dt, timestep,109mesh::Union{ParallelTreeMesh, ParallelP4estMesh,110ParallelT8codeMesh},111equations, dg::DG, cache,112filename)113114# Restart files always store conservative variables115data = u116117# Calculate element and node counts by MPI rank118element_size = nnodes(dg)^ndims(mesh)119element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))120node_counts = element_counts * Cint(element_size)121# Cumulative sum of nodes per rank starting with an additional 0122cum_node_counts = append!(zeros(eltype(node_counts), 1), cumsum(node_counts))123124# Open file (clobber existing content)125h5open(filename, "w", mpi_comm()) do file126# Add context information as attributes127attributes(file)["ndims"] = ndims(mesh)128attributes(file)["equations"] = get_name(equations)129attributes(file)["polydeg"] = polydeg(dg)130attributes(file)["n_vars"] = nvariables(equations)131attributes(file)["n_elements"] = nelementsglobal(mesh, dg, cache)132attributes(file)["mesh_type"] = get_name(mesh)133attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]134attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar135attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar136attributes(file)["timestep"] = timestep137138# Store each variable of the solution139for v in eachvariable(equations)140# Need to create dataset explicitly in parallel case141var = create_dataset(file, "/variables_$v", datatype(eltype(data)),142dataspace((ndofsglobal(mesh, dg, cache),)))143# Write data of each process in slices (ranks start with 0)144slice = (cum_node_counts[mpi_rank() + 1] + 1):cum_node_counts[mpi_rank() + 2]145# Convert to 1D array146var[slice] = vec(data[v, .., :])147# Add variable name as attribute148attributes(var)["name"] = varnames(cons2cons, equations)[v]149end150end151152return filename153end154155function save_restart_file_on_root(u, time, dt, timestep,156mesh::Union{ParallelTreeMesh, ParallelP4estMesh,157ParallelT8codeMesh},158equations, dg::DG, cache,159filename)160161# Restart files always store conservative variables162data = u163164# Calculate element and node counts by MPI rank165element_size = nnodes(dg)^ndims(mesh)166element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))167node_counts = element_counts * Cint(element_size)168169# non-root ranks only send data170if !mpi_isroot()171# Send nodal data to root172for v in eachvariable(equations)173MPI.Gatherv!(vec(data[v, .., :]), nothing, mpi_root(), mpi_comm())174end175176return filename177end178179# Open file (clobber existing content)180h5open(filename, "w") do file181# Add context information as attributes182attributes(file)["ndims"] = ndims(mesh)183attributes(file)["equations"] = get_name(equations)184attributes(file)["polydeg"] = polydeg(dg)185attributes(file)["n_vars"] = nvariables(equations)186attributes(file)["n_elements"] = nelements(dg, cache)187attributes(file)["mesh_type"] = get_name(mesh)188attributes(file)["mesh_file"] = splitdir(mesh.current_filename)[2]189attributes(file)["time"] = convert(Float64, time) # Ensure that `time` is written as a double precision scalar190attributes(file)["dt"] = convert(Float64, dt) # Ensure that `dt` is written as a double precision scalar191attributes(file)["timestep"] = timestep192193# Store each variable of the solution194for v in eachvariable(equations)195# Convert to 1D array196recv = Vector{eltype(data)}(undef, sum(node_counts))197MPI.Gatherv!(vec(data[v, .., :]), MPI.VBuffer(recv, node_counts),198mpi_root(), mpi_comm())199file["variables_$v"] = recv200201# Add variable name as attribute202var = file["variables_$v"]203attributes(var)["name"] = varnames(cons2cons, equations)[v]204end205end206207return filename208end209210function load_restart_file(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,211ParallelT8codeMesh}, equations,212dg::DG, cache, restart_file)213if HDF5.has_parallel()214load_restart_file_parallel(mesh, equations, dg, cache, restart_file)215else216load_restart_file_on_root(mesh, equations, dg, cache, restart_file)217end218end219220function load_restart_file_parallel(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,221ParallelT8codeMesh},222equations, dg::DG, cache, restart_file)223224# Calculate element and node counts by MPI rank225element_size = nnodes(dg)^ndims(mesh)226element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))227node_counts = element_counts * Cint(element_size)228# Cumulative sum of nodes per rank starting with an additional 0229cum_node_counts = append!(zeros(eltype(node_counts), 1), cumsum(node_counts))230231# allocate memory232u_ode = allocate_coefficients(mesh, equations, dg, cache)233u = wrap_array_native(u_ode, mesh, equations, dg, cache)234235# read in parallel236h5open(restart_file, "r", mpi_comm()) do file237# Read attributes to perform some sanity checks238if read(attributes(file)["ndims"]) != ndims(mesh)239error("restart mismatch: ndims differs from value in restart file")240end241if read(attributes(file)["equations"]) != get_name(equations)242error("restart mismatch: equations differ from value in restart file")243end244if read(attributes(file)["polydeg"]) != polydeg(dg)245error("restart mismatch: polynomial degree in solver differs from value in restart file")246end247if read(attributes(file)["n_elements"]) != nelementsglobal(mesh, dg, cache)248error("restart mismatch: number of elements in solver differs from value in restart file")249end250251# Read data252for v in eachvariable(equations)253# Check if variable name matches254var = file["variables_$v"]255if (name = read(attributes(var)["name"])) !=256varnames(cons2cons, equations)[v]257error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")258end259260# Read variable261mpi_println("Reading variables_$v ($name)...")262# Read data of each process in slices (ranks start with 0)263slice = (cum_node_counts[mpi_rank() + 1] + 1):cum_node_counts[mpi_rank() + 2]264# Convert 1D array back to actual size of `u`265u[v, .., :] = reshape(read(var)[slice], size(@view u[v, .., :]))266end267end268269return u_ode270end271272function load_restart_file_on_root(mesh::Union{ParallelTreeMesh, ParallelP4estMesh,273ParallelT8codeMesh},274equations, dg::DG, cache, restart_file)275276# Calculate element and node counts by MPI rank277element_size = nnodes(dg)^ndims(mesh)278element_counts = convert(Vector{Cint}, collect(cache.mpi_cache.n_elements_by_rank))279node_counts = element_counts * Cint(element_size)280281# allocate memory282u_ode = allocate_coefficients(mesh, equations, dg, cache)283u = wrap_array_native(u_ode, mesh, equations, dg, cache)284285# non-root ranks only receive data286if !mpi_isroot()287# Receive nodal data from root288for v in eachvariable(equations)289# put Scatterv in both blocks of the if condition to avoid type instability290if isempty(u)291data = eltype(u)[]292MPI.Scatterv!(nothing, data, mpi_root(), mpi_comm())293else294data = @view u[v, .., :]295MPI.Scatterv!(nothing, data, mpi_root(), mpi_comm())296end297end298299return u_ode300end301302# read only on MPI root303h5open(restart_file, "r") do file304# Read attributes to perform some sanity checks305if read(attributes(file)["ndims"]) != ndims(mesh)306error("restart mismatch: ndims differs from value in restart file")307end308if read(attributes(file)["equations"]) != get_name(equations)309error("restart mismatch: equations differ from value in restart file")310end311if read(attributes(file)["polydeg"]) != polydeg(dg)312error("restart mismatch: polynomial degree in solver differs from value in restart file")313end314if read(attributes(file)["n_elements"]) != nelements(dg, cache)315error("restart mismatch: number of elements in solver differs from value in restart file")316end317318# Read data319for v in eachvariable(equations)320# Check if variable name matches321var = file["variables_$v"]322if (name = read(attributes(var)["name"])) !=323varnames(cons2cons, equations)[v]324error("mismatch: variables_$v should be '$(varnames(cons2cons, equations)[v])', but found '$name'")325end326327# Read variable328println("Reading variables_$v ($name)...")329sendbuf = MPI.VBuffer(read(file["variables_$v"]), node_counts)330MPI.Scatterv!(sendbuf, @view(u[v, .., :]), mpi_root(), mpi_comm())331end332end333334return u_ode335end336337# Store controller values for an adaptive time stepping scheme338function save_adaptive_time_integrator(integrator,339controller, restart_callback)340# Save only on root341if mpi_isroot()342@unpack output_directory = restart_callback343timestep = integrator.stats.naccept344345# Filename based on current time step346filename = joinpath(output_directory, @sprintf("restart_%09d.h5", timestep))347348# Open file (preserve existing content)349h5open(filename, "r+") do file350# Add context information as attributes both for PIController and PIDController351attributes(file)["time_integrator_qold"] = integrator.qold352attributes(file)["time_integrator_dtpropose"] = integrator.dtpropose353# For PIDController is necessary to save additional parameters354if hasproperty(controller, :err) # Distinguish PIDController from PIController355attributes(file)["time_integrator_controller_err"] = controller.err356end357end358end359end360end # @muladd361362363