Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: support xpu backend from stock pytorch (>=2.4) #2825

Merged
merged 1 commit into from
Jun 13, 2024

Conversation

dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Jun 4, 2024

Fixes: huggingface/transformers#31237

XPU backend is available in the stock PyTorch starting from version 2.4 [1]. This commit extends huggingface accelerate to support XPU from both IPEX and the stock pytorch. IPEX is being tried first.

Raising this PR as WIP and Draft to facilitate further discussion around XPU backend enabling in huggingface and be able to communicate observed XPU issues back to PyTorch.

[1] pytorch/pytorch#114842

@EikanWang, @fengyuan14, @guangyey, @jgong5, @kding1, @sywangyi

dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 4, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch for doing this! Great start, just one question 🤗

src/accelerate/utils/imports.py Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Show resolved Hide resolved
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 5, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looking much better. Just a nit

src/accelerate/utils/imports.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1")
# torch.xpu.optimize is available only for xpu via IPEX
if hasattr(torch.xpu, "optimize"):
model, optimizer = torch.xpu.optimize(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr : I think I figured it out why there are 2 different calls (torch.xpu.optimize and ipex.optimize) in the current code. It occurs that IPEX can be built in 2 distinct ways:

  1. The first way is to build it with Intel GPUs support, i.e. IPEX-XPU. If it's built this way, we get monkey-patches applied and torch.xpu.optimize() is exposed which is a handy version of ipex.optimize() which is also available.
  2. The second way to built IPEX is with CPU support. In this case there is no any GPUs support, no monkey-patches available and no XPU is available. In this case the only thing which is available is ipex.optimize()

So, it seems that current Huggingface accelerate code is supporting both paths, IPEX-XPU and IPEX-CPU. That's if judging from the source code. Interesting to note that the https://github.com/huggingface/accelerate/blob/main/docs/source/usage_guides/ipex.md talks only about IPEX-CPU and does not mention IPEX-XPU... I wonder whether IPEX-XPU is fully enabled in Huggingface?

Summarizing, I think I need to update my PR covering all 3 cases: 1) IPEX-XPU, 2) IPEX-CPU, 3) XPU in stock PyTorch. I will try to rewrite in a way that it will be clear which options are there on the plate, will add some comments to the code.

FYI The easiest way to check behavior would be probably trying out IPEX containers from https://hub.docker.com/r/intel/intel-optimized-pytorch. Here are some printouts:

# IPEX CPU
$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.3.0-pip-base python3 -c 'import torch; import intel_extension_for_pytorch; print(torch.xpu.is_available())'
False

$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.3.0-pip-base python3 -c 'import torch; import intel_extension_for_pytorch as ipex; print(hasattr(ipex, "optimize"))'
True

$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.3.0-pip-base python3 -c 'import torch; import intel_extension_for_pytorch as ipex; print(hasattr(torch.xpu, "optimize"))'
False

# IPEX XPU
$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.1.30-xpu python3 -c 'import torch; import intel_extension_for_pytorch; print(torch.xpu.is_available())'
True

$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.1.30-xpu python3 -c 'import torch; import intel_extension_for_pytorch as ipex; print(hasattr(ipex, "optimize"))'
True

$ docker run -it --rm --privileged intel/intel-extension-for-pytorch:2.1.30-xpu python3 -c 'import torch; import intel_extension_for_pytorch as ipex; print(hasattr(torch.xpu, "optimize"))'
True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr : I reworked the PR according to above. Please, help to review again.

@dvrogozh dvrogozh requested a review from muellerzr June 7, 2024 16:15
@dvrogozh dvrogozh force-pushed the xpu branch 2 times, most recently from e918737 to 9b94a01 Compare June 7, 2024 16:33
@dvrogozh dvrogozh marked this pull request as ready for review June 7, 2024 17:00
@dvrogozh dvrogozh changed the title [WIP] xpu: support xpu backend from stock pytorch (>=2.4) xpu: support xpu backend from stock pytorch (>=2.4) Jun 7, 2024
@dvrogozh
Copy link
Contributor Author

dvrogozh commented Jun 7, 2024

