Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: various fixes and enhancements #46

Merged
merged 4 commits into from
Feb 12, 2024
Merged

fix: various fixes and enhancements #46

merged 4 commits into from
Feb 12, 2024

Conversation

pyetras
Copy link
Contributor

@pyetras pyetras commented Feb 12, 2024

This PR enables:

  • decoding without flash attention
  • early termination when decoding
  • prompt free guidance
  • kv-cache with any dtype

Piotr Sokolski added 4 commits February 12, 2024 10:12
Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>
Signed-off-by: Piotr Sokólski <piotr@themetavoice.xyz>
@pyetras
Copy link
Contributor Author

pyetras commented Feb 12, 2024

Fixes: #19, #22
Might fix, but not tested: #1

@pyetras pyetras merged commit 43f97a0 into main Feb 12, 2024
@pyetras pyetras mentioned this pull request Feb 12, 2024
).transpose(
1, 2
) # (B, nh, T, hs) -> (B, T, nh, hs)

return y

def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor:
def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor:
Copy link
Member

@vatsalaggarwal vatsalaggarwal Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this for this PR? for context, this used to be used as a test, otherwise _torch_attn does the job?

@@ -161,7 +166,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
seed=1337,
device=device,
dtype=GlobalState.config.dtype,
compile=False,
compile=GlobalState.config.compile,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not work right now and cause recompilations at each time-steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not seem to do that for me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you try with vanilla kv-cache or flash decoding?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, we probably need to change the mode for torch.compile during inference to get the most out of it, and to use cuda graphs

flash_attn_with_kvcache,
)
except ImportError:
warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: improve warning by providing the change in command required instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants