Skip to content

Commit 3ce1393

Browse files
authored
Small cleanup in CUDA sparse code (#466)
1 parent 782c854 commit 3ce1393

File tree

2 files changed

+91
-110
lines changed

2 files changed

+91
-110
lines changed

crates/bullet_cuda_backend/src/ops/sparse.rs

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,102 @@
1-
mod bwd;
21
mod fwd;
32

4-
use acyclib::{device::function, graph::ir::operation::sparse::SparseAffineImpl};
3+
use acyclib::{
4+
device::{function, operation::DiffableFromOutput},
5+
graph::ir::operation::sparse::SparseAffineImpl,
6+
};
57

6-
use crate::{CudaDevice, kernel::Kernel};
8+
use crate::{
9+
CudaDevice,
10+
kernel::{Expr, Kernel, KernelArgs, KernelInput},
11+
};
712

813
impl SparseAffineImpl for CudaDevice {
914
type Bwd = Kernel;
1015
type Fwd = Kernel;
1116

12-
fn bwd(op: function::BackpropSparseAffineActivate<Self>) -> Self::Bwd {
13-
bwd::kernel(op)
17+
fn bwd(desc: function::BackpropSparseAffineActivate<Self>) -> Self::Bwd {
18+
const MAXIMUM_BLOCKS_Y: i32 = 32768;
19+
20+
let output_shape = desc.weights_shape * desc.input_shape;
21+
let indices = desc.indices;
22+
23+
assert_eq!(desc.weights_shape.size(), desc.weights_grads.shape().size());
24+
assert_eq!(desc.input_shape.size(), indices.shape().size());
25+
assert_eq!(desc.input_shape.cols(), 1);
26+
assert_eq!(output_shape.cols(), 1);
27+
28+
let bias = desc.biases_grads.as_ref().map(|x| x.batch_size().is_some());
29+
30+
let batched = indices.batch_size().is_some();
31+
let nnz = indices.sparse().nnz;
32+
let m = output_shape.rows();
33+
34+
let code = include_str!("sparse/bwd.cu")
35+
.lines()
36+
.skip(8)
37+
.map(|x| format!("{x}\n"))
38+
.collect::<String>()
39+
.replace(
40+
"INV_DERIV",
41+
match desc.activation {
42+
DiffableFromOutput::Identity => "1.0F",
43+
DiffableFromOutput::ReLU => "x > 0.0F ? 1.0F : 0.0F",
44+
DiffableFromOutput::CReLU => "x > 0.0F && x < 1.0F ? 1.0F : 0.0F",
45+
DiffableFromOutput::SCReLU => "x > 0.0F && x < 1.0F ? 2.0F * sqrtf(x) : 0.0F",
46+
DiffableFromOutput::SqrReLU => "x > 0.0F ? 2.0F * sqrtf(x) : 0.0F",
47+
DiffableFromOutput::Sigmoid => "x * (1.0F - x)",
48+
},
49+
)
50+
.replace("DECL_MAXY", &MAXIMUM_BLOCKS_Y.to_string())
51+
.replace("DECL_M", &m.to_string())
52+
.replace("DECL_NNZ", &nnz.to_string())
53+
.replace("BIAS_ARG", if bias.is_some() { ",float* Bg" } else { "" })
54+
.replace(
55+
"BIAS_BACKPROP",
56+
match bias {
57+
None => "",
58+
Some(true) => "if (tE != 0.0F) { atomicAdd(&Bg[m * loc + row], tE); }",
59+
Some(false) => "if (tE != 0.0F) { atomicAdd(&Bg[row], tE); }",
60+
},
61+
);
62+
63+
let batch_size = Expr::Var;
64+
65+
let mut inputs = vec![
66+
KernelInput::Size(batch_size.clone()),
67+
KernelInput::Slice { slice: indices, layout: Some(nnz), mutable: false, batched, shape: desc.input_shape },
68+
KernelInput::Slice { slice: desc.output, layout: None, mutable: false, batched, shape: output_shape },
69+
KernelInput::Slice { slice: desc.output_grads, layout: None, mutable: false, batched, shape: output_shape },
70+
KernelInput::Slice {
71+
slice: desc.weights_grads,
72+
layout: None,
73+
mutable: true,
74+
batched: false,
75+
shape: desc.weights_shape,
76+
},
77+
];
78+
79+
if let Some(bias) = desc.biases_grads {
80+
let batched = bias.batch_size().is_some();
81+
let shape = bias.shape();
82+
assert_eq!(shape.size(), output_shape.size());
83+
84+
inputs.push(KernelInput::Slice { slice: bias, layout: None, mutable: true, batched, shape: output_shape });
85+
}
86+
87+
let maxy = Expr::Const(MAXIMUM_BLOCKS_Y);
88+
let threads = m.min(1024);
89+
let chunks = m.div_ceil(threads);
90+
let ky = batch_size.min(&maxy);
91+
let kz = (batch_size + maxy.clone() - 1) / maxy;
92+
let grid_dim = [Expr::Const(chunks as i32), ky, kz];
93+
let block_dim = [Expr::Const(threads as i32), Expr::Const(1), Expr::Const(1)];
94+
95+
let shared_mem_bytes = Expr::Const(0);
96+
97+
let args = KernelArgs { inputs, grid_dim, block_dim, shared_mem_bytes };
98+
99+
unsafe { Kernel::new("SparseAffineActiveBackward".to_string(), code, args).unwrap() }
14100
}
15101

16102
fn fwd(op: function::SparseAffineActivate<Self>) -> Self::Fwd {

crates/bullet_cuda_backend/src/ops/sparse/bwd.rs

Lines changed: 0 additions & 105 deletions
This file was deleted.

0 commit comments

Comments
 (0)