Skip to content

Commit 213331b

Browse files
committed
more simd
1 parent 433e71f commit 213331b

File tree

4 files changed

+156
-52
lines changed

4 files changed

+156
-52
lines changed

src/graph.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ impl Graph {
8181
Entry::Vacant(entry) => {
8282
let mut v = allocate_vec(grad.len());
8383
v[..].clone_from_slice(grad);
84-
8584
entry.insert(v);
8685
}
8786
}

src/lib.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,17 @@ impl MinimumOps for ANode {
380380
convert_binops! { impl MinimumOps, minimum for ANode, ANode }
381381
forward_ref_binop! { impl MinimumOps, minimum for ANode, ANode }
382382

383+
pub trait UnaryOps {
384+
fn sqrt(self) -> ANode;
385+
fn sin(self) -> ANode;
386+
}
387+
388+
impl UnaryOps for ANode {
389+
fn sqrt(self) -> ANode {
390+
SquareRoot::new(self)
391+
}
392+
fn sin(self) -> ANode {
393+
Sin::new(self)
394+
}
395+
396+
}

src/ops.rs

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ macro_rules! run_unary_op {
272272
if left_len == out_len {
273273
$func(ArrayInput($left), ArrayOutput($out));
274274
} else if left_len == 1 {
275-
$func(BroadcastInput($left[0], out_len), ArrayOutput($out));
275+
$func(BroadcastInput($left, out_len), ArrayOutput($out));
276276
} else if out_len == 1 {
277-
$func(ArrayInput($left[0], out_len), BroadcastOutput($out:tt, left_len));
277+
$func(ArrayInput($left), BroadcastOutput($out, left_len));
278278
} else {
279279
panic!("Left length: {}, Output Length: {}", left_len, out_len);
280280
}
@@ -403,11 +403,11 @@ impl Subtract {
403403
}
404404

405405
fn compute(left: &ANode, right: &ANode) -> MPVec {
406-
let (lv, rv) = Broadcast::from_pair(left.value(), right.value());
407-
let mut out = allocate_vec(lv.len);
408-
out.iter_mut().zip(lv.zip(rv)).for_each(|(oi, (lvi, rvi))| {
409-
*oi = lvi - rvi
410-
});
406+
let x = left.value();
407+
let y = right.value();
408+
let mut out = Broadcast::allocate_out(x, y);
409+
let o = &mut out;
410+
run_binary_op!(x, y, o, simd_sub);
411411
out
412412
}
413413
}
@@ -431,12 +431,12 @@ impl Node for Subtract {
431431
fn compute_grad(&self, grad: &[DType], child_grads: &mut [&mut [DType]]) {
432432
// f(x,y) = x - y
433433
// df(x,y)/dx = 1
434-
// df(x,y)/dy = -1
435-
let mut out = Updater::new(&mut child_grads[0], grad.len());
436-
grad.iter().for_each(|gi| out.add(*gi));
434+
let out = &mut child_grads[0];
435+
run_unary_op!(grad, out, simd_iadd);
437436

438-
let mut out = Updater::new(&mut child_grads[1], grad.len());
439-
grad.iter().for_each(|gi| out.add(-*gi));
437+
// df(x,y)/dy = -1
438+
let out = &mut child_grads[1];
439+
run_unary_op!(grad, out, grad_sub_y);
440440
}
441441

442442
}
@@ -539,21 +539,10 @@ impl Node for Divide {
539539
let out = &mut child_grads[0];
540540
run_binary_op!(grad, y, out, grad_div_x);
541541

542-
/*
543-
let ly = Broadcast::sized(y, child_grads[0].len());
544-
let mut out = Updater::new(&mut child_grads[0], grad.len());
545-
grad.iter().zip(ly).for_each(|(gi, yi)| out.add(*gi / *yi));
546-
*/
547-
548542
let out = &mut child_grads[1];
543+
// df(x,y)/dy = -x / y ^ 2
549544
run_trinary_op!(grad, x, y, out, grad_div_y);
550545

551-
// df(x,y)/dy = -x / y ^ 2
552-
/*
553-
let (lx, ly) = Broadcast::from_pair(x, y);
554-
let mut out = Updater::new(&mut child_grads[1], lx.len);
555-
grad.iter().zip(lx.zip(ly)).for_each(|(gi, (xi, yi))| out.add(*gi * -*xi / yi.powf(2f32)));
556-
*/
557546
}
558547

559548
}
@@ -664,7 +653,7 @@ impl Node for SquareRoot {
664653
fn requires_grad(&self) -> bool { false }
665654

666655
fn compute_grad(&self, grad: &[DType], child_grads: &mut [&mut [DType]]) {
667-
let x = self.1[0].value();
656+
let x = self.value();
668657

669658
// df(x)/dx = (1/2) / x ^ 0.5
670659
child_grads[0].iter_mut().zip(grad.iter().zip(x)).for_each(|(outi, (gi, xi))| {
@@ -905,7 +894,9 @@ impl Exp {
905894
fn compute(left: &ANode) -> MPVec {
906895
let lv = left.value();
907896
let mut out = allocate_vec(lv.len());
908-
out.iter_mut().zip(lv.iter()).for_each(|(oi, lvi)| *oi = lvi.exp());
897+
let o = &mut out;
898+
run_unary_op!(lv, o, simd_exp);
899+
//out.iter_mut().zip(lv.iter()).for_each(|(oi, lvi)| *oi = lvi.exp());
909900
out
910901
}
911902

@@ -1314,6 +1305,21 @@ mod tests {
13141305
assert_eq!(y_grad, &[3.]);
13151306
}
13161307

1308+
#[test]
1309+
fn test_sqrt() {
1310+
let x = Variable::new(vec![4., 9.]);
1311+
let res = SquareRoot::new(x.clone());
1312+
assert_eq!(res.value(), &[2., 3.]);
1313+
1314+
let mut graph = Graph::new();
1315+
graph.backward(&res);
1316+
1317+
let x_1_g = 1f32 / (2f32 * 2f32);
1318+
let x_2_g = 1f32 / (2f32 * 3f32);
1319+
let x_grad = graph.get_grad(&x).unwrap();
1320+
assert_eq!(x_grad, &[x_1_g, x_2_g]);
1321+
}
1322+
13171323
#[test]
13181324
fn test_div() {
13191325
let x = Variable::new(vec![0., 1.]);
@@ -1466,7 +1472,7 @@ mod tests {
14661472
let x = Variable::new(vec![1., 2., 3.]);
14671473

14681474
let x_slice = x.slice(1, 2);
1469-
let mut out = x_slice * 2.;
1475+
let out = x_slice * 2.;
14701476

14711477
let mut graph = Graph::new();
14721478
graph.backward(&out);

src/vecops.rs

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,55 @@ unsafe fn hsum_avx_ps(v: __m256) -> f32 {
168168
_mm_cvtss_f32(sum_128)
169169
}
170170

171+
#[inline(always)]
172+
pub unsafe fn _mm256_exp_ps(x: __m256) -> __m256 {
173+
// Constants
174+
let ln2 = _mm256_set1_ps(0.6931471805599453); // ln(2)
175+
let ln2_inv = _mm256_set1_ps(1.4426950408889634); // 1/ln(2)
176+
177+
// Scale input by 1/ln(2)
178+
let scaled = _mm256_mul_ps(x, ln2_inv);
179+
// Round scaled value to nearest integer: n = round(x/ln2)
180+
let n = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
181+
// r = x - n*ln2
182+
let r = _mm256_sub_ps(x, _mm256_mul_ps(n, ln2));
183+
184+
// Compute a polynomial approximation of exp(r):
185+
// exp(r) ~ 1 + r + r²/2 + r³/6 + r⁴/24
186+
let c2 = _mm256_set1_ps(1.0); // coefficient for r
187+
let c3 = _mm256_set1_ps(0.5); // coefficient for r²: 1/2
188+
let c4 = _mm256_set1_ps(0.16666667); // coefficient for r³: 1/6
189+
let c5 = _mm256_set1_ps(0.04166667); // coefficient for r⁴: 1/24
190+
191+
let r2 = _mm256_mul_ps(r, r);
192+
let r3 = _mm256_mul_ps(r2, r);
193+
let r4 = _mm256_mul_ps(r3, r);
194+
195+
let poly = _mm256_add_ps(
196+
_mm256_add_ps(
197+
_mm256_add_ps(
198+
_mm256_add_ps(r, c2),
199+
_mm256_mul_ps(r2, c3)
200+
),
201+
_mm256_mul_ps(r3, c4)
202+
),
203+
_mm256_mul_ps(r4, c5)
204+
);
205+
206+
// Compute 2^n using IEEE754 bit-level conversion:
207+
// First, convert n (float) to an integer
208+
let int_n = _mm256_cvtps_epi32(n);
209+
// For a 32-bit float, the exponent field is biased by 127.
210+
// So 2^n is represented by (n + 127) << 23.
211+
let bias = _mm256_set1_epi32(127);
212+
let exp_int = _mm256_add_epi32(int_n, bias);
213+
let exp_int = _mm256_slli_epi32(exp_int, 23);
214+
let two_n = _mm256_castsi256_ps(exp_int);
215+
216+
// Reconstruct exp(x) ≈ exp(r) * 2^n
217+
_mm256_mul_ps(poly, two_n)
218+
}
219+
171220
macro_rules! avx_detect {
172221
($block:expr) => {
173222
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
@@ -183,7 +232,7 @@ macro_rules! avx_detect {
183232

184233
macro_rules! unary_op {
185234
($fname:ident, $sim_op:expr, $fallback_op:expr) => {
186-
pub unsafe fn $fname(
235+
pub fn $fname(
187236
a: impl Input,
188237
mut out: impl Output
189238
) {
@@ -194,12 +243,14 @@ macro_rules! unary_op {
194243
let mut i = 0;
195244

196245
avx_detect! {
197-
// Process in chunks of 8 floats
198-
while i + 8 <= length {
199-
let va = a.fill_256(i);
200-
let res = $sim_op(va);
201-
out.store(res, i);
202-
i += 8;
246+
unsafe {
247+
// Process in chunks of 8 floats
248+
while i + 8 <= length {
249+
let va = a.fill_256(i);
250+
let res = $sim_op(va);
251+
out.store(res, i);
252+
i += 8;
253+
}
203254
}
204255
}
205256

@@ -215,7 +266,7 @@ macro_rules! unary_op {
215266

216267
macro_rules! binary_op {
217268
($fname:ident, $sim_op:expr, $(&mut)? $fallback_op:expr) => {
218-
pub unsafe fn $fname(
269+
pub fn $fname(
219270
a: impl Input,
220271
b: impl Input,
221272
mut out: impl Output
@@ -228,13 +279,15 @@ macro_rules! binary_op {
228279
let mut i = 0;
229280

230281
avx_detect! {
231-
// Process in chunks of 8 floats
232-
while i + 8 <= length {
233-
let va = a.fill_256(i);
234-
let vb = b.fill_256(i);
235-
let res = $sim_op(va, vb);
236-
out.store(res, i);
237-
i += 8;
282+
unsafe {
283+
// Process in chunks of 8 floats
284+
while i + 8 <= length {
285+
let va = a.fill_256(i);
286+
let vb = b.fill_256(i);
287+
let res = $sim_op(va, vb);
288+
out.store(res, i);
289+
i += 8;
290+
}
238291
}
239292
}
240293

@@ -250,7 +303,7 @@ macro_rules! binary_op {
250303

251304
macro_rules! trinary_op {
252305
($fname:ident, $sim_op:expr, $fallback_op:expr) => {
253-
pub unsafe fn $fname(
306+
pub fn $fname(
254307
a: impl Input,
255308
b: impl Input,
256309
c: impl Input,
@@ -265,14 +318,16 @@ macro_rules! trinary_op {
265318
let mut i = 0;
266319

267320
avx_detect! {
268-
// Process in chunks of 8 floats
269-
while i + 8 <= length {
270-
let va = a.fill_256(i);
271-
let vb = b.fill_256(i);
272-
let vc = c.fill_256(i);
273-
let res = $sim_op(va, vb, vc);
274-
out.store(res, i);
275-
i += 8;
321+
unsafe {
322+
// Process in chunks of 8 floats
323+
while i + 8 <= length {
324+
let va = a.fill_256(i);
325+
let vb = b.fill_256(i);
326+
let vc = c.fill_256(i);
327+
let res = $sim_op(va, vb, vc);
328+
out.store(res, i);
329+
i += 8;
330+
}
276331
}
277332
}
278333

@@ -317,3 +372,33 @@ binary_op!(
317372
|xi, yi| { xi * yi }
318373
);
319374

375+
binary_op!(
376+
simd_sub,
377+
_mm256_sub_ps,
378+
|xi, yi| { xi - yi }
379+
);
380+
381+
binary_op!(
382+
simd_add,
383+
_mm256_add_ps,
384+
|xi, yi| { xi - yi }
385+
);
386+
387+
unary_op!(
388+
simd_iadd,
389+
|vo| {vo},
390+
|xi| {xi}
391+
);
392+
393+
unary_op!(
394+
grad_sub_y,
395+
|vo| { _mm256_xor_ps(vo, _mm256_set1_ps(-0f32))},
396+
|xi: f32| {-xi}
397+
);
398+
399+
unary_op!(
400+
simd_exp,
401+
_mm256_exp_ps,
402+
|xi: f32| {xi.exp()}
403+
);
404+

0 commit comments

Comments
 (0)