Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transpose/permute_axes operation #108

Closed
Tracked by #159
coreylowman opened this issue Jul 20, 2022 · 14 comments · Fixed by #169
Closed
Tracked by #159

Add transpose/permute_axes operation #108

coreylowman opened this issue Jul 20, 2022 · 14 comments · Fixed by #169

Comments

@coreylowman
Copy link
Owner

coreylowman commented Jul 20, 2022

This would be used for matmuls & transformer implementations. The hard thing about this is how to do this without moving data around a lot. I believe other frameworks just change stride & dimensions without actually moving around data, which is what we can/should do as well. definitely a weakness of storing actual arrays.

Originally posted by @jafioti in #34 (comment)

@jafioti
Copy link
Contributor

jafioti commented Jul 21, 2022

@coreylowman Do you think that at some point it will make more sense to use flat arrays for tensor storage with strides and shapes or keep it in actual arrays?

@coreylowman
Copy link
Owner Author

@jafioti i was thinking about this more overnight - i think we could just do both actually (so I take back my comment in the main post). we could keep the array storage and then just convert them to slices for computations. or store them as slices and then the frontend interface could expose them as rust arrays.

@vikigenius
Copy link
Contributor

+1 for storing strides and dimensions alongside a flat array. This should lead to some clean API for reshaping(views) and reducing over arbitrary dimensions.

@jafioti
Copy link
Contributor

jafioti commented Jul 21, 2022

I'd love to just do flat and call it a day right now, but one issue will be that we are forced to go full nightly. The size of the flat array in a Tensor3D<A, B, C> will be [f32; {A * B * C}] which requires nightly.

@coreylowman
Copy link
Owner Author

We should be able to do Box slices without nightly though. all tensors would have Rc<[f32]>. Number of elements isn't captured in the type, so i'm not sure if rust would be able to auto vectorize things anymore, but won't know unless we try!

The biggest changes for this will be involved in Device traits, which are all implemented using recursive array traits

@jafioti
Copy link
Contributor

jafioti commented Jul 23, 2022

@coreylowman Have you looked into how to do this? I'm going to need to transpose for my MultiHeadAttention implementation, so I could either try to do this (though I would rather keep it as a seperate PR) or just implement an inefficient copying transpose system temporarily and once this lands we can switch transpose to a more efficient impl.

@coreylowman
Copy link
Owner Author

