Skip to content

Commit 3e9abd1

Browse files
committed
Fix row/col/slice index type for respective functions
1 parent 00067c8 commit 3e9abd1

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

examples/helloworld.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ fn main() {
1313
);
1414
println!("Revision: {}", get_revision());
1515

16-
let num_rows: u64 = 5;
17-
let num_cols: u64 = 3;
16+
let num_rows: i64 = 5;
17+
let num_cols: i64 = 3;
1818
let values: [f32; 3] = [1.0, 2.0, 3.0];
1919
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
2020

2121
af_print!("Indices ", indices);
2222

23-
let dims = Dim4::new(&[num_rows, num_cols, 1, 1]);
23+
let dims = Dim4::new(&[num_rows as u64, num_cols as u64, 1, 1]);
2424

2525
let mut a = randu::<f32>(dims);
2626
af_print!("Create a 5-by-3 float matrix on the GPU", a);

src/core/index.rs

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ where
287287
/// print(&a);
288288
/// print(&row(&a, 4));
289289
/// ```
290-
pub fn row<T>(input: &Array<T>, row_num: u64) -> Array<T>
290+
pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
291291
where
292292
T: HasAfEnum,
293293
{
@@ -301,7 +301,7 @@ where
301301
}
302302

303303
/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
304-
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: u64)
304+
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: i64)
305305
where
306306
T: HasAfEnum,
307307
{
@@ -313,22 +313,24 @@ where
313313
}
314314

315315
/// Get an Array with all rows from `first` to `last` in the `input` Array
316-
pub fn rows<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
316+
pub fn rows<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
317317
where
318318
T: HasAfEnum,
319319
{
320+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
320321
index(
321322
input,
322-
&[Seq::new(first as f64, last as f64, 1.0), Seq::default()],
323+
&[Seq::new(first as f64, last as f64, step), Seq::default()],
323324
)
324325
}
325326

326327
/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
327-
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: u64, last: u64)
328+
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: i64, last: i64)
328329
where
329330
T: HasAfEnum,
330331
{
331-
let seqs = [Seq::new(first as f64, last as f64, 1.0), Seq::default()];
332+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
333+
let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
332334
assign_seq(inout, &seqs, new_rows)
333335
}
334336

@@ -344,7 +346,7 @@ where
344346
/// println!("Grab last col of the random matrix");
345347
/// print(&col(&a, 4));
346348
/// ```
347-
pub fn col<T>(input: &Array<T>, col_num: u64) -> Array<T>
349+
pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
348350
where
349351
T: HasAfEnum,
350352
{
@@ -358,7 +360,7 @@ where
358360
}
359361

360362
/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
361-
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: u64)
363+
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
362364
where
363365
T: HasAfEnum,
364366
{
@@ -370,29 +372,31 @@ where
370372
}
371373

372374
/// Get all cols from `first` to `last` in the `input` Array
373-
pub fn cols<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
375+
pub fn cols<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
374376
where
375377
T: HasAfEnum,
376378
{
379+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
377380
index(
378381
input,
379-
&[Seq::default(), Seq::new(first as f64, last as f64, 1.0)],
382+
&[Seq::default(), Seq::new(first as f64, last as f64, step)],
380383
)
381384
}
382385

383386
/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
384-
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: u64, last: u64)
387+
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: i64, last: i64)
385388
where
386389
T: HasAfEnum,
387390
{
388-
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, 1.0)];
391+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
392+
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
389393
assign_seq(inout, &seqs, new_cols)
390394
}
391395

