Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AsyncCollectiveTensor: prevent wait_tensor() calls on graph inputs fr…
…om getting DCEd (pytorch#125677) @wanchaol was seeing the loss eventually become NaN when compiling individual transformer blocks in torchtitan - with this patch I no longer see the NaN loss. The problem is the following: (1) It is possible to have graph inputs to a compiled region that are AsyncCollectiveTensors. In particular: when we compile individual transformer blocks in the llama model, the first layer (embedding layer) is run in eager mode, and it outputs an AsyncCollectiveTensor that is fed to the first transformer block (2) ideally, we would like that AsyncCollectiveTensor graph input to desugar into a `wait_tensor()` op that shows up at the beginning of the graph. (3) the way this is supposed to happen is: AOTAutograd traces through the __torch_dispatch__ of AsyncCollectiveTensor, tracing out a `wait_tensor()` call before dispatching to any of the other ops in the function we are tracing (4) however: `trigger_wait()` was getting called in a way where we would ignore its output (and return `self.elem` directly), which would cause the `wait_tensor` ops to get DCE'd. Pull Request resolved: pytorch#125677 Approved by: https://github.com/wanchaol, https://github.com/yifuwang ghstack dependencies: pytorch#125676
- Loading branch information