Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable pytorch xpu support for non-attention models #2561

Merged
merged 1 commit into from
Oct 14, 2024

Conversation

dvrogozh
Copy link
Contributor

XPU backend is available natively (without IPEX) in pytorch starting from pytorch 2.4. This commit extends TGI to cover the case when user has XPU support thru pytorch 2.4, but does not have IPEX installed. Models which don't require attention can work. For attention required models more work is needed to provide attention implementation.

Tested with the following models:

  • teknium/OpenHermes-2.5-Mistral-7B
  • bigscience/bloom-560m
  • google/gemma-7b
  • google/flan-t5-xxl

CC: @Narsil

XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@sywangyi
Copy link
Contributor

please hold the PR,the pytorch in the Dockerfile_intel for xpu is 2.3. see https://github.com/huggingface/text-generation-inference/blob/main/Dockerfile_intel#L94

@sywangyi
Copy link
Contributor

and what do you mean for non-attention models.
teknium/OpenHermes-2.5-Mistral-7B
bigscience/bloom-560m
google/gemma-7b
google/flan-t5-xxl

these models all have attention per my understanding, we now use pageattention of ipex to support it.

@dvrogozh
Copy link
Contributor Author

please hold the PR,the pytorch in the Dockerfile_intel for xpu is 2.3

In the dockerfile yes. It's possible however to build TGI from sources against different version of pytorch. I have pytorch build from main which is current 2.6 candidate.

@dvrogozh
Copy link
Contributor Author

these models all have attention per my understanding

These models have fallback mechanism which is being triggered if there is no attention available. For example, for gemma it's defined like in below code snippet. Fallback is no line 828. Since IPEX container provides attention, you probably did not notice it. As you can see in this PR, fallback() path was not patched for XPU support.

if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM.fallback(
model_id,
revision,

@dvrogozh
Copy link
Contributor Author

These models have fallback mechanism which is being triggered if there is no attention available.

To be more precise, it's being triggered if one of the attention enabled models failed to load (which potentially might not be the same as there is no attention available). Logic is defined here:

from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False

@sywangyi
Copy link
Contributor

sywangyi commented Sep 25, 2024

Have you compared the fallback perf with the current one? Per my understanding, fallback path has a a perf limitation, also TensorParallel /GPTQ/AWQ is not support. so usually if I find a model need to fallback to transformers. I will implement the customer model in TGI.

@sywangyi
Copy link
Contributor

also we use Dockerfile_intel to generate the text-generation-inference docker image. and user just directly download it from docker hub. see https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference. That's why I suggest you hold the change.

@Narsil
Copy link
Collaborator

Narsil commented Oct 1, 2024

Following this thread, my understanding is that this PR isn't ready yet, correct ?

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Oct 1, 2024

Following this thread, my understanding is that this PR isn't ready yet, correct ?

My opinion as the author of the patch is that it's ready. It enables TGI to work with upstream pytorch with xpu support for those models which are currently possible. These are those which don't require custom kernels. On top of this PR further work need to happen to propose implementation of the missing kernels. Few ways would be possible here:

  1. Implement kernels with triton
  2. Implement kernels with sycl (but this might need additional work on pytorch side to expose syck in CPP extension API, see xpu: support torch.utils.cpp_extension APIs to build SYCL kernels pytorch/pytorch#132944)
  3. Use IPEX as custom ops library for the missing kernels (this also requires additional work on IPEX side to expose such custom ops library rather than intermix it with pytorch plugin functionality as it currently does)

I wonder what's TGI maintainers opinion on this topic. In particular, are there plans to rewrite attention kernels residing in TGI with triton? maybe such step is planned or is being doing by anyone at the moment?

I also think that PR improves handling of xpu path regardless of whether IPEX or upstream pytorch XPU is used - some of the conditions which were corrected are now more logical. If someone will want to use TGA against upstream pytorch xpu, he will need to build it from sources. I did not provide dockerfile for this case. Existing IPEX path continues to work.

Apparently, there is other opinion expressed above that IPEX to be currently used. This however tights TGI on xpu to the older pytorch version and restricts try outs of TGI against upstream pytorch xpu for those who are willing to do so. I consider proposed PR to be a step in the right direction in any case because ultimate logic should be to enable all possible features with the base stuff (pytorch), then enhance with additional features (custom kernels and additional 3rd party libraries). Current path with IPEX makes this in the upside down way due to the initial nature of IPEX as a plugin for pytorch. Things are changing however with xpu being available right out of the box in pytorch and plugin aspect of IPEX going away. This PR is a step towards this change.

@Narsil
Copy link
Collaborator

Narsil commented Oct 2, 2024

I with triton? maybe such step is planned or is being doing by anyone at the moment?

Triton has serious drawbacks as a technology for production because it's a JIT environment meaning there's unbound compilation time during a runtime initial phase (which could be long or even crash very late because of all the compilations).

As much as I respect the work being done over there, it's really hard to use as long as it's not AOT compilation.
I know you can have "precomputed" cache, vllm does it a bit, but it's still something you have to run on every machine to get the proper cached version and invalidation is really not fun there. And if you do not, then you get into super slow startup/runtime.

If this JIT can be done entirely during the warmup then it's better already (still not great because of the super slow startup times but at least you're not having atrocious runtimes initially). So far we've seen that custom made cuda kernels work much better in practice than triton made kernels.

We've also tiptoed with torch.compile and it suffers from the same issues.

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Oct 2, 2024

Triton has serious drawbacks <...>

Thank you for sharing this @Narsil, very interesting.

@Narsil Narsil merged commit 58848cb into huggingface:main Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants