Skip to content

Commit

Permalink
[functorch] [Benchmark] Layer norm patterns (pytorch/functorch#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and bigfootjon committed Jul 21, 2022
1 parent 85576db commit 98bbf3d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
4 changes: 4 additions & 0 deletions functorch/benchmarks/ls_patterns/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def time_with_torch_timer(fn, args, string_id):
env = {"fn": fn, "a": args[0], "b": args[1]}
fn_call = "fn(a, b)"
grad_none = "a.grad = b.grad = None"
elif len(args) == 1:
env = {"fn": fn, "a": args[0]}
fn_call = "fn(a)"
grad_none = "a.grad = None"

print("################################################")
print(f"#### Torch Timer for {string_id} starts #########")
Expand Down
67 changes: 67 additions & 0 deletions functorch/benchmarks/ls_patterns/bias_dropout_res_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import time
from functorch.compile import memory_efficient_operator_authoring, clear_compile_cache
import benchmark_helper


batch_size = 32
seq_len = 196
hidden_size = 1024
def bias_dropout_res_layernorm(input, bias, residual):
a = torch.add(input, bias)
b = torch.nn.functional.dropout(a, p=0.7, training=True)
c = b + residual
d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,))
return d


fn = bias_dropout_res_layernorm

clear_compile_cache()

# Set inputs
device = "cuda"
dtype = torch.float16
# batch_size = 2
# seq_len = 4
# hidden_size = 3
input = torch.randn(
batch_size, seq_len, hidden_size, requires_grad=True, device=device, dtype=dtype
)
bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
residual = torch.randn(
batch_size, seq_len, hidden_size, requires_grad=False, device=device, dtype=dtype
)


# Get the optimized function
opt_fn = memory_efficient_operator_authoring(fn, compiler_name="torchscript_nvfuser")

# Use this to print the graphs for NVFuser
with torch.jit.fuser("fuser2"):
for _ in range(10):
fwd = opt_fn(input, bias, residual)
loss = fwd.sum()
loss.backward()

# Profile cuda kernels
benchmark_helper.profile_cuda_kernels(fn, (input, bias, residual), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.profile_cuda_kernels(
opt_fn, (input, bias, residual), "AOTAutograd"
)


# Time it with Torch Timer
benchmark_helper.time_with_torch_timer(fn, (input, bias, residual), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_torch_timer(
opt_fn, (input, bias, residual), "AOTAutograd"
)

# Time it with manual Timer
benchmark_helper.time_with_manual_timer(fn, (input, bias, residual), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_manual_timer(
opt_fn, (input, bias, residual), "AOTAutograd"
)
60 changes: 60 additions & 0 deletions functorch/benchmarks/ls_patterns/layernorm_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import time
from functorch.compile import memory_efficient_operator_authoring, clear_compile_cache
import benchmark_helper


batch_size = 8192
hidden_size = 512
def layernorm_sigmoid(inp):
a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,))
b = torch.sigmoid(a)
return b


fn = layernorm_sigmoid

clear_compile_cache()

# Set inputs
device = "cuda"
dtype = torch.float16
# batch_size = 2
# seq_len = 4
# hidden_size = 3
inp = torch.randn(
batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype
)


# Get the optimized function
opt_fn = memory_efficient_operator_authoring(fn, compiler_name="torchscript_nvfuser")

# Use this to print the graphs for NVFuser
with torch.jit.fuser("fuser2"):
for _ in range(10):
fwd = opt_fn(inp)
loss = fwd.sum()
loss.backward()

# Profile cuda kernels
benchmark_helper.profile_cuda_kernels(fn, (inp,), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.profile_cuda_kernels(
opt_fn, (inp,), "AOTAutograd"
)


# Time it with Torch Timer
benchmark_helper.time_with_torch_timer(fn, (inp,), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_torch_timer(
opt_fn, (inp,), "AOTAutograd"
)

# Time it with manual Timer
benchmark_helper.time_with_manual_timer(fn, (inp,), "Eager")
with torch.jit.fuser("fuser2"):
benchmark_helper.time_with_manual_timer(
opt_fn, (inp,), "AOTAutograd"
)

0 comments on commit 98bbf3d

Please sign in to comment.