-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: various fixes and enhancements #46
Conversation
Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>
Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>
).transpose( | ||
1, 2 | ||
) # (B, nh, T, hs) -> (B, T, nh, hs) | ||
|
||
return y | ||
|
||
def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor: | ||
def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this for this PR? for context, this used to be used as a test, otherwise _torch_attn
does the job?
@@ -161,7 +166,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): | |||
seed=1337, | |||
device=device, | |||
dtype=GlobalState.config.dtype, | |||
compile=False, | |||
compile=GlobalState.config.compile, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will not work right now and cause recompilations at each time-steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does not seem to do that for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you try with vanilla kv-cache or flash decoding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, we probably need to change the mode for torch.compile
during inference to get the most out of it, and to use cuda graphs
flash_attn_with_kvcache, | ||
) | ||
except ImportError: | ||
warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: improve warning by providing the change in command required instead
This PR enables: