Skip to content

Implement as_hal for BLASes and TLASes #7303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
- Added `CommandEncoder::transition_resources()` for native API interop, and allowing users to slightly optimize barriers. By @JMS55 in [#6678](https://github.com/gfx-rs/wgpu/pull/6678).
- Add `wgpu_hal::vulkan::Adapter::texture_format_as_raw` for native API interop. By @JMS55 in [#7228](https://github.com/gfx-rs/wgpu/pull/7228).

- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) .
- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183).
- Add `as_hal` for both acceleration structures. By @Vecvec in [#7303](https://github.com/gfx-rs/wgpu/pull/7303).


#### Naga
Expand Down
56 changes: 56 additions & 0 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use crate::{
FastHashSet,
};

use crate::id::{BlasId, TlasId};

struct TriangleBufferStore<'a> {
vertex_buffer: Arc<Buffer>,
vertex_transition: Option<PendingTransition<BufferUses>>,
Expand Down Expand Up @@ -61,6 +63,60 @@ struct TlasBufferStore {
}

impl Global {
pub fn command_encoder_mark_acceleration_structures_built(
&self,
command_encoder_id: CommandEncoderId,
blas_ids: &[BlasId],
tlas_ids: &[TlasId],
) -> Result<(), BuildAccelerationStructureError> {
profiling::scope!("CommandEncoder::mark_acceleration_structures_built");

let hub = &self.hub;

let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id());

let device = &cmd_buf.device;

device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let build_command_index = NonZeroU64::new(
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();

let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard;

cmd_buf_data.blas_actions.reserve(blas_ids.len());

cmd_buf_data.tlas_actions.reserve(tlas_ids.len());

for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
}

for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
cmd_buf_data.tlas_actions.push(TlasAction {
tlas,
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies: Vec::new(),
},
});
}

Ok(())
}
// Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
// making for the two to be implemented differently, the main difference is this function has separate buffers for each
// of the TLAS instances while the other has one large buffer
Expand Down
50 changes: 50 additions & 0 deletions wgpu-core/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ use crate::{
Label, LabelHelpers, SubmissionIndex,
};

use crate::id::{BlasId, TlasId};

/// Information about the wgpu-core resource.
///
/// Each type representing a `wgpu-core` resource, like [`Device`],
Expand Down Expand Up @@ -1412,6 +1414,54 @@ impl Global {

hal_queue_callback(hal_queue)
}

/// # Safety
///
/// - The raw blas handle must not be manually destroyed
pub unsafe fn blas_as_hal<A: HalApi, F: FnOnce(Option<&A::AccelerationStructure>) -> R, R>(
&self,
id: BlasId,
hal_blas_callback: F,
) -> R {
profiling::scope!("Blas::as_hal");

let hub = &self.hub;

if let Ok(blas) = hub.blas_s.get(id).get() {
let snatch_guard = blas.device.snatchable_lock.read();
let hal_blas = blas
.try_raw(&snatch_guard)
.ok()
.and_then(|b| b.as_any().downcast_ref());
hal_blas_callback(hal_blas)
} else {
hal_blas_callback(None)
}
}

/// # Safety
///
/// - The raw tlas handle must not be manually destroyed
pub unsafe fn tlas_as_hal<A: HalApi, F: FnOnce(Option<&A::AccelerationStructure>) -> R, R>(
&self,
id: TlasId,
hal_tlas_callback: F,
) -> R {
profiling::scope!("Blas::as_hal");

let hub = &self.hub;

if let Ok(tlas) = hub.tlas_s.get(id).get() {
let snatch_guard = tlas.device.snatchable_lock.read();
let hal_tlas = tlas
.try_raw(&snatch_guard)
.ok()
.and_then(|t| t.as_any().downcast_ref());
hal_tlas_callback(hal_tlas)
} else {
hal_tlas_callback(None)
}
}
}

/// A texture that has been marked as destroyed and is staged for actual deletion soon.
Expand Down
24 changes: 24 additions & 0 deletions wgpu/src/api/blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,30 @@ impl Blas {
pub fn handle(&self) -> Option<u64> {
self.handle
}

/// Returns the inner hal Acceleration Structure using a callback. The hal acceleration structure
/// will be `None` if the backend type argument does not match with this wgpu Blas
///
/// This method will start the wgpu_core level command recording.
///
/// # Safety
///
/// - The raw handle obtained from the hal Acceleration Structure must not be manually destroyed
#[cfg(wgpu_core)]
pub unsafe fn as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&mut self,
hal_blas_callback: F,
) -> R {
if let Some(blas) = self.inner.as_core_opt() {
unsafe { blas.context.blas_as_hal::<A, F, R>(blas, hal_blas_callback) }
} else {
hal_blas_callback(None)
}
}
}

