Skip to content

Commit

Permalink
Adding try_realize, and realize doesn't return result (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored May 1, 2023
1 parent f57df15 commit 66d79ce
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 108 deletions.
2 changes: 1 addition & 1 deletion examples/01-tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn main() {
// `realize` method helps us move between dynamic and known size for the dimensions,
// if the conversion is incompatible, it may result in runtime error
let a: Tensor<(usize, usize), f32, _> = dev.zeros_like(&(2, 3));
let _: Tensor<(usize, Const<3>), f32, _> = a.realize().expect("`a` should have 3 columns");
let _: Tensor<(usize, Const<3>), f32, _> = a.try_realize().expect("`a` should have 3 columns");

// each of the creation methods also supports specifying the shape on the function
// note to change the dtype we specify the dtype as the 2nd generic parameter
Expand Down
4 changes: 2 additions & 2 deletions examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ fn main() {
a = a + 0.5;
let b: Tensor<(usize, Const<5>), f32, _> = dev.sample_uniform_like(&(3, Const));
// note the use of `realize`
let _: Tensor<(Const<3>, usize), f32, _> = a + b.realize().expect("`b` should have 3 rows");
let _: Tensor<(Const<3>, usize), f32, _> = a + b.try_realize().expect("`b` should have 3 rows");

// then we have things like matrix and vector multiplication:
let a: Tensor<(usize, Const<5>), f32, _> = dev.sample_normal_like(&(3, Const));
let b: Tensor<(usize, usize), f32, _> = dev.sample_normal_like(&(5, 7));
// if type inference is not possible, we explicitly provide the shape for `realize`
let _: Tensor<(usize, usize), f32, _> = a.matmul(
b.realize::<(Const<5>, usize)>()
b.try_realize::<(Const<5>, usize)>()
.expect("`b` should have 5 rows"),
);

Expand Down
2 changes: 1 addition & 1 deletion examples/13-housing-nn-in-struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Predictor {
let batched: Tensor<Rank2<1, 2>, _, _> = input.clone().broadcast();

// convert static size tensor to variable sized tensor
let batched_realized: Tensor<(usize, Const<2>), _, _> = batched.realize().unwrap();
let batched_realized: Tensor<(usize, Const<2>), _, _> = batched.try_realize().unwrap();
assert_eq!(batched_realized.shape(), &(1 as usize, Const::<2>));

// call predict on batches
Expand Down
12 changes: 3 additions & 9 deletions src/tensor_ops/attention_reshape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,9 @@ mod tests {

let (q, k, v) = dev.attention_reshape(&qkv, &past_key, &past_value);

let q = q
.realize::<(Const<NUM_HEADS>, Const<1>, Const<HEAD_DIM>)>()
.unwrap();
let k = k
.realize::<(Const<NUM_HEADS>, Const<HEAD_DIM>, Const<4>)>()
.unwrap();
let v = v
.realize::<(Const<NUM_HEADS>, Const<4>, Const<HEAD_DIM>)>()
.unwrap();
let q = q.realize::<(Const<NUM_HEADS>, Const<1>, Const<HEAD_DIM>)>();
let k = k.realize::<(Const<NUM_HEADS>, Const<HEAD_DIM>, Const<4>)>();
let v = v.realize::<(Const<NUM_HEADS>, Const<4>, Const<HEAD_DIM>)>();

assert_close_to_literal!(q, [[[1.0; HEAD_DIM]; 1]; NUM_HEADS]);
assert_close_to_literal!(
Expand Down
31 changes: 20 additions & 11 deletions src/tensor_ops/concat_along/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ mod cuda_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const));
/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const));
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize().unwrap();
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize();
/// ```
///
/// Along Axis 1:
Expand All @@ -44,7 +44,7 @@ mod cuda_kernel;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize().unwrap();
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
/// ```
pub trait TryConcatAlong<Ax>: Sized {
type Output;
Expand Down Expand Up @@ -192,11 +192,14 @@ mod tests {
let b: Tensor<Rank3<3, 3, 4>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(usize, Const<3>, Const<4>)>()
.try_realize::<(usize, Const<3>, Const<4>)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(usize, Const<3>, Const<4>)>()
.unwrap();
let b_dyn = b.clone().realize::<(usize, Const<3>, Const<4>)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<0>);
let c = c.realize::<(Const<5>, Const<3>, Const<4>)>().unwrap();
let c = c.try_realize::<(Const<5>, Const<3>, Const<4>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand All @@ -222,11 +225,14 @@ mod tests {
let b: Tensor<Rank3<2, 3, 4>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(Const<2>, usize, Const<4>)>()
.try_realize::<(Const<2>, usize, Const<4>)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(Const<2>, usize, Const<4>)>()
.unwrap();
let b_dyn = b.clone().realize::<(Const<2>, usize, Const<4>)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<1>);
let c = c.realize::<(Const<2>, Const<5>, Const<4>)>().unwrap();
let c = c.try_realize::<(Const<2>, Const<5>, Const<4>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand All @@ -251,11 +257,14 @@ mod tests {
let b: Tensor<Rank3<2, 3, 3>, TestDtype, _> = dev.sample_normal();
let a_dyn = a
.leaky_trace()
.realize::<(Const<2>, Const<3>, usize)>()
.try_realize::<(Const<2>, Const<3>, usize)>()
.unwrap();
let b_dyn = b
.clone()
.try_realize::<(Const<2>, Const<3>, usize)>()
.unwrap();
let b_dyn = b.clone().realize::<(Const<2>, Const<3>, usize)>().unwrap();
let c = (a_dyn, b_dyn).concat_along(Axis::<2>);
let c = c.realize::<(Const<2>, Const<3>, Const<5>)>().unwrap();
let c = c.try_realize::<(Const<2>, Const<3>, Const<5>)>().unwrap();
let a_arr = a.array();
let b_arr = b.array();
let c_arr = c.array();
Expand Down
143 changes: 78 additions & 65 deletions src/tensor_ops/realize_to.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
use crate::{shapes::*, tensor::*};

/// Changes order of dimensions/axes
/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let a = a.realize::<(usize, usize)>();
/// let mut a = a.realize::<Rank2<2, 3>>();
/// match a.try_realize::<(usize, Const<4>)>() {
/// Ok(new) => println!("Shape was properly realized, returned new tensor"),
/// Err(old) => println!("Shape could not be realized, returned the original tensor"),
/// }
/// ```
pub trait RealizeTo: HasErr + HasShape {
/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
/// let a = a.realize::<(usize, usize)>().unwrap();
/// let mut a = a.realize::<Rank2<2, 3>>().unwrap();
/// match a.realize::<(usize, Const<4>)>() {
/// Ok(new) => println!("Shape was properly realized, returned new tensor"),
/// Err(old) => println!("Shape could not be realized, returned the original tensor"),
/// }
/// ```
fn realize<Dst: Shape<Concrete = <<Self as HasShape>::Shape as Shape>::Concrete>>(
self,
) -> Self::WithShape<Dst>
where
Self::Shape: RealizeShapeTo<Dst>,
Self: std::fmt::Debug,
{
self.try_realize::<Dst>().unwrap()
}

/// Realizes the concrete shape of the tensor as another compatable shape,
/// or returns the original tensor if the new shape's dimensions are incompatable.
fn try_realize<Dst: Shape<Concrete = <<Self as HasShape>::Shape as Shape>::Concrete>>(
self,
) -> Result<Self::WithShape<Dst>, Self>
where
Self::Shape: RealizeShapeTo<Dst>;
}

impl<S: Shape, E: Dtype, D: DeviceStorage, T: Tape<E, D>> RealizeTo for Tensor<S, E, D, T> {
fn realize<Dst: Shape<Concrete = S::Concrete>>(self) -> Result<Self::WithShape<Dst>, Self>
fn try_realize<Dst: Shape<Concrete = S::Concrete>>(self) -> Result<Self::WithShape<Dst>, Self>
where
Self::Shape: RealizeShapeTo<Dst>,
{
Expand Down Expand Up @@ -51,45 +64,39 @@ mod tests {
fn test_realize_2d() {
let dev: TestDevice = Default::default();
let src: Tensor<Rank2<2, 3>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<2>, usize), TestDtype, _> =
src.clone().realize::<(Const<2>, usize)>().unwrap();
let dst = src.clone().realize::<(Const<2>, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, Const<3>), TestDtype, _> =
src.clone().realize::<(usize, Const<3>)>().unwrap();
let dst = src.clone().realize::<(usize, Const<3>)>();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, usize), TestDtype, _> =
src.clone().realize::<(usize, usize)>().unwrap();
let dst: Tensor<(usize, usize), TestDtype, _> = src.clone().realize::<(usize, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
src = src.realize::<(usize, Const<4>)>().unwrap_err();
src = src.realize::<(Const<1>, usize)>().unwrap_err();
src = src.realize::<(Const<2>, Const<4>)>().unwrap_err();
src = src.realize::<(Const<3>, Const<2>)>().unwrap_err();
src = src.try_realize::<(usize, Const<4>)>().unwrap_err();
src = src.try_realize::<(Const<1>, usize)>().unwrap_err();
src = src.try_realize::<(Const<2>, Const<4>)>().unwrap_err();
src = src.try_realize::<(Const<3>, Const<2>)>().unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}

#[test]
fn test_realize_3d() {
let dev: TestDevice = Default::default();
let src: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<3>, usize, Const<7>), TestDtype, _> = src
.clone()
.realize::<(Const<3>, usize, Const<7>)>()
.unwrap();
let dst = src.clone().realize::<(Const<3>, usize, Const<7>)>();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, Const<5>, usize), TestDtype, _> =
src.clone().realize::<(usize, Const<5>, usize)>().unwrap();
let dst = src.clone().realize::<(usize, Const<5>, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, usize, usize), TestDtype, _> =
src.clone().realize::<(usize, usize, usize)>().unwrap();
let dst = src.clone().realize::<(usize, usize, usize)>();
assert_eq!(src.as_vec(), dst.as_vec());
// Ensure we get back the original tensor on error
src = src.realize::<(usize, Const<2>, usize)>().unwrap_err();
src = src.realize::<(Const<3>, Const<1>, Const<7>)>().unwrap_err();
src = src.realize::<(usize, usize, Const<3>)>().unwrap_err();
src = src.try_realize::<(usize, Const<2>, usize)>().unwrap_err();
src = src
.try_realize::<(Const<3>, Const<1>, Const<7>)>()
.unwrap_err();
src = src.try_realize::<(usize, usize, Const<3>)>().unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}

Expand All @@ -99,29 +106,29 @@ mod tests {
let src: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.sample_normal();
let dst: Tensor<(Const<3>, usize, Const<7>, usize), TestDtype, _> = src
.clone()
.realize::<(Const<3>, usize, Const<7>, usize)>()
.try_realize::<(Const<3>, usize, Const<7>, usize)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
let src = dst;
let dst: Tensor<(usize, usize, usize, usize), TestDtype, _> = src
.clone()
.realize::<(usize, usize, usize, usize)>()
.try_realize::<(usize, usize, usize, usize)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
let mut src = dst;
let dst: Tensor<(usize, Const<5>, Const<7>, Const<9>), TestDtype, _> = src
.clone()
.realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.try_realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.unwrap();
assert_eq!(src.as_vec(), dst.as_vec());
src = src
.realize::<(usize, Const<2>, usize, Const<9>)>()
.try_realize::<(usize, Const<2>, usize, Const<9>)>()
.unwrap_err();
src = src
.realize::<(Const<3>, Const<1>, Const<7>, Const<9>)>()
.try_realize::<(Const<3>, Const<1>, Const<7>, Const<9>)>()
.unwrap_err();
src = src
.realize::<(usize, usize, Const<3>, usize)>()
.try_realize::<(usize, usize, Const<3>, usize)>()
.unwrap_err();
assert_eq!(src.as_vec(), dst.as_vec());
}
Expand All @@ -133,7 +140,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize)>()
.try_realize::<(usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -148,7 +155,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize, usize)>()
.try_realize::<(usize, usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -163,7 +170,7 @@ mod tests {
let g1 = t.leaky_trace().exp().sum().backward();
let g2 = t
.leaky_trace()
.realize::<(usize, usize, usize, usize)>()
.try_realize::<(usize, usize, usize, usize)>()
.unwrap()
.exp()
.sum()
Expand All @@ -176,39 +183,45 @@ mod tests {
let dev: TestDevice = Default::default();

let x: Tensor<Rank2<3, 5>, TestDtype, _> = dev.sample_normal();
let x = x.realize::<(Const<3>, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>)>().unwrap();
let _ = x.realize::<(usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>)>().unwrap();
let _ = x.try_realize::<(usize, usize)>().unwrap();

let x: Tensor<Rank3<3, 5, 7>, TestDtype, _> = dev.sample_normal();
let x = x.realize::<(Const<3>, Const<5>, usize)>().unwrap();
let x = x.realize::<(Const<3>, usize, Const<7>)>().unwrap();
let x = x.realize::<(usize, Const<5>, Const<7>)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>, usize)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>)>().unwrap();
let _ = x.realize::<(usize, usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, Const<5>, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, Const<7>)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, Const<7>)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, Const<7>)>().unwrap();
let _ = x.try_realize::<(usize, usize, usize)>().unwrap();

let x: Tensor<Rank4<3, 5, 7, 9>, TestDtype, _> = dev.sample_normal();
let x = x
.realize::<(Const<3>, Const<5>, Const<7>, usize)>()
.try_realize::<(Const<3>, Const<5>, Const<7>, usize)>()
.unwrap();
let x = x
.try_realize::<(Const<3>, Const<5>, usize, Const<9>)>()
.unwrap();
let x = x
.try_realize::<(Const<3>, usize, Const<7>, Const<9>)>()
.unwrap();
let x = x
.try_realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.unwrap();
let x = x
.realize::<(Const<3>, Const<5>, usize, Const<9>)>()
.try_realize::<(Const<3>, Const<5>, usize, usize)>()
.unwrap();
let x = x
.realize::<(Const<3>, usize, Const<7>, Const<9>)>()
.try_realize::<(Const<3>, usize, usize, Const<9>)>()
.unwrap();
let x = x
.realize::<(usize, Const<5>, Const<7>, Const<9>)>()
.try_realize::<(usize, usize, Const<7>, Const<9>)>()
.unwrap();
let x = x.realize::<(Const<3>, Const<5>, usize, usize)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize, Const<9>)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>, Const<9>)>().unwrap();
let x = x.realize::<(Const<3>, usize, usize, usize)>().unwrap();
let x = x.realize::<(usize, Const<5>, usize, usize)>().unwrap();
let x = x.realize::<(usize, usize, Const<7>, usize)>().unwrap();
let x = x.realize::<(usize, usize, usize, Const<9>)>().unwrap();
let _ = x.realize::<(usize, usize, usize, usize)>().unwrap();
let x = x.try_realize::<(Const<3>, usize, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, Const<5>, usize, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, Const<7>, usize)>().unwrap();
let x = x.try_realize::<(usize, usize, usize, Const<9>)>().unwrap();
let _ = x.try_realize::<(usize, usize, usize, usize)>().unwrap();
}
}
Loading

0 comments on commit 66d79ce

Please sign in to comment.