392396
/// Get `slice_num`^th slice from `input` Array
393397
///
394398
/// Slices indicate that the indexing is along 3rd dimension
395-
pub fn slice<T>(input: &Array<T>, slice_num: u64) -> Array<T>
399+
pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
396400
where
397401
T: HasAfEnum,
398402
{
@@ -407,7 +411,7 @@ where
407411
/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
408412
///
409413
/// Slices indicate that the indexing is along 3rd dimension
410-
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: u64)
414+
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
411415
where
412416
T: HasAfEnum,
413417
{
@@ -422,29 +426,31 @@ where
422426
/// Get slices from `first` to `last` in `input` Array
423427
///
424428
/// Slices indicate that the indexing is along 3rd dimension
425-
pub fn slices<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
429+
pub fn slices<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
426430
where
427431
T: HasAfEnum,
428432
{
433+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
429434
let seqs = [
430435
Seq::default(),
431436
Seq::default(),
432-
Seq::new(first as f64, last as f64, 1.0),
437+
Seq::new(first as f64, last as f64, step),
433438
];
434439
index(input, &seqs)
435440
}
436441

437442
/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
438443
///
439444
/// Slices indicate that the indexing is along 3rd dimension
440-
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: u64, last: u64)
445+
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: i64, last: i64)
441446
where
442447
T: HasAfEnum,
443448
{
449+
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
444450
let seqs = [
445451
Seq::default(),
446452
Seq::default(),
447-
Seq::new(first as f64, last as f64, 1.0),
453+
Seq::new(first as f64, last as f64, step),
448454
];
449455
assign_seq(inout, &seqs, new_slices)
450456
}
@@ -644,6 +650,7 @@ mod tests {
644650
use super::super::data::constant;
645651
use super::super::dim4::Dim4;
646652
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
653+
use super::super::index::{cols, rows};
647654
use super::super::random::randu;
648655
use super::super::seq::Seq;
649656

@@ -800,4 +807,51 @@ mod tests {
800807
// 0.7896
801808
// ANCHOR_END: setrow
802809
}
810+
811+
#[test]
812+
fn get_row() {
813+
// ANCHOR: get_row
814+
let a = randu::<f32>(dim4!(5, 5));
815+
// [5 5 1 1]
816+
// 0.6010 0.5497 0.1583 0.3636 0.6755
817+
// 0.0278 0.2864 0.3712 0.4165 0.6105
818+
// 0.9806 0.3410 0.3543 0.5814 0.5232
819+
// 0.2126 0.7509 0.6450 0.8962 0.5567
820+
// 0.0655 0.4105 0.9675 0.3712 0.7896
821+
let _r = row(&a, -1);
822+
// [1 5 1 1]
823+
// 0.0655 0.4105 0.9675 0.3712 0.7896
824+
let _c = col(&a, -1);
825+
// [5 1 1 1]
826+
// 0.6755
827+
// 0.6105
828+
// 0.5232
829+
// 0.5567
830+
// 0.7896
831+
// ANCHOR_END: get_row
832+
}
833+
834+
#[test]
835+
fn get_rows() {
836+
// ANCHOR: get_rows
837+
let a = randu::<f32>(dim4!(5, 5));
838+
// [5 5 1 1]
839+
// 0.6010 0.5497 0.1583 0.3636 0.6755
840+
// 0.0278 0.2864 0.3712 0.4165 0.6105
841+
// 0.9806 0.3410 0.3543 0.5814 0.5232
842+
// 0.2126 0.7509 0.6450 0.8962 0.5567
843+
// 0.0655 0.4105 0.9675 0.3712 0.7896
844+
let _r = rows(&a, -1, -2);
845+
// [2 5 1 1]
846+
// 0.2126 0.7509 0.6450 0.8962 0.5567
847+
// 0.0655 0.4105 0.9675 0.3712 0.7896
848+
let _c = cols(&a, -1, -3);
849+
// [5 3 1 1]
850+
// 0.1583 0.3636 0.6755
851+
// 0.3712 0.4165 0.6105
852+
// 0.3543 0.5814 0.5232
853+
// 0.6450 0.8962 0.5567
854+
// 0.9675 0.3712 0.7896
855+
// ANCHOR_END: get_rows
856+
}
803857
}

0 commit comments

Comments
 (0)