I tried this PR (+ huggingface/transformers#31238) as much as I could in the IPEX-CPU, IPEX-XPU, Pytorch-XPU, Pytorch-CPU scenarios. Tried to run some tests from accelerate and transformers and some examples from transformers. All seem to work engaging with XPU when expected. I promote these PRs from drafts for the qualified review. Let me know if any concerns or any feedback needs to be addressed.

@dvrogozh
Copy link
Contributor Author

Applied doc-builder style src/accelerate docs/source --max_len 119 to fix format issues identified by ci.

@dvrogozh
Copy link
Contributor Author

@muellerzr : can you, please, help to run ci again? Also, is there anything else I can help with fixing in this PR to get it merged?

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Jun 11, 2024

I did not see such a failure before on this PR. Can this be something random since I can't associate this failure with the changes made. I also tried this locally and test worked for me running on cpu. @muellerzr, can you, please, advise?

FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_with_hooks_use_pytorch - assert 0.0007739067077636719 > 0.001
 +  where 0.0007739067077636719 = abs((4.019573211669922 - 4.0203471183776855))
 +    where 4.0203471183776855 = get_signature(Linear(in_features=2, out_features=4, bias=True))

dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 12, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@dvrogozh
Copy link
Contributor Author

@SunMarc : thank you for retriggering failed ci. I see it's passing now. I guess my assumption that this was sporadic failure is true.

@SunMarc, @muellerzr : I have outlined current status of xpu backend in pytorch in huggingface/transformers#31237. There are a number of issues in xpu backend which are being worked on right now. I believe however that this PR and PR in transformers (huggingface/transformers#31237) are ready as the first step to enable xpu backend in huggingface on top of which we can gradually improve the support. Can you, please, outline acceptance requirements for these PRs on Huggingface side?

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks great to me, thank you for the improvement!

cc @SunMarc for a second pair of eyes, else we can merge it after the nit has been addressed!

Comment on lines 389 to 392
if importlib.util.find_spec("torch") is None:
return False

import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can actually be removed, as accelerate always requires PyTorch :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed:

I copied this from is_npu_available above. Which by the way also re-imports torch. Do you want me to also fix is_npu_available() in this PR?

if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False
import torch
import torch_npu # noqa: F401

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can actually be removed, as accelerate always requires PyTorch :)

Fixed. I will submit npu/mlu in a separate cleanup PR unless you will tell me otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will submit npu/mlu in a separate cleanup PR unless you will tell me otherwise.

Submitted #2856.

@muellerzr muellerzr requested a review from SunMarc June 13, 2024 14:12
Fixes: huggingface/transformers#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface accelerate
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a small nit

)
else:
# ipex.optimize() is available only for IPEX, both IPEX-CPU and IPEX-XPU
if is_ipex_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change that by self.state.use_ipex (or add at least)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I afraid this might break IPEX-XPU path.

Note that use_ipex is currently used with CPU path to differentiate the case when IPEX-CPU optimization should or should not be used:

if self.device.type == "cpu" and self.state.use_ipex:

In case of IPEX-XPU this flag was not used and I am not sure whether it will be =True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good !

dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 13, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! Next step: transformers :)

(Also if you want the version this will release on, accelerate==0.32.0)

@muellerzr muellerzr merged commit 3b5a00e into huggingface:main Jun 13, 2024
22 of 23 checks passed
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 13, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 13, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 13, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Jun 14, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
ydshieh pushed a commit to dvrogozh/transformers that referenced this pull request Jun 14, 2024
Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
ydshieh pushed a commit to huggingface/transformers that referenced this pull request Jun 14, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
itazap pushed a commit to huggingface/transformers that referenced this pull request Jun 17, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
itazap pushed a commit to huggingface/transformers that referenced this pull request Jun 17, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
itazap pushed a commit to huggingface/transformers that referenced this pull request Jun 17, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
itazap pushed a commit to huggingface/transformers that referenced this pull request Jun 18, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
itazap pushed a commit to huggingface/transformers that referenced this pull request Jun 20, 2024
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: #31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

xpu: Support new PyTorch XPU backend (>=2.4)
6 participants