Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functorch nvfuser revisions #363

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9474fc3
added nvfuser implementation, benchmark for biasReluDropout
Jul 7, 2022
5ea028e
reformatted fuse pattern
Jul 8, 2022
8453069
revised benchamrking, nvfused patterns
Jul 11, 2022
fdd6b16
adds BiasDropoutRes and BiasDropoutResLayernorm patterns, minor edits
Jul 13, 2022
291f439
unit testing for all fused patterns, minor edits
Jul 19, 2022
5004562
benchmarking for all nvfused patterns
Jul 19, 2022
ea85ea4
mypy wip
Jul 19, 2022
568c09a
benchmarking nvfuser patterns, adding plots, minor testing changes
Jul 22, 2022
7c7f6de
fixing mypy errors
Jul 25, 2022
8c59bb9
fixed benchmarking bug, minor test change
Jul 25, 2022
fd82a43
final benchmark plots, benchmmark edits
Jul 25, 2022
bd4499a
nvfuser documentation, minor edits
Jul 26, 2022
b004d87
fixing functorch version error, documentation revisions
Jul 26, 2022
14cc332
Merge branch 'main' into op_fusion_functorch
yuanandonly Jul 26, 2022
9ea013a
fixing circleci functorch errors, mypy errors
Jul 26, 2022
c774755
circleci config wip
Jul 27, 2022
4f18220
circleci test wip
Jul 27, 2022
d5e0765
wip2
Jul 27, 2022
477c208
testing revisions, circleci fixes, minor changes
Jul 27, 2022
7d9d659
changelog changes, fixes functorch flag bug
Jul 27, 2022
339a556
circle-ci fix
Jul 27, 2022
5d8221d
circle-ci spacing fix
Jul 27, 2022
d9199f0
build error wip
Jul 27, 2022
bcf746e
revised documentation, reverted circleci config
Jul 27, 2022
bd5b799
Fix functorch errors, circleci issue, testing changes
yuanandonly Jul 27, 2022
a6f3221
updating changelog
yuanandonly Jul 28, 2022
33431d0
added mlp plots, mlp functionality to switch weights to nvfused mlp
yuanandonly Aug 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed benchmarking bug, minor test change
  • Loading branch information
Chris Yuan committed Jul 25, 2022
commit 8c59bb9318736a811d35e26848a71e1d56a339f1
8 changes: 7 additions & 1 deletion tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.


import logging

import pytest
import torch
import torch.nn as nn
Expand All @@ -15,13 +17,17 @@

_gpu_available = torch.cuda.is_available()

if xformers._is_functorch_available:
xformers._is_functorch_available = True

try:
from xformers.components.nvfuser import (
NVFusedBiasActivationDropout,
NVFusedBiasDropoutRes,
NVFusedBiasDropoutResLayerNorm,
)
from xformers.components.nvfuser.utils import build_nvfused
except ImportError as e:
logging.warning(f"Functorch is not available in test_nvfuser.py. \nError {e}")

FUSED_PATTERNS = (
[
Expand Down
10 changes: 5 additions & 5 deletions xformers/benchmarks/benchmark_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def step(fn, residual, x):
for testcase in testcases:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# torch.cuda.synchronize()
torch.cuda.synchronize()

time = triton.testing.do_bench(
lambda: testcase.function(x=a), grad_to_none=[a, b]
)[0]

# torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2**20
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2**20

key = f"B={B}, M={M}, K={K}"
if key not in results:
Expand All @@ -211,7 +211,7 @@ def step(fn, residual, x):
units="GB/s",
)
pretty_print(
results,
results_mem,
title="\n --- PEAK MEMORY Type: {} {} --- ".format(pattern_str, dtype),
units="MB",
)
Expand All @@ -230,7 +230,7 @@ def step(fn, residual, x):
legend_loc="upper left",
)
pretty_plot(
results,
results_mem,
title="MAXMEM-{}-FW{}-{}{}-{}{}".format(
pattern_str,
"+BW" if backward else "",
Expand Down