From 3907ea62c7ffe88bc5759f7909addab5c714700c Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Fri, 5 Jan 2024 00:19:14 -0500 Subject: [PATCH] Try a debug print approach --- .github/workflows/test.yml | 2 +- burn-compute/src/compute.rs | 9 ++++++-- burn-wgpu/src/compute/base.rs | 39 +++++++++++++++++++++++++++++++-- burn-wgpu/src/compute/kernel.rs | 12 +++++----- burn-wgpu/src/graphics.rs | 2 ++ 5 files changed, 53 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 999c70d0b6..8fb6d83699 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -110,7 +110,7 @@ jobs: # run: cargo test -p burn-wgpu --color=always -- --color=always --test-threads 1 # run: cargo test tests::module -p burn-wgpu --color=always -- --color=always --test-threads 1 # run: cargo test tests::matmul -p burn-wgpu --color=always -- --color=always --test-threads 1 - run: cargo test can_run_kernel -p burn-wgpu --color=always -- --color=always --test-threads 1 + run: cargo test can_run_kernel -p burn-wgpu --color=always -- --color=always --test-threads 1 --nocapture # - name: Run cargo clippy for stable version # if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std' diff --git a/burn-compute/src/compute.rs b/burn-compute/src/compute.rs index 33f5d34337..4d854f6075 100644 --- a/burn-compute/src/compute.rs +++ b/burn-compute/src/compute.rs @@ -28,12 +28,15 @@ where where Init: Fn() -> ComputeClient, { + println!("dbg 4_1"); let mut clients = self.clients.lock(); + println!("dbg 4_2"); if clients.is_none() { Self::register_inner(device, init(), &mut clients); } + println!("dbg 4_3"); match clients.deref_mut() { Some(clients) => match clients.get(device) { Some(client) => client.clone(), @@ -68,16 +71,18 @@ where client: ComputeClient, clients: &mut Option>>, ) { + println!("dbg 8_1"); if clients.is_none() { *clients = Some(HashMap::new()); } - + println!("dbg 8_2"); if let Some(clients) = clients { if clients.contains_key(device) { panic!("Client already created for device {:?}", device); } - + println!("dbg 8_3"); clients.insert(device.clone(), client); + println!("dbg 8_4"); } } } diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index 864b40151b..b3becf7bee 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -26,9 +26,11 @@ static COMPUTE: Compute, Channel> = Com /// Get the [compute client](ComputeClient) for the given [device](WgpuDevice). pub fn compute_client(device: &WgpuDevice) -> ComputeClient { + println!("dbg 3_1"); let device = Arc::new(device); - + println!("dbg 3_2"); COMPUTE.client(&device, move || { + println!("dbg 5_1"); pollster::block_on(create_client::(&device)) }) } @@ -42,14 +44,17 @@ pub async fn init_async(device: &WgpuDevice) { } async fn create_client(device: &WgpuDevice) -> ComputeClient { + println!("dbg 6_1"); let (device_wgpu, queue, info) = select_device::(device).await; + println!("dbg 6_2"); log::info!( "Created wgpu compute server on device {:?} => {:?}", device, info ); + println!("dbg 6_3"); // TODO: Support a way to modify max_tasks without std. let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") { Ok(value) => value @@ -58,16 +63,22 @@ async fn create_client(device: &WgpuDevice) -> ComputeClient 64, // 64 tasks by default }; + println!("dbg 6_4"); let device = Arc::new(device_wgpu); + println!("dbg 6_5"); let storage = WgpuStorage::new(device.clone()); + println!("dbg 6_6"); let memory_management = SimpleMemoryManagement::new( storage, DeallocStrategy::new_period_tick(max_tasks * 2), SliceStrategy::Ratio(0.8), ); + println!("dbg 6_7"); let server = WgpuServer::new(memory_management, device, queue, max_tasks); + println!("dbg 6_8"); let channel = Channel::new(server); + println!("dbg 6_9"); ComputeClient::new(channel, Arc::new(Mutex::new(Tuner::new()))) } @@ -75,14 +86,18 @@ async fn create_client(device: &WgpuDevice) -> ComputeClient( device: &WgpuDevice, ) -> (wgpu::Device, wgpu::Queue, wgpu::AdapterInfo) { + println!("dbg 7_1"); #[cfg(target_family = "wasm")] let adapter = select_adapter::(device).await; + println!("dbg 7_2"); #[cfg(not(target_family = "wasm"))] let adapter = select_adapter::(device); + println!("dbg 7_3"); let limits = adapter.limits(); + println!("dbg 7_4"); let (device, queue) = adapter .request_device( &DeviceDescriptor { @@ -102,6 +117,7 @@ pub async fn select_device( }) .unwrap(); + dbg!((&device, &queue, adapter.get_info())); (device, queue, adapter.get_info()) } @@ -119,20 +135,27 @@ async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { use wgpu::DeviceType; + println!("dbg 8_1"); let instance = wgpu::Instance::default(); + println!("dbg 8_1_1"); let mut adapters_other = Vec::new(); + println!("dbg 8_1_2"); let mut adapters = Vec::new(); + println!("dbg 8_2"); instance .enumerate_adapters(G::backend().into()) .for_each(|adapter| { + println!("dbg 8_3"); let device_type = adapter.get_info().device_type; + println!("dbg 8_4"); if let DeviceType::Other = device_type { adapters_other.push(adapter); return; } + println!("dbg 8_5"); let is_same_type = match device { WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, @@ -141,18 +164,23 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { WgpuDevice::BestAvailable => true, }; + println!("dbg 8_6"); + dbg!(&adapter); if is_same_type { adapters.push(adapter); } }); + println!("dbg 8_7"); fn select( num: usize, error: &str, mut adapters: Vec, mut adapters_other: Vec, ) -> wgpu::Adapter { + println!("dbg 8_8"); if adapters.len() <= num { + println!("dbg 8_9"); if adapters_other.len() <= num { panic!( "{}, adapters {:?}, other adapters {:?}", @@ -167,10 +195,12 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { .collect::>(), ); } else { + println!("dbg 8_10"); return adapters_other.remove(num); } } + println!("dbg 8_11"); adapters.remove(num) } @@ -195,13 +225,16 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { ), WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), WgpuDevice::BestAvailable => { + println!("dbg 8_12"); let mut most_performant_adapter = None; let mut current_score = -1; + println!("dbg 8_13"); adapters .into_iter() .chain(adapters_other) .for_each(|adapter| { + println!("dbg 8_14"); let info = adapter.get_info(); let score = match info.device_type { DeviceType::DiscreteGpu => 5, @@ -211,7 +244,7 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { DeviceType::VirtualGpu => 2, DeviceType::Cpu => 1, }; - + println!("dbg 8_15"); if score > current_score { most_performant_adapter = Some(adapter); current_score = score; @@ -219,6 +252,8 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { }); if let Some(adapter) = most_performant_adapter { + println!("dbg 8_16"); + dbg!(&adapter); adapter } else { panic!("No adapter found for graphics API {:?}", G::default()); diff --git a/burn-wgpu/src/compute/kernel.rs b/burn-wgpu/src/compute/kernel.rs index 821ec06138..bb70d398d5 100644 --- a/burn-wgpu/src/compute/kernel.rs +++ b/burn-wgpu/src/compute/kernel.rs @@ -83,6 +83,7 @@ mod tests { #[test] fn can_run_kernel() { + println!("dbg 1_1"); binary!( operator: |elem: Elem| Operator::Add { lhs: Variable::Input(0, elem), @@ -92,9 +93,9 @@ mod tests { elem_in: f32, elem_out: f32 ); - + println!("dbg 1_2"); let client = compute_client::(&WgpuDevice::default()); - + println!("dbg 1_3"); let lhs: Vec = vec![0., 1., 2., 3., 4., 5., 6., 7.]; let rhs: Vec = vec![10., 11., 12., 6., 7., 3., 1., 0.]; let info: Vec = vec![1, 1, 8, 1, 8, 1, 8]; @@ -103,16 +104,15 @@ mod tests { let rhs = client.create(bytemuck::cast_slice(&rhs)); let out = client.empty(core::mem::size_of::() * 8); let info = client.create(bytemuck::cast_slice(&info)); - type Kernel = KernelSettings, f32, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>; let kernel = Box::new(StaticKernel::::new(WorkGroup::new(1, 1, 1))); - + println!("dbg 1_4"); client.execute(kernel, &[&lhs, &rhs, &out, &info]); - + println!("dbg 1_5"); let data = client.read(&out).read_sync().unwrap(); let output: &[f32] = bytemuck::cast_slice(&data); - + println!("dbg 1_6"); assert_eq!(output, [10., 12., 14., 9., 11., 8., 7., 7.]); } } diff --git a/burn-wgpu/src/graphics.rs b/burn-wgpu/src/graphics.rs index 9ebf20b48e..852db50b7c 100644 --- a/burn-wgpu/src/graphics.rs +++ b/burn-wgpu/src/graphics.rs @@ -78,6 +78,7 @@ impl GraphicsApi for WebGpu { impl GraphicsApi for AutoGraphicsApi { fn backend() -> wgpu::Backend { + println!("dbg 2_1"); // Allow overriding AutoGraphicsApi backend with ENV var in std test environments #[cfg(not(no_std))] #[cfg(test)] @@ -98,6 +99,7 @@ impl GraphicsApi for AutoGraphicsApi { } } + println!("dbg 2_2"); // In a no_std environment or if the environment variable is not set #[cfg(target_os = "macos")] return wgpu::Backend::Metal;