Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] implementing Conv3d #50

Closed
NewBornRustacean opened this issue Apr 28, 2024 · 2 comments
Closed

[Question] implementing Conv3d #50

NewBornRustacean opened this issue Apr 28, 2024 · 2 comments

Comments

@NewBornRustacean
Copy link
Contributor

Hello! Thanks for this awesome project! @jafioti

Recently, I'm working on Conv3d(#28 #29) and I have a question:

Is it OK to define Axes7 and R7?
If to do so, there might be a shot-gun surgery I guess..

My draft is here:

impl<
    const CHANNELS_IN: usize,
    const CHANNELS_OUT: usize,
    const KERNELX: usize,
    const KERNELY: usize,
    const KERNELZ: usize,
    const STRIDEX: usize,
    const STRIDEY: usize,
    const STRIDEZ: usize,
    const DILATIONX: usize,
    const DILATIONY: usize,
    const DILATIONZ: usize,
    const CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ: usize,
> Conv3D<
    CHANNELS_IN,
    CHANNELS_OUT,
    KERNELX,
    KERNELY,
    KERNELZ,
    STRIDEX,
    STRIDEY,
    STRIDEZ,
    DILATIONX,
    DILATIONY,
    DILATIONZ,
    CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ,
>
{
    pub fn forward<
        const DIMX_IN: usize,
        const DIMY_IN: usize,
        const DIMZ_IN: usize,
        const DIMX_OUT: usize,
        const DIMY_OUT: usize,
        const DIMZ_OUT: usize,
        const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
    >(
        &self,
        input: GraphTensor<R4<CHANNELS_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>,
    ) -> GraphTensor<R4<CHANNELS_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>> {
        let input_pooled = input
            .pool_last_dim::<R5<CHANNELS_IN, DIMX_IN, DIMY_OUT, DIMZ_OUT, KERNELY>>(
                KERNELY.into(),
                STRIDEY.into(),
                DILATIONY
            )
            .permute::<_, Axes5<0, 2, 3, 4, 1>>()
            .pool_last_dim::<R6<CHANNELS_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
                KERNELX.into(),
                STRIDEX.into(),
                DILATIONX
            )
            .permute::<_, Axes6<0, 5, 2, 3, 4, 1>>()
            .pool_last_dim::<R7<CHANNELS_IN, DIMZ_OUT, KERNELZ, DIMX_OUT, KERNELX, DIMY_OUT, KERNELY>>(
                KERNELZ.into(),
                STRIDEZ.into(),
                DILATIONZ
            )
            .permute::<_, Axes7<0, 6, 2, 3, 4, 5, 1>>()
            .reshape::<R2<CHANNELS_IN_TIMES_KERNELX_KERNELY_KERNELZ, DIMX_TIMES_DIMY_DIMZ_OUT>>();

        self.weight
            .matmul(input_pooled)
            .reshape::<R4<CHANNELS_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
    }
}

I'm not sure I'm going right direction, so any comments will be appreciated!

Have a nice day :)

@jafioti
Copy link
Owner

jafioti commented Apr 28, 2024

Great work! It looks correct to me on first pass. Is there any way to implement it without requiring a 7D tensor? Like maybe merging some of the dimensions before doing the last pool_last_dim? Ideally we keep the max tensor dims at 6 (in order to keep the shapetracker small because it's stored on the stack).

If not, I'll look into adding a 7th dimension

@NewBornRustacean
Copy link
Contributor Author

Thanks! I' ll look into it!
(I'm quite a newbie just like my nick name, this is really helpful! )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants