Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Dropout partial bw fusion (second take) #164

Merged
merged 7 commits into from
Jan 3, 2022
Merged
Prev Previous commit
small cleanup, removing old files and fixing broken links
  • Loading branch information
blefaudeux committed Jan 3, 2022
commit c40dd647fa0d96077ec5014c27eef99d89802abe
4 changes: 2 additions & 2 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

![Fused linear layers throughput in fp16 - training](docs/plots/fused_linea/FusedLinear_fp16_FW_BW_gelu.png)
![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png)

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png)

Expand All @@ -74,7 +74,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_none.png)

![Fused linear layers throughput in fp16 - training](docs/plots/fused_line/FusedLinear_fp16_FW_BW_none.png)
![Fused linear layers throughput in fp16 - training](docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png)

### Fused layer norm

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_dropout(shape, amp, bias, p):
# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.1
assert abs(drop_p - p) < 0.01


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
Expand Down
10 changes: 4 additions & 6 deletions xformers/triton/k_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ def _get_4_bin_masks(seed, rand_offsets, p):
# NOTE: We keep the random numbers as is there (integers over int32),
# and convert the threshold instead, for speed

# The initial distribution is -2**31 / 2**31
# so our float threshold in between [0, 1]
# The full computation is:
# 2 ** 32 * p - 2 ** 31 => full range * p - half range (to offset in between -2**31 and 2 **31)
threshold = 2147483648.0 * (2.0 * p - 1.0)
threshold = threshold.to(tl.int32)
# The initial distribution is -2**31 / 2**31 -1
# and our float threshold is in between [0, 1]
# The full computation is: `start_point + full range * p`
threshold = (-2147483648.0 + 4294967295.0 * p).to(tl.int32)
rand_mask1 = rand1 > threshold
rand_mask2 = rand2 > threshold
rand_mask3 = rand3 > threshold
Expand Down