You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
in lightning/fabric/accelerators/tpu.py there is a _parse_tpu_devices function that hard codes a maximum of 8 devices. in torch_xla/distributed/xla_multiprocessing.py there's a validator that allows for either 1 or $WORLD_SIZE devices to be used.
When working with Amazon Trainium, the large trn1.32xlarge instances come equipped with 16 accelerators with 2 cores each for a total of 32 devices. Setting of both 8 and 32 cause a validation error to occur before training starts. This problem is not see when only 1 acclerator is in use since it falls under the <=8 threshold.
File ~/miniconda3/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:201, in _pre_fork_setup(num_devices)
199 num_devices = dev_count
200 elif num_devices not in [1, dev_count]:
--> 201 raise ValueError(
202 'The number of devices must be either 1 or {}, got {} instead'.format(
203 dev_count, num_devices))
204 total_devices = _get_world_size() * num_devices
205 if total_devices > 1 and not os.environ.get(xenv.SERVICE_ADDRESS, None):
206 # In multi-processing mode, even if there is only one XLA host, we still
207 # bring up the mesh service.
ValueError: The number of devices must be either 1 or 32, got 8 instead
and
File ~/miniconda3/lib/python3.10/site-packages/lightning/fabric/accelerators/tpu.py:158, in _parse_tpu_devices(devices)
155 devices = _parse_tpu_devices_str(devices.strip())
157 if not _tpu_devices_valid(devices):
--> 158 raise TypeError("`devices` can only be 1, 8 or [<1-8>] for TPUs.")
160 return devices
TypeError: `devices` can only be 1, 8 or [<1-8>] for TPUs.
Bug description
in lightning/fabric/accelerators/tpu.py there is a
_parse_tpu_devices
function that hard codes a maximum of 8 devices. intorch_xla/distributed/xla_multiprocessing.py
there's a validator that allows for either 1 or $WORLD_SIZE devices to be used.When working with Amazon Trainium, the large trn1.32xlarge instances come equipped with 16 accelerators with 2 cores each for a total of 32 devices. Setting of both 8 and 32 cause a validation error to occur before training starts. This problem is not see when only 1 acclerator is in use since it falls under the <=8 threshold.
What version are you seeing the problem on?
v2.0, v2.1, v2.2
How to reproduce the bug
Error messages and logs
and
Environment
- GPU: None
- available: False
- version: 11.7
- lightning: 2.2.3
- lightning-cloud: 0.5.68
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.3
- torch: 1.13.0
- torch-neuronx: 1.13.1.1.14.0
- torch-xla: 1.13.1+torchneurone
- torchmetrics: 1.3.2
- torchvision: 0.14.0
- absl-py: 2.1.0
- aiohttp: 3.9.5
- aiohttp-cors: 0.7.0
- aiosignal: 1.3.1
- alembic: 1.13.1
- anaconda-anon-usage: 0.4.4
- aniso8601: 9.0.1
- annotated-types: 0.6.0
- anyio: 4.3.0
- archspec: 0.2.3
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- asttokens: 2.4.1
- async-lru: 2.0.4
- async-timeout: 4.0.3
- attrs: 23.2.0
- aws-neuronx-runtime-discovery: 2.9
- babel: 2.14.0
- beautifulsoup4: 4.12.3
- bio: 1.7.0
- biopython: 1.83
- biothings-client: 0.3.1
- bleach: 6.1.0
- blessed: 1.20.0
- blinker: 1.8.1
- boltons: 23.0.0
- boto3: 1.34.93
- botocore: 1.34.93
- brotli: 1.0.9
- cachetools: 5.3.3
- certifi: 2024.2.2
- cffi: 1.16.0
- charset-normalizer: 2.0.4
- click: 8.1.7
- cloud-tpu-client: 0.10
- cloudpickle: 3.0.0
- colorful: 0.5.6
- comm: 0.2.2
- conda: 24.4.0
- conda-content-trust: 0.2.0
- conda-libmamba-solver: 24.1.0
- conda-package-handling: 2.2.0
- conda-package-streaming: 0.9.0
- contourpy: 1.2.1
- croniter: 1.3.15
- cryptography: 42.0.5
- cycler: 0.12.1
- datasets: 2.19.0
- dateutils: 0.6.12
- debugpy: 1.8.1
- decorator: 5.1.1
- deepdiff: 7.0.1
- defusedxml: 0.7.1
- deprecated: 1.2.14
- dill: 0.3.8
- distlib: 0.3.8
- distro: 1.8.0
- dm-tree: 0.1.8
- docker: 7.0.0
- docutils: 0.21.2
- ec2-metadata: 2.10.0
- editor: 1.6.6
- entrypoints: 0.4
- exceptiongroup: 1.2.1
- executing: 2.0.1
- farama-notifications: 0.0.4
- fastapi: 0.88.0
- fastjsonschema: 2.19.1
- filelock: 3.14.0
- flask: 3.0.3
- fonttools: 4.51.0
- fqdn: 1.5.1
- frozenlist: 1.4.1
- fsspec: 2023.12.2
- gitdb: 4.0.11
- gitpython: 3.1.43
- google-api-core: 1.34.1
- google-api-python-client: 1.8.0
- google-auth: 2.29.0
- google-auth-httplib2: 0.2.0
- googleapis-common-protos: 1.63.0
- gprofiler-official: 1.0.0
- graphene: 3.3
- graphql-core: 3.2.3
- graphql-relay: 3.2.0
- greenlet: 3.0.3
- grpcio: 1.62.2
- gunicorn: 21.2.0
- gymnasium: 0.28.1
- h11: 0.14.0
- httpcore: 1.0.5
- httplib2: 0.22.0
- httptools: 0.6.1
- httpx: 0.27.0
- huggingface-hub: 0.22.2
- idna: 3.7
- imageio: 2.34.1
- importlib-metadata: 7.0.0
- inquirer: 3.2.4
- ipykernel: 6.29.4
- ipython: 8.24.0
- ipywidgets: 8.1.2
- islpy: 2023.1
- isoduration: 20.11.0
- itsdangerous: 2.2.0
- jax-jumpy: 1.0.0
- jedi: 0.19.1
- jinja2: 3.1.3
- jmespath: 1.0.1
- joblib: 1.4.0
- json5: 0.9.25
- jsonpatch: 1.33
- jsonpointer: 2.1
- jsonschema: 4.21.1
- jsonschema-specifications: 2023.12.1
- jupyter: 1.0.0
- jupyter-client: 8.6.1
- jupyter-console: 6.6.3
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.0
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.1.8
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.1
- jupyterlab-widgets: 3.0.10
- kiwisolver: 1.4.5
- lazy-loader: 0.4
- libmambapy: 1.5.8
- libneuronxla: 0.5.971
- lightning: 2.2.3
- lightning-cloud: 0.5.68
- lightning-utilities: 0.11.2
- linkify-it-py: 2.0.3
- lockfile: 0.12.2
- lz4: 4.3.3
- mako: 1.3.3
- markdown: 3.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mdit-py-plugins: 0.4.0
- mdurl: 0.1.2
- memray: 1.12.0
- menuinst: 2.0.2
- mistune: 3.0.2
- mlflow: 2.12.1
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- multiprocess: 0.70.16
- mygene: 3.2.2
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 2.6.3
- neuronx-cc: 2.13.72.0+78a426937
- notebook: 7.1.3
- notebook-shim: 0.2.4
- numpy: 1.25.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.1.105
- oauth2client: 4.1.3
- opencensus: 0.11.4
- opencensus-context: 0.1.3
- opentelemetry-api: 1.24.0
- opentelemetry-exporter-otlp: 1.24.0
- opentelemetry-exporter-otlp-proto-common: 1.24.0
- opentelemetry-exporter-otlp-proto-grpc: 1.24.0
- opentelemetry-exporter-otlp-proto-http: 1.24.0
- opentelemetry-proto: 1.24.0
- opentelemetry-sdk: 1.24.0
- opentelemetry-semantic-conventions: 0.45b0
- ordered-set: 4.1.0
- overrides: 7.7.0
- packaging: 23.2
- pandas: 2.2.2
- pandocfilters: 1.5.1
- parso: 0.8.4
- pexpect: 4.9.0
- pgzip: 0.3.5
- pillow: 10.3.0
- pip: 23.3.1
- platformdirs: 3.10.0
- pluggy: 1.0.0
- polars: 0.20.23
- pooch: 1.8.1
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.43
- proto-plus: 1.23.0
- protobuf: 3.19.6
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-spy: 0.3.14
- pyarrow: 15.0.2
- pyarrow-hotfix: 0.6
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pycosat: 0.6.6
- pycparser: 2.21
- pydantic: 1.10.15
- pydantic-core: 2.18.2
- pygments: 2.17.2
- pyjwt: 2.8.0
- pyparsing: 3.1.2
- pysocks: 1.7.1
- python-daemon: 3.0.1
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.0.1
- python-json-logger: 2.0.7
- python-multipart: 0.0.9
- pytorch-lightning: 2.2.3
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.2
- qtconsole: 5.5.1
- qtpy: 2.4.1
- querystring-parser: 1.2.4
- ray: 2.12.0
- ray-cpp: 2.12.0
- readchar: 4.0.6
- referencing: 0.35.0
- regex: 2024.4.28
- requests: 2.31.0
- requests-unixsocket: 0.3.0
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.7.1
- rpds-py: 0.18.0
- rsa: 4.9
- ruamel.yaml: 0.17.21
- ruamel.yaml.clib: 0.2.6
- runs: 1.2.2
- s3transfer: 0.10.1
- safetensors: 0.4.3
- scikit-image: 0.23.2
- scikit-learn: 1.4.2
- scipy: 1.11.2
- send2trash: 1.8.3
- setuptools: 68.2.2
- shellingham: 1.5.4
- six: 1.16.0
- smart-open: 7.0.4
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.5
- sqlalchemy: 2.0.29
- sqlparse: 0.5.0
- stack-data: 0.6.3
- starlette: 0.22.0
- starsessions: 1.3.0
- sympy: 1.12
- tensorboardx: 2.6.2.2
- terminado: 0.18.1
- textual: 0.58.0
- threadpoolctl: 3.5.0
- tifffile: 2024.4.24
- tinycss2: 1.3.0
- tokenizers: 0.19.1
- tomli: 2.0.1
- torch: 1.13.0
- torch-neuronx: 1.13.1.1.14.0
- torch-xla: 1.13.1+torchneurone
- torchmetrics: 1.3.2
- torchvision: 0.14.0
- tornado: 6.4
- tqdm: 4.65.0
- traitlets: 5.14.3
- transformers: 4.40.1
- triton: 2.3.0
- truststore: 0.8.0
- typer: 0.12.3
- types-python-dateutil: 2.9.0.20240316
- typing-extensions: 4.11.0
- tzdata: 2024.1
- uc-micro-py: 1.0.3
- uri-template: 1.3.0
- uritemplate: 3.0.1
- urllib3: 2.1.0
- uvicorn: 0.29.0
- uvloop: 0.19.0
- virtualenv: 20.26.1
- watchfiles: 0.21.0
- wcwidth: 0.2.13
- webcolors: 1.13
- webencodings: 0.5.1
- websocket-client: 1.8.0
- websockets: 11.0.3
- werkzeug: 3.0.2
- wheel: 0.41.2
- widgetsnbextension: 4.0.10
- wrapt: 1.16.0
- xmod: 1.8.1
- xxhash: 3.4.1
- yarl: 1.9.4
- zipp: 3.18.1
- zstandard: 0.19.0
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 5.10.214-202.855.amzn2.x86_64
- version: Proposal for help #1 SMP Tue Apr 9 06:57:12 UTC 2024
More info
No response
The text was updated successfully, but these errors were encountered: