Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow compilation of resnet18 #234

Closed
coreylowman opened this issue Oct 10, 2022 · 2 comments · Fixed by #304
Closed

Very slow compilation of resnet18 #234

coreylowman opened this issue Oct 10, 2022 · 2 comments · Fixed by #304

Comments

@coreylowman
Copy link
Owner

The following program takes an extremely long time to compile (order of minutes? tbh it hasn't finished compiling for me). I believe this is because each conv2d layer in the below has to have it's own concrete function with different const values.
That means there are 20 different convolution forward functions.

#![cfg(feature(generic_const_exprs))]

use dfdx::prelude::*;
use std::time::Instant;

fn main() {
    type BasicBlock<const C: usize> = (
        Residual<(
            Conv2D<C, C, 3, 1, 1>,
            BatchNorm2D<C>,
            ReLU,
            Conv2D<C, C, 3, 1, 1>,
            BatchNorm2D<C>,
        )>,
        ReLU,
    );

    type Downsample<const C: usize, const D: usize> = (
        GeneralizedResidual<
            (
                Conv2D<C, D, 3, 2, 1>,
                BatchNorm2D<D>,
                ReLU,
                Conv2D<D, D, 3, 1, 1>,
                BatchNorm2D<D>,
            ),
            (Conv2D<C, D, 1, 2, 0>, BatchNorm2D<D>),
        >,
        ReLU,
    );

    type Layer0 = (
        Conv2D<3, 64, 7, 2, 3>,
        BatchNorm2D<64>,
        ReLU,
        MaxPool2D<3, 2, 1>,
    );

    type Resnet18<const NUM_CLASSES: usize> = (
        (
            Layer0,
            (BasicBlock<64>, BasicBlock<64>),
            (Downsample<64, 128>, BasicBlock<128>),
            (Downsample<128, 256>, BasicBlock<256>),
            (Downsample<256, 512>, BasicBlock<512>),
        ),
        AvgPoolGlobal,
        Linear<512, NUM_CLASSES>,
    );

    println!("{:?}", std::mem::size_of::<BasicBlock<128>>());
    println!("{:?}", std::mem::size_of::<Downsample<128, 256>>());
    println!("{:?}", std::mem::size_of::<Resnet18<10>>());

    let mut m: Box<Resnet18<10>> = Default::default();
    let img: Tensor3D<3, 224, 224> = TensorCreator::zeros();
    let start = Instant::now();
    let out = m.forward_mut(img.trace());
    println!("{:?}", start.elapsed());
}
@coreylowman coreylowman changed the title Very slow compilation of resnet18 forward Very slow compilation of resnet18 Oct 10, 2022
@coreylowman
Copy link
Owner Author

Reproduced this with a simpler case here: https://play.rust-lang.org/?version=nightly&mode=release&edition=2021&gist=f8dd6c7f4d89497c3a793cfc77a5cc56

This seems to be an issue with generic_const_exprs bounds on associated types. If there's only 1 bound on the associated type, then it compiles super quickly, but two bounds see the slow compilation speed.

Oddly enough using generic outputs instead of associated output types doesn't see the same error.

@coreylowman
Copy link
Owner Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant