Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
bevyengine
GitHub Repository: bevyengine/bevy
Path: blob/main/examples/shader_advanced/custom_shader_instancing.rs
6849 views
1
//! A shader that renders a mesh multiple times in one draw call.
2
//!
3
//! Bevy will automatically batch and instance your meshes assuming you use the same
4
//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
5
//!
6
//! This example is intended for advanced users and shows how to make a custom instancing
7
//! implementation using bevy's low level rendering api.
8
//! It's generally recommended to try the built-in instancing before going with this approach.
9
10
use bevy::pbr::SetMeshViewBindingArrayBindGroup;
11
use bevy::{
12
camera::visibility::NoFrustumCulling,
13
core_pipeline::core_3d::Transparent3d,
14
ecs::{
15
query::QueryItem,
16
system::{lifetimeless::*, SystemParamItem},
17
},
18
mesh::{MeshVertexBufferLayoutRef, VertexBufferLayout},
19
pbr::{
20
MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
21
},
22
prelude::*,
23
render::{
24
extract_component::{ExtractComponent, ExtractComponentPlugin},
25
mesh::{allocator::MeshAllocator, RenderMesh, RenderMeshBufferInfo},
26
render_asset::RenderAssets,
27
render_phase::{
28
AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
29
RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
30
},
31
render_resource::*,
32
renderer::RenderDevice,
33
sync_world::MainEntity,
34
view::{ExtractedView, NoIndirectDrawing},
35
Render, RenderApp, RenderStartup, RenderSystems,
36
},
37
};
38
use bytemuck::{Pod, Zeroable};
39
40
/// This example uses a shader source file from the assets subdirectory
41
const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
42
43
fn main() {
44
App::new()
45
.add_plugins((DefaultPlugins, CustomMaterialPlugin))
46
.add_systems(Startup, setup)
47
.run();
48
}
49
50
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
51
commands.spawn((
52
Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
53
InstanceMaterialData(
54
(1..=10)
55
.flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
56
.map(|(x, y)| InstanceData {
57
position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
58
scale: 1.0,
59
color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
60
})
61
.collect(),
62
),
63
// NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
64
// As the cube is at the origin, if its Aabb moves outside the view frustum, all the
65
// instanced cubes will be culled.
66
// The InstanceMaterialData contains the 'GlobalTransform' information for this custom
67
// instancing, and that is not taken into account with the built-in frustum culling.
68
// We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
69
// component to avoid incorrect culling.
70
NoFrustumCulling,
71
));
72
73
// camera
74
commands.spawn((
75
Camera3d::default(),
76
Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
77
// We need this component because we use `draw_indexed` and `draw`
78
// instead of `draw_indirect_indexed` and `draw_indirect` in
79
// `DrawMeshInstanced::render`.
80
NoIndirectDrawing,
81
));
82
}
83
84
#[derive(Component, Deref)]
85
struct InstanceMaterialData(Vec<InstanceData>);
86
87
impl ExtractComponent for InstanceMaterialData {
88
type QueryData = &'static InstanceMaterialData;
89
type QueryFilter = ();
90
type Out = Self;
91
92
fn extract_component(item: QueryItem<'_, '_, Self::QueryData>) -> Option<Self> {
93
Some(InstanceMaterialData(item.0.clone()))
94
}
95
}
96
97
struct CustomMaterialPlugin;
98
99
impl Plugin for CustomMaterialPlugin {
100
fn build(&self, app: &mut App) {
101
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
102
app.sub_app_mut(RenderApp)
103
.add_render_command::<Transparent3d, DrawCustom>()
104
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
105
.add_systems(RenderStartup, init_custom_pipeline)
106
.add_systems(
107
Render,
108
(
109
queue_custom.in_set(RenderSystems::QueueMeshes),
110
prepare_instance_buffers.in_set(RenderSystems::PrepareResources),
111
),
112
);
113
}
114
}
115
116
#[derive(Clone, Copy, Pod, Zeroable)]
117
#[repr(C)]
118
struct InstanceData {
119
position: Vec3,
120
scale: f32,
121
color: [f32; 4],
122
}
123
124
fn queue_custom(
125
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
126
custom_pipeline: Res<CustomPipeline>,
127
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
128
pipeline_cache: Res<PipelineCache>,
129
meshes: Res<RenderAssets<RenderMesh>>,
130
render_mesh_instances: Res<RenderMeshInstances>,
131
material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
132
mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
133
views: Query<(&ExtractedView, &Msaa)>,
134
) {
135
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
136
137
for (view, msaa) in &views {
138
let Some(transparent_phase) = transparent_render_phases.get_mut(&view.retained_view_entity)
139
else {
140
continue;
141
};
142
143
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
144
145
let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
146
let rangefinder = view.rangefinder3d();
147
for (entity, main_entity) in &material_meshes {
148
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
149
else {
150
continue;
151
};
152
let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
153
continue;
154
};
155
let key =
156
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
157
let pipeline = pipelines
158
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
159
.unwrap();
160
transparent_phase.add(Transparent3d {
161
entity: (entity, *main_entity),
162
pipeline,
163
draw_function: draw_custom,
164
distance: rangefinder.distance_translation(&mesh_instance.translation),
165
batch_range: 0..1,
166
extra_index: PhaseItemExtraIndex::None,
167
indexed: true,
168
});
169
}
170
}
171
}
172
173
#[derive(Component)]
174
struct InstanceBuffer {
175
buffer: Buffer,
176
length: usize,
177
}
178
179
fn prepare_instance_buffers(
180
mut commands: Commands,
181
query: Query<(Entity, &InstanceMaterialData)>,
182
render_device: Res<RenderDevice>,
183
) {
184
for (entity, instance_data) in &query {
185
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
186
label: Some("instance data buffer"),
187
contents: bytemuck::cast_slice(instance_data.as_slice()),
188
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
189
});
190
commands.entity(entity).insert(InstanceBuffer {
191
buffer,
192
length: instance_data.len(),
193
});
194
}
195
}
196
197
#[derive(Resource)]
198
struct CustomPipeline {
199
shader: Handle<Shader>,
200
mesh_pipeline: MeshPipeline,
201
}
202
203
fn init_custom_pipeline(
204
mut commands: Commands,
205
asset_server: Res<AssetServer>,
206
mesh_pipeline: Res<MeshPipeline>,
207
) {
208
commands.insert_resource(CustomPipeline {
209
shader: asset_server.load(SHADER_ASSET_PATH),
210
mesh_pipeline: mesh_pipeline.clone(),
211
});
212
}
213
214
impl SpecializedMeshPipeline for CustomPipeline {
215
type Key = MeshPipelineKey;
216
217
fn specialize(
218
&self,
219
key: Self::Key,
220
layout: &MeshVertexBufferLayoutRef,
221
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
222
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
223
224
descriptor.vertex.shader = self.shader.clone();
225
descriptor.vertex.buffers.push(VertexBufferLayout {
226
array_stride: size_of::<InstanceData>() as u64,
227
step_mode: VertexStepMode::Instance,
228
attributes: vec![
229
VertexAttribute {
230
format: VertexFormat::Float32x4,
231
offset: 0,
232
shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
233
},
234
VertexAttribute {
235
format: VertexFormat::Float32x4,
236
offset: VertexFormat::Float32x4.size(),
237
shader_location: 4,
238
},
239
],
240
});
241
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
242
Ok(descriptor)
243
}
244
}
245
246
type DrawCustom = (
247
SetItemPipeline,
248
SetMeshViewBindGroup<0>,
249
SetMeshViewBindingArrayBindGroup<1>,
250
SetMeshBindGroup<2>,
251
DrawMeshInstanced,
252
);
253
254
struct DrawMeshInstanced;
255
256
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
257
type Param = (
258
SRes<RenderAssets<RenderMesh>>,
259
SRes<RenderMeshInstances>,
260
SRes<MeshAllocator>,
261
);
262
type ViewQuery = ();
263
type ItemQuery = Read<InstanceBuffer>;
264
265
#[inline]
266
fn render<'w>(
267
item: &P,
268
_view: (),
269
instance_buffer: Option<&'w InstanceBuffer>,
270
(meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
271
pass: &mut TrackedRenderPass<'w>,
272
) -> RenderCommandResult {
273
// A borrow check workaround.
274
let mesh_allocator = mesh_allocator.into_inner();
275
276
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
277
else {
278
return RenderCommandResult::Skip;
279
};
280
let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
281
return RenderCommandResult::Skip;
282
};
283
let Some(instance_buffer) = instance_buffer else {
284
return RenderCommandResult::Skip;
285
};
286
let Some(vertex_buffer_slice) =
287
mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
288
else {
289
return RenderCommandResult::Skip;
290
};
291
292
pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
293
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
294
295
match &gpu_mesh.buffer_info {
296
RenderMeshBufferInfo::Indexed {
297
index_format,
298
count,
299
} => {
300
let Some(index_buffer_slice) =
301
mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
302
else {
303
return RenderCommandResult::Skip;
304
};
305
306
pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
307
pass.draw_indexed(
308
index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
309
vertex_buffer_slice.range.start as i32,
310
0..instance_buffer.length as u32,
311
);
312
}
313
RenderMeshBufferInfo::NonIndexed => {
314
pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
315
}
316
}
317
RenderCommandResult::Success
318
}
319
}
320
321