/// Context version of [BlasTriangleGeometry].
Expand Down
16 changes: 16 additions & 0 deletions wgpu/src/api/command_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ impl CommandEncoder {

/// [`Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE`] must be enabled on the device in order to call these functions.
impl CommandEncoder {
/// Mark acceleration structures as being built. ***Should only*** be used with wgpu-hal
/// functions, all wgpu functions already mark acceleration structures as built.
///
/// # Safety
///
/// - All acceleration structures must have been build in this command encoder.
/// - All BLASes inputted must have been built before all TLASes that were inputted here and
/// which use them.
pub unsafe fn mark_acceleration_structures_built<'a>(
&self,
blas: impl IntoIterator<Item = &'a Blas>,
tlas: impl IntoIterator<Item = &'a Tlas>,
) {
self.inner
.mark_acceleration_structures_built(&mut blas.into_iter(), &mut tlas.into_iter())
}
/// Build bottom and top level acceleration structures.
///
/// Builds the BLASes then the TLASes, but does ***not*** build the BLASes into the TLASes,
Expand Down
27 changes: 27 additions & 0 deletions wgpu/src/api/tlas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ static_assertions::assert_impl_all!(Tlas: WasmNotSendSync);

crate::cmp::impl_eq_ord_hash_proxy!(Tlas => .shared.inner);

impl Tlas {
/// Returns the inner hal Acceleration Structure using a callback. The hal acceleration structure
/// will be `None` if the backend type argument does not match with this wgpu Tlas
///
/// This method will start the wgpu_core level command recording.
///
/// # Safety
///
/// - The raw handle obtained from the hal Acceleration Structure must not be manually destroyed
/// - If the raw handle is build,
#[cfg(wgpu_core)]
pub unsafe fn as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&mut self,
hal_tlas_callback: F,
) -> R {
if let Some(tlas) = self.shared.inner.as_core_opt() {
unsafe { tlas.context.tlas_as_hal::<A, F, R>(tlas, hal_tlas_callback) }
} else {
hal_tlas_callback(None)
}
}
}

/// Entry for a top level acceleration structure build.
/// Used with raw instance buffers for an unvalidated builds.
/// See [`TlasPackage`] for the safe version.
Expand Down
10 changes: 9 additions & 1 deletion wgpu/src/backend/webgpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use wgt::Backends;
use js_sys::Promise;
use wasm_bindgen::{prelude::*, JsCast};

use crate::{dispatch, SurfaceTargetUnsafe};
use crate::{dispatch, Blas, SurfaceTargetUnsafe, Tlas};

use defined_non_null_js_value::DefinedNonNullJsValue;

Expand Down Expand Up @@ -3084,6 +3084,14 @@ impl dispatch::CommandEncoderInterface for WebCommandEncoder {
);
}

fn mark_acceleration_structures_built<'a>(
&self,
_blas: &mut dyn Iterator<Item = &'a Blas>,
_tlas: &mut dyn Iterator<Item = &'a Tlas>,
) {
unimplemented!("Raytracing not implemented for web");
}

fn build_acceleration_structures_unsafe_tlas<'a>(
&self,
_blas: &mut dyn Iterator<Item = &'a crate::BlasBuildEntry<'a>>,
Expand Down
52 changes: 50 additions & 2 deletions wgpu/src/backend/wgpu_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ use wgt::WasmNotSendSync;
use crate::{
api,
dispatch::{self, BufferMappedRangeInterface},
BindingResource, BufferBinding, BufferDescriptor, CompilationInfo, CompilationMessage,
BindingResource, Blas, BufferBinding, BufferDescriptor, CompilationInfo, CompilationMessage,
CompilationMessageType, ErrorSource, Features, Label, LoadOp, MapMode, Operations,
ShaderSource, SurfaceTargetUnsafe, TextureDescriptor,
ShaderSource, SurfaceTargetUnsafe, TextureDescriptor, Tlas,
};

#[derive(Clone)]
Expand Down Expand Up @@ -267,6 +267,30 @@ impl ContextWgpuCore {
}
}

pub unsafe fn blas_as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&self,
blas: &CoreBlas,
hal_blas_callback: F,
) -> R {
unsafe { self.0.blas_as_hal::<A, F, R>(blas.id, hal_blas_callback) }
}

pub unsafe fn tlas_as_hal<
A: wgc::hal_api::HalApi,
F: FnOnce(Option<&A::AccelerationStructure>) -> R,
R,
>(
&self,
tlas: &CoreTlas,
hal_tlas_callback: F,
) -> R {
unsafe { self.0.tlas_as_hal::<A, F, R>(tlas.id, hal_tlas_callback) }
}

pub fn generate_report(&self) -> wgc::global::GlobalReport {
self.0.generate_report()
}
Expand Down Expand Up @@ -2486,6 +2510,30 @@ impl dispatch::CommandEncoderInterface for CoreCommandEncoder {
}
}

fn mark_acceleration_structures_built<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a Blas>,
tlas: &mut dyn Iterator<Item = &'a Tlas>,
) {
let blas = blas
.map(|b| b.inner.as_core().id)
.collect::<SmallVec<[_; 4]>>();
let tlas = tlas
.map(|t| t.shared.inner.as_core().id)
.collect::<SmallVec<[_; 4]>>();
if let Err(cause) = self
.context
.0
.command_encoder_mark_acceleration_structures_built(self.id, &blas, &tlas)
{
self.context.handle_error_nolabel(
&self.error_sink,
cause,
"CommandEncoder::build_acceleration_structures_unsafe_tlas",
);
}
}

fn build_acceleration_structures_unsafe_tlas<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a crate::BlasBuildEntry<'a>>,
Expand Down
7 changes: 6 additions & 1 deletion wgpu/src/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#![allow(clippy::too_many_arguments)] // It's fine.
#![allow(missing_docs, clippy::missing_safety_doc)] // Interfaces are not documented

use crate::{WasmNotSend, WasmNotSendSync};
use crate::{Blas, Tlas, WasmNotSend, WasmNotSendSync};

use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
use core::{any::Any, fmt::Debug, future::Future, hash::Hash, ops::Range, pin::Pin};
Expand Down Expand Up @@ -311,6 +311,11 @@ pub trait CommandEncoderInterface: CommonTraits {
destination: &DispatchBuffer,
destination_offset: crate::BufferAddress,
);
fn mark_acceleration_structures_built<'a>(
&self,
blas: &mut dyn Iterator<Item = &'a Blas>,
tlas: &mut dyn Iterator<Item = &'a Tlas>,
);

fn build_acceleration_structures_unsafe_tlas<'a>(
&self,
Expand Down