Skip to content

Commit 632d532

Browse files
feat!: make ProgrammableStage::entry_point optional in wgpu-core
1 parent 99fa77b commit 632d532

File tree

5 files changed

+120
-19
lines changed

5 files changed

+120
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Bottom level categories:
102102
```
103103
- `wgpu::Id` now implements `PartialOrd`/`Ord` allowing it to be put in `BTreeMap`s. By @cwfitzgerald and @9291Sam in [#5176](https://github.com/gfx-rs/wgpu/pull/5176)
104104
- `wgpu::CommandEncoder::write_timestamp` requires now the new `wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS` feature which is available on all native backends but not on WebGPU (due to a spec change `write_timestamp` is no longer supported on WebGPU). By @wumpf in [#5188](https://github.com/gfx-rs/wgpu/pull/5188)
105+
- BREAKING CHANGE: [`wgpu_core::pipeline::ProgrammableStageDescriptor`](https://docs.rs/wgpu-core/latest/wgpu_core/pipeline/struct.ProgrammableStageDescriptor.html#structfield.entry_point) is now optional. By @ErichDonGubler in [#????](https://github.com/gfx-rs/wgpu/pull/????).
105106

106107
#### GLES
107108

wgpu-core/src/device/resource.rs

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2679,14 +2679,21 @@ impl<A: HalApi> Device<A> {
26792679
let mut shader_binding_sizes = FastHashMap::default();
26802680
let io = validation::StageIo::default();
26812681

2682+
let final_entry_point_name;
2683+
26822684
{
26832685
let stage = wgt::ShaderStages::COMPUTE;
26842686

2687+
final_entry_point_name = shader_module.finalize_entry_point_name(
2688+
stage,
2689+
desc.stage.entry_point.as_ref().map(|ep| ep.as_ref()),
2690+
)?;
2691+
26852692
if let Some(ref interface) = shader_module.interface {
26862693
let _ = interface.check_stage(
26872694
&mut binding_layout_source,
26882695
&mut shader_binding_sizes,
2689-
&desc.stage.entry_point,
2696+
&final_entry_point_name,
26902697
stage,
26912698
io,
26922699
None,
@@ -2714,7 +2721,7 @@ impl<A: HalApi> Device<A> {
27142721
label: desc.label.to_hal(self.instance_flags),
27152722
layout: pipeline_layout.raw(),
27162723
stage: hal::ProgrammableStage {
2717-
entry_point: desc.stage.entry_point.as_ref(),
2724+
entry_point: final_entry_point_name.as_ref(),
27182725
module: shader_module.raw(),
27192726
},
27202727
};
@@ -3086,6 +3093,7 @@ impl<A: HalApi> Device<A> {
30863093
};
30873094

30883095
let vertex_shader_module;
3096+
let vertex_entry_point_name;
30893097
let vertex_stage = {
30903098
let stage_desc = &desc.vertex.stage;
30913099
let stage = wgt::ShaderStages::VERTEX;
@@ -3102,12 +3110,19 @@ impl<A: HalApi> Device<A> {
31023110

31033111
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
31043112

3113+
vertex_entry_point_name = vertex_shader_module
3114+
.finalize_entry_point_name(
3115+
stage,
3116+
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
3117+
)
3118+
.map_err(stage_err)?;
3119+
31053120
if let Some(ref interface) = vertex_shader_module.interface {
31063121
io = interface
31073122
.check_stage(
31083123
&mut binding_layout_source,
31093124
&mut shader_binding_sizes,
3110-
&stage_desc.entry_point,
3125+
&vertex_entry_point_name,
31113126
stage,
31123127
io,
31133128
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@@ -3118,11 +3133,12 @@ impl<A: HalApi> Device<A> {
31183133

31193134
hal::ProgrammableStage {
31203135
module: vertex_shader_module.raw(),
3121-
entry_point: stage_desc.entry_point.as_ref(),
3136+
entry_point: &vertex_entry_point_name,
31223137
}
31233138
};
31243139

31253140
let mut fragment_shader_module = None;
3141+
let fragment_entry_point_name;
31263142
let fragment_stage = match desc.fragment {
31273143
Some(ref fragment_state) => {
31283144
let stage = wgt::ShaderStages::FRAGMENT;
@@ -3138,13 +3154,24 @@ impl<A: HalApi> Device<A> {
31383154

31393155
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
31403156

3157+
fragment_entry_point_name = shader_module
3158+
.finalize_entry_point_name(
3159+
stage,
3160+
fragment_state
3161+
.stage
3162+
.entry_point
3163+
.as_ref()
3164+
.map(|ep| ep.as_ref()),
3165+
)
3166+
.map_err(stage_err)?;
3167+
31413168
if validated_stages == wgt::ShaderStages::VERTEX {
31423169
if let Some(ref interface) = shader_module.interface {
31433170
io = interface
31443171
.check_stage(
31453172
&mut binding_layout_source,
31463173
&mut shader_binding_sizes,
3147-
&fragment_state.stage.entry_point,
3174+
&fragment_entry_point_name,
31483175
stage,
31493176
io,
31503177
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@@ -3156,7 +3183,7 @@ impl<A: HalApi> Device<A> {
31563183

31573184
if let Some(ref interface) = shader_module.interface {
31583185
shader_expects_dual_source_blending = interface
3159-
.fragment_uses_dual_source_blending(&fragment_state.stage.entry_point)
3186+
.fragment_uses_dual_source_blending(&fragment_entry_point_name)
31603187
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
31613188
stage,
31623189
error,
@@ -3165,7 +3192,7 @@ impl<A: HalApi> Device<A> {
31653192

31663193
Some(hal::ProgrammableStage {
31673194
module: shader_module.raw(),
3168-
entry_point: fragment_state.stage.entry_point.as_ref(),
3195+
entry_point: &fragment_entry_point_name,
31693196
})
31703197
}
31713198
None => None,

wgpu-core/src/pipeline.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ impl<A: HalApi> ShaderModule<A> {
9292
pub(crate) fn raw(&self) -> &A::ShaderModule {
9393
self.raw.as_ref().unwrap()
9494
}
95+
96+
pub(crate) fn finalize_entry_point_name(
97+
&self,
98+
stage_bit: wgt::ShaderStages,
99+
entry_point: Option<&str>,
100+
) -> Result<validation::FinalizedEntryPointName, validation::StageError> {
101+
match &self.interface {
102+
Some(interface) => interface.finalize_entry_point_name(stage_bit, entry_point),
103+
None => entry_point
104+
.map(|ep| ep.to_string())
105+
.map(validation::FinalizedEntryPointName::new)
106+
.ok_or(validation::StageError::NoEntryPointFound),
107+
}
108+
}
95109
}
96110

97111
#[derive(Clone, Debug)]
@@ -213,9 +227,13 @@ impl CreateShaderModuleError {
213227
pub struct ProgrammableStageDescriptor<'a> {
214228
/// The compiled shader module for this stage.
215229
pub module: ShaderModuleId,
216-
/// The name of the entry point in the compiled shader. There must be a function with this name
217-
/// in the shader.
218-
pub entry_point: Cow<'a, str>,
230+
/// The name of the entry point in the compiled shader. The name is selected using the
231+
/// following logic:
232+
///
233+
/// * If `Some(name)` is specified, there must be a function with this name in the shader.
234+
/// * If a single entry point associated with this stage must be in the shader, then proceed as
235+
/// if `Some(…)` was specified with that entry point's name.
236+
pub entry_point: Option<Cow<'a, str>>,
219237
}
220238

221239
/// Number of implicit bind groups derived at pipeline creation.

wgpu-core/src/validation.rs

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
FastHashMap, FastHashSet,
55
};
66
use arrayvec::ArrayVec;
7-
use std::{collections::hash_map::Entry, fmt};
7+
use std::{collections::hash_map::Entry, fmt, ops::Deref};
88
use thiserror::Error;
99
use wgt::{BindGroupLayoutEntry, BindingType};
1010

@@ -283,6 +283,16 @@ pub enum StageError {
283283
},
284284
#[error("Location[{location}] is provided by the previous stage output but is not consumed as input by this stage.")]
285285
InputNotConsumed { location: wgt::ShaderLocation },
286+
#[error(
287+
"Unable to select an entry point: no entry point was found in the provided shader module"
288+
)]
289+
NoEntryPointFound,
290+
#[error(
291+
"Unable to select an entry point: \
292+
multiple entry points were found in the provided shader module, \
293+
but no entry point was specified"
294+
)]
295+
MultipleEntryPointsFound,
286296
}
287297

288298
fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> {
@@ -971,6 +981,30 @@ impl Interface {
971981
}
972982
}
973983

984+
pub fn finalize_entry_point_name(
985+
&self,
986+
stage_bit: wgt::ShaderStages,
987+
entry_point_name: Option<&str>,
988+
) -> Result<FinalizedEntryPointName, StageError> {
989+
let stage = Self::shader_stage_from_stage_bit(stage_bit);
990+
Ok(FinalizedEntryPointName::new(
991+
entry_point_name
992+
.map(|ep| ep.to_string())
993+
.map(Ok)
994+
.unwrap_or_else(|| {
995+
let mut entry_points = self
996+
.entry_points
997+
.keys()
998+
.filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
999+
let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
1000+
entry_points
1001+
.next()
1002+
.ok_or(StageError::MultipleEntryPointsFound)?;
1003+
Ok(first.clone())
1004+
})?,
1005+
))
1006+
}
1007+
9741008
pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage {
9751009
match stage_bit {
9761010
wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex,
@@ -984,7 +1018,7 @@ impl Interface {
9841018
&self,
9851019
layouts: &mut BindingLayoutSource<'_>,
9861020
shader_binding_sizes: &mut FastHashMap<naga::ResourceBinding, wgt::BufferSize>,
987-
entry_point_name: &str,
1021+
entry_point_name: &FinalizedEntryPointName,
9881022
stage_bit: wgt::ShaderStages,
9891023
inputs: StageIo,
9901024
compare_function: Option<wgt::CompareFunction>,
@@ -993,10 +1027,11 @@ impl Interface {
9931027
// we need to look for one with the right execution model.
9941028
let shader_stage = Self::shader_stage_from_stage_bit(stage_bit);
9951029
let pair = (shader_stage, entry_point_name.to_string());
996-
let entry_point = self
997-
.entry_points
998-
.get(&pair)
999-
.ok_or(StageError::MissingEntryPoint(pair.1))?;
1030+
let entry_point = match self.entry_points.get(&pair) {
1031+
Some(some) => some,
1032+
None => return Err(StageError::MissingEntryPoint(pair.1)),
1033+
};
1034+
let (_stage, entry_point_name) = pair;
10001035

10011036
// check resources visibility
10021037
for &handle in entry_point.resources.iter() {
@@ -1294,3 +1329,23 @@ pub fn validate_color_attachment_bytes_per_sample(
12941329

12951330
Ok(())
12961331
}
1332+
1333+
/// An entry point name finalized with [`Interface::finalize_entry_point_name`].
1334+
#[derive(Debug)]
1335+
pub struct FinalizedEntryPointName {
1336+
inner: String,
1337+
}
1338+
1339+
impl Deref for FinalizedEntryPointName {
1340+
type Target = str;
1341+
1342+
fn deref(&self) -> &Self::Target {
1343+
self.inner.as_ref()
1344+
}
1345+
}
1346+
1347+
impl FinalizedEntryPointName {
1348+
pub(crate) fn new(inner: String) -> Self {
1349+
Self { inner }
1350+
}
1351+
}

wgpu/src/backend/wgpu_core.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ impl crate::Context for ContextWgpuCore {
11021102
vertex: pipe::VertexState {
11031103
stage: pipe::ProgrammableStageDescriptor {
11041104
module: desc.vertex.module.id.into(),
1105-
entry_point: Borrowed(desc.vertex.entry_point),
1105+
entry_point: Some(Borrowed(desc.vertex.entry_point)),
11061106
},
11071107
buffers: Borrowed(&vertex_buffers),
11081108
},
@@ -1112,7 +1112,7 @@ impl crate::Context for ContextWgpuCore {
11121112
fragment: desc.fragment.as_ref().map(|frag| pipe::FragmentState {
11131113
stage: pipe::ProgrammableStageDescriptor {
11141114
module: frag.module.id.into(),
1115-
entry_point: Borrowed(frag.entry_point),
1115+
entry_point: Some(Borrowed(frag.entry_point)),
11161116
},
11171117
targets: Borrowed(frag.targets),
11181118
}),
@@ -1160,7 +1160,7 @@ impl crate::Context for ContextWgpuCore {
11601160
layout: desc.layout.map(|l| l.id.into()),
11611161
stage: pipe::ProgrammableStageDescriptor {
11621162
module: desc.module.id.into(),
1163-
entry_point: Borrowed(desc.entry_point),
1163+
entry_point: Some(Borrowed(desc.entry_point)),
11641164
},
11651165
};
11661166

0 commit comments

Comments
 (0)