From 51962916b082b8d8768d3bff99dcbd19b3bc2b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gianluca=20Macr=C3=AC?= <56147945+gianlucamacri@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:07:36 +0100 Subject: [PATCH 1/4] Update llama_attn_replace_sft.py --- llama_attn_replace_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_attn_replace_sft.py b/llama_attn_replace_sft.py index 5100940e..d957e014 100644 --- a/llama_attn_replace_sft.py +++ b/llama_attn_replace_sft.py @@ -36,7 +36,7 @@ def forward_flashattn( attention_mask: [bsz, q_len] """ if not self.training: - raise ValueError("This function is only for training. For inference, please use forward_flashattn_inference.") + warnings.warn("This function should be used just for training as it may exhibit reduced inference performances. For inference, please use forward_flashattn_inference.") if output_attentions: warnings.warn( From 363c1958b43bd993361b062efb777d8930dccfb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gianluca=20Macr=C3=AC?= <56147945+gianlucamacri@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:08:52 +0100 Subject: [PATCH 2/4] Update llama_attn_replace.py --- llama_attn_replace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_attn_replace.py b/llama_attn_replace.py index 0a14c8ba..c4980bd9 100644 --- a/llama_attn_replace.py +++ b/llama_attn_replace.py @@ -34,7 +34,7 @@ def forward_flashattn( attention_mask: [bsz, q_len] """ if not self.training: - raise ValueError("This function is only for training. For inference, please use forward_flashattn_inference.") + warnings.warn("This function should be used just for training as it may exhibit reduced inference performances. For inference, please use forward_flashattn_inference.") if output_attentions: warnings.warn( From 0eb1a2a84c6828998e72f47b8be20b06243ad1ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gianluca=20Macr=C3=AC?= <56147945+gianlucamacri@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:11:04 +0100 Subject: [PATCH 3/4] Update llama_attn_replace_sft.py --- llama_attn_replace_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_attn_replace_sft.py b/llama_attn_replace_sft.py index d957e014..9f897605 100644 --- a/llama_attn_replace_sft.py +++ b/llama_attn_replace_sft.py @@ -36,7 +36,7 @@ def forward_flashattn( attention_mask: [bsz, q_len] """ if not self.training: - warnings.warn("This function should be used just for training as it may exhibit reduced inference performances. For inference, please use forward_flashattn_inference.") + warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.") if output_attentions: warnings.warn( From b5c6809515d47ef4be34bd86b2074d540a1b897a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gianluca=20Macr=C3=AC?= <56147945+gianlucamacri@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:11:21 +0100 Subject: [PATCH 4/4] Update llama_attn_replace.py --- llama_attn_replace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_attn_replace.py b/llama_attn_replace.py index c4980bd9..68ca77ac 100644 --- a/llama_attn_replace.py +++ b/llama_attn_replace.py @@ -34,7 +34,7 @@ def forward_flashattn( attention_mask: [bsz, q_len] """ if not self.training: - warnings.warn("This function should be used just for training as it may exhibit reduced inference performances. For inference, please use forward_flashattn_inference.") + warnings.warn("This function should be used just for training as it may exhibit reduced inference performance. For inference, please use forward_flashattn_inference.") if output_attentions: warnings.warn(