11"""
2- (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
3- ==========================================================================================
2+ (Beta) Scaled Dot Product Attention (SDPA)λ‘ κ³ μ±λ₯ νΈλμ€ν¬λ¨Έ(Transformers) ꡬννκΈ°
3+ =================================================================================
44
55
6- **Author:** `Driss Guessous <https://github.com/drisspg>`_
6+ **μ μ:** `Driss Guessous <https://github.com/drisspg>`_
7+ **λ²μ:** `μ΄κ°ν¬ <https://github.com/khleexv>`_
78"""
89
910######################################################################
10- # Summary
11- # ~~~~~~~~
11+ # μμ½
12+ # ~~~~
1213#
13- # In this tutorial, we want to highlight a new ``torch.nn.functional`` function
14- # that can be helpful for implementing transformer architectures. The
15- # function is named `` torch.nn.functional.scaled_dot_product_attention``.
16- # For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/ torch.nn.functional.scaled_dot_product_attention.html# torch.nn.functional.scaled_dot_product_attention>`__.
17- # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer`` .
14+ # μ΄ νν 리μΌμμ, νΈλμ€ν¬λ¨Έ(Transformer) μν€ν
μ² κ΅¬νμ λμμ΄ λλ μλ‘μ΄
15+ # ``torch.nn.functional`` λͺ¨λμ ν¨μλ₯Ό μκ°ν©λλ€. μ΄ ν¨μμ μ΄λ¦μ ``torch.nn.functional.scaled_dot_product_attention``
16+ # μ
λλ€. ν¨μμ λν μμΈν μ€λͺ
μ `PyTorch λ¬Έμ <https://pytorch.org/docs/master/generated/ torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__
17+ # λ₯Ό μ°Έκ³ νμΈμ. μ΄ ν¨μλ μ΄λ―Έ `` torch.nn.MultiheadAttention`` κ³Ό `` torch.nn.TransformerEncoderLayer``
18+ # μμ μ¬μ©λκ³ μμ΅λλ€ .
1819#
19- # Overview
20- # ~~~~~~~~~
21- # At a high level, this PyTorch function calculates the
22- # scaled dot product attention (SDPA) between query, key, and value according to
23- # the definition found in the paper `Attention is all you
24- # need <https://arxiv.org/abs/1706.03762>`__. While this function can
25- # be written in PyTorch using existing functions, a fused implementation can provide
26- # large performance benefits over a naive implementation.
20+ # κ°μ
21+ # ~~~~
22+ # κ³ μμ€μμ, μ΄ PyTorch ν¨μλ 쿼리(query), ν€(key), κ°(value) μ¬μ΄μ
23+ # scaled dot product attention (SDPA)μ κ³μ°ν©λλ€.
24+ # μ΄ ν¨μμ μ μλ `Attention is all you need <https://arxiv.org/abs/1706.03762>`__
25+ # λ
Όλ¬Έμμ μ°Ύμ μ μμ΅λλ€. μ΄ ν¨μλ κΈ°μ‘΄ ν¨μλ₯Ό μ¬μ©νμ¬ PyTorchλ‘ μμ±ν μ μμ§λ§,
26+ # ν¨μ¦λ(fused) ꡬνμ λ¨μν ꡬνλ³΄λ€ ν° μ±λ₯ μ΄μ μ μ 곡ν μ μμ΅λλ€.
2727#
28- # Fused implementations
28+ # ν¨μ¦λ ꡬν
2929# ~~~~~~~~~~~~~~~~~~~~~~
3030#
31- # For CUDA tensor inputs, the function will dispatch into one of the following
32- # implementations:
31+ # μ΄ ν¨μλ CUDA tensor μ
λ ₯μ λ€μ μ€ νλμ ꡬνμ μ¬μ©ν©λλ€.
32+ #
33+ # ꡬν:
3334#
3435# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
3536# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
3637# * A PyTorch implementation defined in C++
3738#
3839# .. note::
3940#
40- # This tutorial requires PyTorch 2.0.0 or later .
41+ # μ΄ νν 리μΌμ PyTorch λ²μ 2.0.0 μ΄μμ΄ νμν©λλ€ .
4142#
4243
4344import torch
4445import torch .nn as nn
4546import torch .nn .functional as F
4647device = "cuda" if torch .cuda .is_available () else "cpu"
4748
48- # Example Usage :
49+ # μ¬μ© μμ :
4950query , key , value = torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device )
5051F .scaled_dot_product_attention (query , key , value )
5152
5253
5354######################################################################
54- # Explicit Dispatcher Control
55- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
56- #
57- # While the function will implicitly dispatch to one of the three
58- # implementations, the user can also explicitly control the dispatch via
59- # the use of a context manager. This context manager allows users to
60- # explicitly disable certain implementations. If a user wants to ensure
61- # the function is indeed using the fastest implementation for their
62- # specific inputs, the context manager can be used to sweep through
63- # measuring performance.
55+ # λͺ
μμ Dispatcher μ μ΄
56+ # ~~~~~~~~~~~~~~~~~~~~
6457#
58+ # μ΄ ν¨μλ μμμ μΌλ‘ μΈ κ°μ§ ꡬν μ€ νλλ₯Ό μ¬μ©ν©λλ€. νμ§λ§ 컨ν
μ€νΈ λ§€λμ λ₯Ό
59+ # μ¬μ©νλ©΄ λͺ
μμ μΌλ‘ μ΄λ€ ꡬνμ μ¬μ©ν μ§ μ μ΄ν μ μμ΅λλ€. 컨ν
μ€νΈ λ§€λμ λ₯Ό ν΅ν΄
60+ # νΉμ ꡬνμ λͺ
μμ μΌλ‘ λΉνμ±ν ν μ μμ΅λλ€. νΉμ μ
λ ₯μ λν κ°μ₯ λΉ λ₯Έ ꡬνμ μ°Ύκ³ μ
61+ # νλ€λ©΄, 컨ν
μ€νΈ λ§€λμ λ‘ λͺ¨λ ꡬνμ μ±λ₯μ μΈ‘μ ν΄λ³Ό μ μμ΅λλ€.
6562
66- # Lets define a helpful benchmarking function:
63+ # λ²€μΉλ§ν¬ ν¨μλ₯Ό μ μν©λλ€
6764import torch .utils .benchmark as benchmark
6865def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
6966 t0 = benchmark .Timer (
7067 stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
7168 )
7269 return t0 .blocked_autorange ().mean * 1e6
7370
74- # Lets define the hyper-parameters of our input
71+ # μ
λ ₯μ νμ΄νΌνλΌλ―Έν°λ₯Ό μ μν©λλ€
7572batch_size = 32
7673max_sequence_len = 1024
7774num_heads = 32
@@ -85,7 +82,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
8582
8683print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
8784
88- # Lets explore the speed of each of the 3 implementations
85+ # μΈ κ°μ§ ꡬνμ μλλ₯Ό μΈ‘μ ν©λλ€
8986from torch .backends .cuda import sdp_kernel , SDPBackend
9087
9188# Helpful arguments mapper
@@ -114,24 +111,22 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
114111
115112
116113######################################################################
117- # Hardware dependence
118- # ~~~~~~~~~~~~~~~~~~~
114+ # νλμ¨μ΄ μμ‘΄μ±
115+ # ~~~~~~~~~~~~~
119116#
120- # Depending on what machine you ran the above cell on and what hardware is
121- # available, your results might be different.
122- # - If you donβt have a GPU and are running on CPU then the context manager
123- # will have no effect and all three runs should return similar timings.
124- # - Depending on what compute capability your graphics card supports
125- # flash attention or memory efficient might have failed.
117+ # μ μ
μ μ΄λ€ λ¨Έμ μμ μ€ννλμ§μ μ¬μ© κ°λ₯ν νλμ¨μ΄μ λ°λΌ κ²°κ³Όκ° λ€λ₯Ό μ μμ΅λλ€.
118+ # - GPUκ° μκ³ CPUμμ μ€ν μ€μ΄λΌλ©΄ 컨ν
μ€νΈ λ§€λμ λ ν¨κ³Όκ° μκ³ μΈ κ°μ§ μ€ν λͺ¨λ
119+ # μ μ¬ν μκ°μ λ°νν κ²μ
λλ€.
120+ # - κ·Έλν½ μΉ΄λκ° μ§μνλ μ»΄ν¨ν
λ₯λ ₯μ λ°λΌ flash attention λλ
121+ # memory efficient ꡬνμ΄ λμνμ§ μμ μ μμ΅λλ€.
126122
127123
128124######################################################################
129125# Causal Self Attention
130126# ~~~~~~~~~~~~~~~~~~~~~
131127#
132- # Below is an example implementation of a multi-headed causal self
133- # attention block inspired by
134- # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
128+ # μλλ multi-head causal self attention λΈλ‘μ ꡬν μμμ
λλ€.
129+ # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ μ μ₯μλ₯Ό μ°Έκ³ νμ΅λλ€.
135130#
136131
137132class CausalSelfAttention (nn .Module ):
@@ -187,12 +182,13 @@ def forward(self, x):
187182
188183
189184#####################################################################
190- # ``NestedTensor`` and Dense tensor support
191- # -----------------------------------------
185+ # ``NestedTensor`` λ° Dense tensor μ§μ
186+ # ------------------------------------
192187#
193- # SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences
194- # without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see
195- # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__.
188+ # SDPAλ ``NestedTensor`` μ Dense tensor μ
λ ₯μ λͺ¨λ μ§μν©λλ€.
189+ # ``NestedTensors`` λ μ
λ ₯μ΄ κ°λ³ κΈΈμ΄ μνμ€λ‘ ꡬμ±λ λ°°μΉμΈ κ²½μ°μ
190+ # λ°°μΉ λ΄ μνμ€μ μ΅λ κΈΈμ΄μ λ§μΆ° κ° μνμ€λ₯Ό ν¨λ©ν νμκ° μμ΅λλ€. ``NestedTensors`` μ λν μμΈν λ΄μ©μ
191+ # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ μ `NestedTensors νν λ¦¬μΌ <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__ μ μ°Έκ³ νμΈμ.
196192#
197193
198194import random
@@ -236,7 +232,7 @@ def generate_rand_batch(
236232random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
237233random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
238234
239- # Currently the fused implementations don't support ``NestedTensor`` for training
235+ # νμ¬ ν¨μ¦λ ꡬνμ ``NestedTensor`` λ‘ νμ΅νλ κ²μ μ§μνμ§ μμ΅λλ€.
240236model .eval ()
241237
242238with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
@@ -248,15 +244,14 @@ def generate_rand_batch(
248244
249245
250246######################################################################
251- # Using SDPA with ``torch.compile``
252- # =================================
247+ # ``torch.compile`` κ³Ό ν¨κ» SDPA μ¬μ©νκΈ°
248+ # =====================================
253249#
254- # With the release of PyTorch 2.0, a new feature called
255- # ``torch.compile()`` has been introduced, which can provide
256- # significant performance improvements over eager mode.
257- # Scaled dot product attention is fully composable with ``torch.compile()``.
258- # To demonstrate this, let's compile the ``CausalSelfAttention`` module using
259- # ``torch.compile()`` and observe the resulting performance improvements.
250+ # PyTorch 2.0 릴리μ¦μ ν¨κ» ``torch.compile()`` λΌλ μλ‘μ΄ κΈ°λ₯μ΄ μΆκ°λμλλ°,
251+ # μ΄λ eager modeλ³΄λ€ μλΉν μ±λ₯ ν₯μμ μ 곡ν μ μμ΅λλ€.
252+ # Scaled dot product attentionμ ``torch.compile()`` λ‘ μμ ν ꡬμ±ν μ μμ΅λλ€.
253+ # μ΄λ₯Ό νμΈνκΈ° μν΄ ``torch.compile()`` μ ν΅ν΄ ``CausalSelfAttention`` λͺ¨λμ μ»΄νμΌνκ³
254+ # κ²°κ³Όμ μΌλ‘ μ»μ΄μ§λ μ±λ₯ ν₯μμ μμλ΄
μλ€.
260255#
261256
262257batch_size = 32
@@ -276,12 +271,11 @@ def generate_rand_batch(
276271
277272######################################################################
278273#
279- # The exact execution time is dependent on machine, however the results for mine:
280- # The non compiled module runs in 166.616 microseconds
281- # The compiled module runs in 166.726 microseconds
282- # That is not what we were expecting. Let's dig a little deeper.
283- # PyTorch comes with an amazing built-in profiler that you can use to
284- # inspect the performance characteristics of your code.
274+ # μ νν μ€ν μκ°μ νκ²½μ λ°λΌ λ€λ₯΄μ§λ§, λ€μμ μ μμ κ²°κ³Όμ
λλ€.
275+ # μ»΄νμΌ λμ§ μμ λͺ¨λμ μ€νμ 166.616ms κ° μμλμμ΅λλ€.
276+ # μ»΄νμΌ λ λͺ¨λμ μ€νμ 166.726ms κ° μμλμμ΅λλ€.
277+ # μ΄λ μ°λ¦¬μ μμκ³Όλ λ€λ¦
λλ€. μ’ λ μμΈν μμλ΄
μλ€.
278+ # PyTorchλ μ½λμ μ±λ₯ νΉμ±μ μ κ²ν μ μλ λλΌμ΄ λ΄μ₯(built-in) νλ‘νμΌλ¬λ₯Ό μ 곡ν©λλ€.
285279#
286280
287281from torch .profiler import profile , record_function , ProfilerActivity
@@ -302,7 +296,7 @@ def generate_rand_batch(
302296 compiled_model (x )
303297print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
304298
305- # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
299+ # λ λ§μ μ 보λ₯Ό μ»κΈ° μν΄ μΆμ ( trace)λ₯Ό λ΄λ³΄λ΄κ³ ``chrome://tracing``μ μ¬μ©νμ¬ κ²°κ³Όλ₯Ό νμΈν΄λ³΄μΈμ.
306300# ::
307301#
308302# prof.export_chrome_trace("compiled_causal_attention_trace.json").
@@ -311,33 +305,30 @@ def generate_rand_batch(
311305
312306
313307######################################################################
314- # The previous code snippet generates a report of the top 10 PyTorch functions
315- # that consumed the most GPU execution time, for both the compiled and non-compiled module.
316- # The analysis reveals that the majority of time spent on the GPU is concentrated
317- # on the same set of functions for both modules.
318- # The reason for this here is that ``torch.compile`` is very good at removing the
319- # framework overhead associated with PyTorch. If your model is launching
320- # large, efficient CUDA kernels, which in this case ``CausaulSelfAttention``
321- # is, then the overhead of PyTorch can be hidden.
308+ # μ΄μ μ½λ μ‘°κ°(snippet)μ μ»΄νμΌ λ λͺ¨λκ³Ό μ»΄νμΌλμ§ μμ λͺ¨λ λͺ¨λμ λν΄
309+ # κ°μ₯ λ§μ GPU μ€ν μκ°μ μ°¨μ§ν μμ 10κ°μ PyTorch ν¨μμ λν λ³΄κ³ μλ₯Ό μμ±ν©λλ€.
310+ # λΆμ κ²°κ³Ό, λ λͺ¨λ λͺ¨λ GPUμμ μμλ μκ°μ λλΆλΆμ΄
311+ # λμΌν ν¨μλ€μ μ§μ€λμ΄ μμμ 보μ¬μ€λλ€.
312+ # PyTorchκ° νλ μμν¬ μ€λ²ν€λλ₯Ό μ κ±°νλ λ° λ§€μ° νμν ``torch.compile`` λ₯Ό
313+ # μ 곡νκΈ° λλ¬Έμ
λλ€. ``CausalSelfAttention`` κ°μ κ²½μ°μ²λΌ ν¬κ³ , ν¨μ¨μ μΈ CUDA 컀λμ
314+ # μ¬μ©νλ λͺ¨λΈμμ PyTorch μ€λ²ν€λλ μμμ§ κ²μ
λλ€.
322315#
323- # In reality, your module does not normally consist of a singular
324- # ``CausalSelfAttention`` block. When experimenting with ` Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
325- # the module took the time per train step from: ``6090.49ms`` to
326- # ``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on
327- # the Shakespeare dataset .
316+ # μ¬μ€, λͺ¨λμ λ³΄ν΅ ``CausalSelfAttention`` λΈλ νλλ§μΌλ‘ ꡬμ±λμ§ μμ΅λλ€.
317+ # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ μ μ₯μμμ μ€νν κ²½μ°,
318+ # λͺ¨λμ μ»΄νμΌ νλ κ²μ νμ΅μ κ° λ¨κ³λ³ μμ μκ°μ ``6090.49ms`` μμ ``3273.17ms`` λ‘
319+ # μ€μΌ μ μμμ΅λλ€. μ΄ μ€νμ NanoGPT μ μ₯μμ ``ae3a8d5`` 컀λ°μμ Shakespeare
320+ # λ°μ΄ν°μ
μ μ¬μ©νμ¬ μ§νλμμ΅λλ€ .
328321#
329322
330323
331324######################################################################
332- # Conclusion
333- # ==========
325+ # κ²°λ‘
326+ # ====
334327#
335- # In this tutorial, we have demonstrated the basic usage of
336- # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
337- # the ``sdp_kernel`` context manager can be used to assert a certain
338- # implementation is used on GPU. As well, we built a simple
339- # ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
340- # compilable. In the process we have shown how to the profiling tools can
341- # be used to explore the performance characteristics of a user defined
342- # module.
328+ # μ΄ νν 리μΌμμ, ``torch.nn.functional.scaled_dot_product_attention`` μ κΈ°λ³Έμ μΈ
329+ # μ¬μ©λ²μ μ΄ν΄λ΄€μ΅λλ€. ``sdp_kernel`` 컨ν
μ€νΈ λ§€λμ λ‘ GPUκ° νΉμ ꡬνμ
330+ # μ¬μ©νλλ‘ ν μ μλ€λ κ²μ 보μμ΅λλ€. λν, κ°λ¨ν ``NestedTensor`` μμ μλνκ³
331+ # μ»΄νμΌ κ°λ₯ν ``CausalSelfAttention`` λͺ¨λμ λ§λ€μμ΅λλ€.
332+ # μ΄ κ³Όμ μμ νλ‘νμΌλ§ λꡬλ₯Ό μ¬μ©νμ¬ μ μ κ° μ μν λͺ¨λμ μ±λ₯ νΉμ±μ μ΄λ»κ²
333+ # νμΈν μ μλμ§λ μ΄ν΄λ΄€μ΅λλ€.
343334#
0 commit comments