Skip to content

Commit 5a4f4f4

Browse files
committed
Call set_device in tests for valid cuda context in threaded runs
1 parent 57cc29e commit 5a4f4f4

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

src/algorithm/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,10 +1444,12 @@ dim_reduce_by_key_nan_func_def!(
14441444
mod tests {
14451445
use super::super::core::c32;
14461446
use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all};
1447+
use crate::core::set_device;
14471448
use crate::randu;
14481449

14491450
#[test]
14501451
fn all_reduce_api() {
1452+
set_device(0);
14511453
let a = randu!(c32; 10, 10);
14521454
println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
14531455

@@ -1469,6 +1471,7 @@ mod tests {
14691471

14701472
#[test]
14711473
fn all_ireduce_api() {
1474+
set_device(0);
14721475
let a = randu!(c32; 10);
14731476
println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
14741477

src/core/data.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,13 +962,15 @@ mod tests {
962962
use super::reorder_v2;
963963

964964
use super::super::defines::BorderType;
965+
use super::super::device::set_device;
965966
use super::super::random::randu;
966967
use super::pad;
967968

968969
use crate::dim4;
969970

970971
#[test]
971972
fn check_reorder_api() {
973+
set_device(0);
972974
let a = randu::<f32>(dim4!(4, 5, 2, 3));
973975

974976
let _transposed = reorder_v2(&a, 1, 0, None);
@@ -979,6 +981,7 @@ mod tests {
979981

980982
#[test]
981983
fn check_pad_api() {
984+
set_device(0);
982985
let a = randu::<f32>(dim4![3, 3]);
983986
let begin_dims = dim4!(0, 0, 0, 0);
984987
let end_dims = dim4!(2, 2, 0, 0);

src/core/index.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ impl SeqInternal {
655655
mod tests {
656656
use super::super::array::Array;
657657
use super::super::data::constant;
658+
use super::super::device::set_device;
658659
use super::super::dim4::Dim4;
659660
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
660661
use super::super::index::{cols, rows};
@@ -665,6 +666,7 @@ mod tests {
665666

666667
#[test]
667668
fn non_macro_seq_index() {
669+
set_device(0);
668670
// ANCHOR: non_macro_seq_index
669671
let dims = Dim4::new(&[5, 5, 1, 1]);
670672
let a = randu::<f32>(dims);
@@ -690,6 +692,7 @@ mod tests {
690692

691693
#[test]
692694
fn seq_index() {
695+
set_device(0);
693696
// ANCHOR: seq_index
694697
let dims = dim4!(5, 5, 1, 1);
695698
let a = randu::<f32>(dims);
@@ -701,18 +704,19 @@ mod tests {
701704

702705
#[test]
703706
fn non_macro_seq_assign() {
707+
set_device(0);
704708
// ANCHOR: non_macro_seq_assign
705-
let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
709+
let mut a = constant(2.0 as f32, dim4!(5, 3));
706710
//print(&a);
707711
// 2.0 2.0 2.0
708712
// 2.0 2.0 2.0
709713
// 2.0 2.0 2.0
710714
// 2.0 2.0 2.0
711715
// 2.0 2.0 2.0
712716

713-
let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
714-
let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
715-
assign_seq(&mut a, seqs, &b);
717+
let b = constant(1.0 as f32, dim4!(3, 3));
718+
let seqs = [seq!(1:3:1), seq!()];
719+
assign_seq(&mut a, &seqs, &b);
716720
//print(&a);
717721
// 2.0 2.0 2.0
718722
// 1.0 1.0 1.0
@@ -724,6 +728,7 @@ mod tests {
724728

725729
#[test]
726730
fn non_macro_seq_array_index() {
731+
set_device(0);
727732
// ANCHOR: non_macro_seq_array_index
728733
let values: [f32; 3] = [1.0, 2.0, 3.0];
729734
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
@@ -751,6 +756,7 @@ mod tests {
751756

752757
#[test]
753758
fn seq_array_index() {
759+
set_device(0);
754760
// ANCHOR: seq_array_index
755761
let values: [f32; 3] = [1.0, 2.0, 3.0];
756762
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
@@ -762,6 +768,7 @@ mod tests {
762768

763769
#[test]
764770
fn non_macro_seq_array_assign() {
771+
set_device(0);
765772
// ANCHOR: non_macro_seq_array_assign
766773
let values: [f32; 3] = [1.0, 2.0, 3.0];
767774
let indices = Array::new(&values, dim4!(3, 1, 1, 1));
@@ -793,6 +800,7 @@ mod tests {
793800

794801
#[test]
795802
fn setrow() {
803+
set_device(0);
796804
// ANCHOR: setrow
797805
let a = randu::<f32>(dim4!(5, 5, 1, 1));
798806
//print(&a);
@@ -817,6 +825,7 @@ mod tests {
817825

818826
#[test]
819827
fn get_row() {
828+
set_device(0);
820829
// ANCHOR: get_row
821830
let a = randu::<f32>(dim4!(5, 5));
822831
// [5 5 1 1]
@@ -840,6 +849,7 @@ mod tests {
840849

841850
#[test]
842851
fn get_rows() {
852+
set_device(0);
843853
// ANCHOR: get_rows
844854
let a = randu::<f32>(dim4!(5, 5));
845855
// [5 5 1 1]

src/core/macros.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ macro_rules! randn {
353353
mod tests {
354354
use super::super::array::Array;
355355
use super::super::data::constant;
356+
use super::super::device::set_device;
356357
use super::super::index::index;
357358
use super::super::random::randu;
358359

@@ -377,6 +378,7 @@ mod tests {
377378

378379
#[test]
379380
fn seq_view() {
381+
set_device(0);
380382
let mut dim4d = dim4!(5, 3, 2, 1);
381383
dim4d[2] = 1;
382384

@@ -387,14 +389,17 @@ mod tests {
387389

388390
#[test]
389391
fn seq_view2() {
392+
set_device(0);
390393
// ANCHOR: seq_view2
391394
let a = randu::<f32>(dim4!(5, 5));
392395
let _sub = view!(a[1:3:1, 1:1:0]); // 1:1:0 means all elements along axis
393-
// ANCHOR_END: seq_view2
396+
397+
// ANCHOR_END: seq_view2
394398
}
395399

396400
#[test]
397401
fn view_macro() {
402+
set_device(0);
398403
let dims = dim4!(5, 5, 2, 1);
399404
let a = randu::<f32>(dims);
400405
let b = a.clone();
@@ -421,6 +426,7 @@ mod tests {
421426

422427
#[test]
423428
fn eval_assign_seq_indexed_array() {
429+
set_device(0);
424430
let dims = dim4!(5, 5);
425431
let mut a = randu::<f32>(dims);
426432
//print(&a);
@@ -456,6 +462,7 @@ mod tests {
456462

457463
#[test]
458464
fn eval_assign_array_to_seqd_array() {
465+
set_device(0);
459466
// ANCHOR: macro_seq_assign
460467
let mut a = randu::<f32>(dim4!(5, 5));
461468
let b = randu::<f32>(dim4!(2, 2));
@@ -465,6 +472,7 @@ mod tests {
465472

466473
#[test]
467474
fn macro_seq_array_assign() {
475+
set_device(0);
468476
// ANCHOR: macro_seq_array_assign
469477
let values: [f32; 3] = [1.0, 2.0, 3.0];
470478
let indices = Array::new(&values, dim4!(3));
@@ -479,6 +487,7 @@ mod tests {
479487

480488
#[test]
481489
fn constant_macro() {
490+
set_device(0);
482491
let _zeros_1d = constant!(0.0f32; 10);
483492
let _zeros_2d = constant!(0.0f64; 5, 5);
484493
let _ones_3d = constant!(1u32; 3, 3, 3);
@@ -490,6 +499,7 @@ mod tests {
490499

491500
#[test]
492501
fn rand_macro() {
502+
set_device(0);
493503
let _ru5x5 = randu!(5, 5);
494504
let _rn5x5 = randn!(5, 5);
495505
let _ruu32_5x5 = randu!(u32; 5, 5);

0 commit comments

Comments
 (0)