@jafioti Yeah I tried out slices and also separate transpose operation. Here are my notes:

  1. slices - this would be a huge change, and also not exactly sure how it interacts with const generics. I think we'd need some separate data storage struct that tensors would use that doesn't have the const generics. For example, if you transpose a Tensor2D<M, N>, then both the type should change to Tensor2D<N, M> and the underlying stride should change.

  2. Transpose operation - this seems way more feasible in the short term, though would do some extra memory copies (which i'm not sure how much they would actually impact things). Recommend not moving Linear to this yet because it'd have to put the weight matrix on the tape, then do transpose, then put tape back onto input so all the operations are recorded. This shouldn't cost anything runtime wise, but would be a less clear implementation.

So tl;dr: more inefficient transpose operation IMO

@jafioti
Copy link
Contributor

jafioti commented Jul 25, 2022

Ok, I ended up not needing transpose for MultiHeadAttention, but I implemented a forward pass of it anyway for Tensor3D:

impl<const A: usize, const B: usize, const C: usize> Transpose<0, 1> for Tensor3D<A, B, C> {
    type Output = Tensor3D<B, A, C>;
    fn transpose(self) -> Self::Output {
        let mut new = Tensor3D::zeros();
        // Copy data
        let data = self.data();
        let new_data = new.mut_data();

        #[allow(clippy::needless_range_loop)]
        for i in 0..B {
            for j in 0..A {
                new_data[i][j] = data[j][i];
            }
        }
        new
    }
}

I agree this is atrocious, and it doens't even properly pass the grads backward. Currently it's not used for anything, so we can scrap it and try a better way.

As far as the slices change, I think it would be fine to have a separate data storage class not connected with const generics, since only we would be using this class, and once linked correctly to the tensors and their operations, we don't need to worry about it anymore. I also think it would be sort of necessary to do this for when we work on GPU kernels. I think this would also allow for cutting down on matmul functions, since no separate matmul_transpose function is required because transpose is effectively zero cost.

I really think this should be prioritized so work going forward doesn't need to be redone when the switch eventually happens.

@vikigenius
Copy link
Contributor

@jafioti @coreylowman +1 on prioritizing the slices change because this is the time to make such massive changes because once we have more users (which we will once we add GPU support) then this will be harder to do since we will have way more downstream functions that will depend on it.

Things like reduce across multiple dimensions etc. will be much easier with a slices/stride based approach.

@coreylowman
Copy link
Owner Author

It's going to require some thinking & design work to implement stuff like reduction across multiple dimensions & permutations in a way that doesn't have a method per possible permutation/axis, and is also user friendly. Permutation for example has to update both the const generic parameters (at compile time) and also the strides (at run time).

I'm going to focus on conv nets (& hopefully help with transformers given enough time) to get more usable features first & help illuminate the path for strides rewrite. Those are two things will require a lot of additions to the internals that will help focus the rewrite. For example transformers already added a ton of functionality that wasn't necessarily obvious that we'd need.

@coreylowman coreylowman changed the title Add transpose operation Add transpose/permute_axes operation Aug 21, 2022
@coreylowman
Copy link
Owner Author

I don't know whether I should be proud that I figured out how to implement this via macros or horrified...

impl_permute!(0, 1, 3, 2);
impl_permute!(0, 2, 1, 3);
impl_permute!(0, 2, 3, 1);
impl_permute!(0, 3, 1, 2);
impl_permute!(0, 3, 2, 1);
impl_permute!(1, 0, 2, 3);
impl_permute!(1, 0, 3, 2);
impl_permute!(1, 2, 0, 3);
impl_permute!(1, 2, 3, 0);
impl_permute!(1, 3, 0, 2);
impl_permute!(1, 3, 2, 0);
impl_permute!(2, 0, 1, 3);
impl_permute!(2, 0, 3, 1);
impl_permute!(2, 1, 0, 3);
impl_permute!(2, 1, 3, 0);
impl_permute!(2, 3, 0, 1);
impl_permute!(2, 3, 1, 0);
impl_permute!(3, 0, 1, 2);
impl_permute!(3, 0, 2, 1);
impl_permute!(3, 1, 0, 2);
impl_permute!(3, 1, 2, 0);
impl_permute!(3, 2, 0, 1);
impl_permute!(3, 2, 1, 0);

@jafioti
Copy link
Contributor

jafioti commented Aug 22, 2022

@coreylowman Oh ... wow.
Yeah an optimal setup would be to have a single generic permute function to take in 4 const arguments, and based on that return different shapes, but I have no idea if thats even possible today (probably not). Till then it seems like we're stuck with these macros.

I would think a slightly better way to do this would have a proc macro to generate every possible permutation, so you can't miss any, but that would require a separate crate for proc macros and a big hassle so this seems good for now.

@coreylowman
Copy link
Owner Author

Yeah as of now you'd need to programmatically specify the output type, which isn't possible on stable at least (maybe nightly?)

which i think may look something like this?

trait Permute3<I, J, K>: HasAxis<I> + HasAxis<J> + HasAxis<K> {
    fn permute(self) -> Tensor3D<<Self as HasAxis<I>>::SIZE, <Self as HasAxis<J>>::SIZE, <Self as HasAxis<K>>::SIZE>;
}

not sure what feature on nightly allows you to use associated consts as generic arguments.

I think the other problem (permutted index) is easier to solve

for m in 0..M {
    for n in 0..N {
        for o in 0..O {
             *permuted3_idx::<I, J, K>(output.mut_data(), [m, n, o]) = input.data()[m][n][o];
        }
    }
}

@coreylowman
Copy link
Owner Author

Yep confirmed that with generic_const_exprs this is possible with just traits. Here's 2d version https://play.rust-lang.org/?version=nightly&mode=debug&edition=2021&gist=114492951baef2c0c68deb51a59ab401.

I'd like permute to be available on stable though. So for now i'm going to do macros, and then i'll make a follow up issue to refactor to this method once const_generic_exprs is stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants