Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding f16 as Dtype #696

Merged
merged 52 commits into from
Apr 27, 2023
Merged

Adding f16 as Dtype #696

merged 52 commits into from
Apr 27, 2023

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Apr 11, 2023

Resolves #423

@coreylowman
Copy link
Owner Author

FYI @opfromthestart @nkoppel If you guys want to work on this, you can open PRs into this feature branch

Cargo.toml Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
@coreylowman
Copy link
Owner Author

coreylowman commented Apr 20, 2023

I'm wondering if we should enforce f16/bf16 tests passing. There's particular tests where I just don't think it'll have the accuracy required to pass the tests, or we need a way to reduce the tolerance even more.

Here are the current failing t ests for f16:

    losses::tests::test_hard_crossentropy
    nn::batchnorm1d::tests::test_batchnorm1d_2d_forward_mut
    nn::batchnorm2d::tests::test_batchnorm2d_4d_forward_mut
    nn::linear::tests::test_forward_1d
    nn::linear::tests::test_linear_initialize
    nn::residual::tests::test_residual_gradients
    nn::unbiased_linear::tests::test_forward_1d
    optim::adam::tests::test_adam_decoupled_decay
    optim::adam::tests::test_custom_adam_one_params
    optim::adam::tests::test_default_adam_params
    tensor_ops::matmul::tests::test_matmul_vec_normal
    tensor_ops::matmul::tests::test_matmul_vec_transpose
    tensor_ops::matmul::tests::test_small_matmul_mm
    tensor_ops::max_to::tests::test_max_axis_0_2d
    tensor_ops::mean_to::tests::test_mean_axis_0_2d
    tensor_ops::pool2d::tests::test_pool2d_3d_max2d
    tensor_ops::sum_to::tests::test_sum_axes_3d_to_1d
    tensor_ops::sum_to::tests::test_sum_broadcasted
    tensor_ops::upscale2d::tests::test_bilinear_upscale2d_batched
    tensor_ops::upscale2d::tests::test_upscale2d_bilinear_even
    tensor_ops::upscale2d::tests::test_upscale2d_bilinear_uneven
    tensor_ops::upscale2d::tests::test_upscale2d_nearest_uneven

@ViliamVadocz
Copy link
Contributor

ViliamVadocz commented Apr 26, 2023

I made a few more tests pass with some fixes: 254 passed; 103 failed.
How many tests are expected to pass?

EDIT: Now up to 255 passed; 102 failed.

@coreylowman
Copy link
Owner Author

On compute_cap 86 I'm only getting 8 failures

@coreylowman
Copy link
Owner Author

They are likely failing because sum is failing. I'm running cargo +nightly test --tests -F test-f16,cuda sum to run juts the sum tests

@ViliamVadocz
Copy link
Contributor

In that case I'll just implement the atomicAdd directly instead of fiddling around with atomicCAS. It does mean that compatibility code will leech into the min_to and max_to files since they also use atomicCAS on shorts.

@coreylowman
Copy link
Owner Author

Okay this is where I'm at now:

__device__ __half atomicAdd(__half* address, __half val) {
    size_t align = reinterpret_cast<size_t>(address) & 2;
    unsigned int *address_as_u32 = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - align);
    unsigned int old = *address_as_u32;
    unsigned int assumed;

    do {
        assumed = old;
        __half sum16 = __ushort_as_half(align ? (old >> 16) : (old & 0xffff)) + val;
        unsigned int sum32 = (unsigned int) __half_as_ushort(sum16);
        old = align ? ((sum32 << 16) | (old & 0xffff)) : ((old & 0xffff0000) | sum32);
        old = atomicCAS(address_as_u32, assumed, old);
    } while (assumed != old);
    return __ushort_as_half(align ? (old >> 16) : (old & 0xffff));
}

375 passed, 18 failed.

It seems like this doesn't handle inf properly as some of the errors i'm still getting are:

---- tensor_ops::max_to::tests::test_max_axis_0_2d stdout ----
thread 'tensor_ops::max_to::tests::test_max_axis_0_2d' panicked at 'lhs != rhs | -inf != 3', src/tensor_ops/max_to/mod.rs:97:9

---- tensor_ops::max_to::tests::test_max_axis_1_2d stdout ----
thread 'tensor_ops::max_to::tests::test_max_axis_1_2d' panicked at 'lhs != rhs | -inf != 2', src/tensor_ops/max_to/mod.rs:112:9

---- tensor_ops::max_to::tests::test_max_negative_zero stdout ----
thread 'tensor_ops::max_to::tests::test_max_negative_zero' panicked at 'lhs != rhs | -inf != 0', src/tensor_ops/max_to/mod.rs:136:9

---- tensor_ops::min_to::tests::test_min_axis_0_2d stdout ----
thread 'tensor_ops::min_to::tests::test_min_axis_0_2d' panicked at 'lhs != rhs | inf != 1', src/tensor_ops/min_to/mod.rs:97:9

---- tensor_ops::min_to::tests::test_min_axis_1_2d stdout ----
thread 'tensor_ops::min_to::tests::test_min_axis_1_2d' panicked at 'lhs != rhs | inf != 1', src/tensor_ops/min_to/mod.rs:112:9

---- tensor_ops::min_to::tests::test_min_negative_zero stdout ----
thread 'tensor_ops::min_to::tests::test_min_negative_zero' panicked at 'lhs != rhs | inf != -0', src/tensor_ops/min_to/mod.rs:136:9

@ViliamVadocz
Copy link
Contributor

All those failing tests are min and max which use the probably broken atomicCAS. I'm almost done with my attempt.

@ViliamVadocz
Copy link
Contributor

PR #742 is up

@coreylowman
Copy link
Owner Author

@ViliamVadocz nice work, all the tests pass for me now! 🚀 (other than the ones I broke from reverting optimizer kernels)

src/optim/adam/mod.rs Outdated Show resolved Hide resolved
src/optim/adam/mod.rs Outdated Show resolved Hide resolved
src/optim/rmsprop/mod.rs Outdated Show resolved Hide resolved
src/optim/sgd/mod.rs Outdated Show resolved Hide resolved
@coreylowman coreylowman changed the title [WIP] [Feature Branch] Adding f16 as Dtype Adding f16 as Dtype Apr 27, 2023
@coreylowman coreylowman merged commit 7626de4 into main Apr 27, 2023
@coreylowman coreylowman deleted the f16 branch April 27, 2023 18:32
@coreylowman coreylowman mentioned this pull request Apr 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add f16 dtype support for tensors
3 participants