diff --git a/master/_modules/index.html b/master/_modules/index.html index b4793584905..0a37467c31d 100644 --- a/master/_modules/index.html +++ b/master/_modules/index.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/core/xla_model.html b/master/_modules/torch_xla/core/xla_model.html index d242277c772..70f637c6c98 100644 --- a/master/_modules/torch_xla/core/xla_model.html +++ b/master/_modules/torch_xla/core/xla_model.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/debug/metrics.html b/master/_modules/torch_xla/debug/metrics.html index 6bd1872d25b..c407efe9c86 100644 --- a/master/_modules/torch_xla/debug/metrics.html +++ b/master/_modules/torch_xla/debug/metrics.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/distributed/parallel_loader.html b/master/_modules/torch_xla/distributed/parallel_loader.html index c78f48258fd..a54d7b47561 100644 --- a/master/_modules/torch_xla/distributed/parallel_loader.html +++ b/master/_modules/torch_xla/distributed/parallel_loader.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html index e9d2f2e487d..e6aa595881a 100644 --- a/master/_modules/torch_xla/distributed/spmd/xla_sharding.html +++ b/master/_modules/torch_xla/distributed/spmd/xla_sharding.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/distributed/xla_multiprocessing.html b/master/_modules/torch_xla/distributed/xla_multiprocessing.html index 0bd860cb33e..060ff8e8af4 100644 --- a/master/_modules/torch_xla/distributed/xla_multiprocessing.html +++ b/master/_modules/torch_xla/distributed/xla_multiprocessing.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/experimental/eager.html b/master/_modules/torch_xla/experimental/eager.html index 4e535a5e614..b29f38c5c0a 100644 --- a/master/_modules/torch_xla/experimental/eager.html +++ b/master/_modules/torch_xla/experimental/eager.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/runtime.html b/master/_modules/torch_xla/runtime.html index cf829f77c56..1bd0d79eed4 100644 --- a/master/_modules/torch_xla/runtime.html +++ b/master/_modules/torch_xla/runtime.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_modules/torch_xla/torch_xla.html b/master/_modules/torch_xla/torch_xla.html index 2e909d3fbf8..82babae13cb 100644 --- a/master/_modules/torch_xla/torch_xla.html +++ b/master/_modules/torch_xla/torch_xla.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/_sources/features/distop.md.txt b/master/_sources/features/distop.md.txt new file mode 100644 index 00000000000..49355bf827f --- /dev/null +++ b/master/_sources/features/distop.md.txt @@ -0,0 +1,89 @@ +# Support of Torch Distributed API in PyTorch/XLA +Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases. +## Collective ops lowering +### Collective ops lowering stack +After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through `torch.ops._c10d_functional` call. Below is the figure that shows how the collective op, `all_reduce` in this case, is lowered between torch and torch_xla: + + +Alt Text + +_Figure 1. Collective ops lowering stack_ + +### non-Dynamo case +Collective ops are lowered through registering the `ProcessGroupXla`, which is derived from `ProcessGroup`: + +```Python +# torch_xla/distributed/xla_backend.py +def _create_xla_process_group(prefix_store, rank, size, timeout): + assert not xr.is_spmd( + ), "XLA backend is not supported with SPMD. Please use a CPU process group instead." + return ProcessGroupXla(prefix_store, rank, size, timeout) + + +def _register_xla_backend(): + dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla') + + +class ProcessGroupXla(ProcessGroup): + ... + def allreduce(self, tensors, all_reduce_options): + ... + def allgather(self, output_tensors_list, input_tensors, opts=None): + ... +``` + +The corresponding xla dist backend is initialized when we call: +```Python +def _mp_fn(rank): + dist.init_process_group("xla", init_method='xla://') + +In this way, collective ops will be called based on the progress group instance: + + # E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py + @_exception_logger + def all_gather(tensor_list, tensor, group=None, async_op=False): + ... + group = group or _get_default_group() + work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead +``` + +### Dynamo case +For dynamo case, certain collective ops are remapped to the new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` will be mapped to `all_reduce_inplace()`, where eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering: + + +```C++ +at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp, + std::string /*group_name*/) {...} + +TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) { + m.impl("all_reduce", all_reduce); +} +``` + + +## API description + +For release 2.5, we now support four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. While the function signatures remain consistent, certain input restrictions still apply. +For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of dist ops in both Dynamo and non-Dynamo scenarios. +Below are the details for each operation: +```Python +dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM) +``` +`all_reduce` performs an in-place reduction on the `input` tensor by aggregating data from all nodes. + +```Python +dist.all_gather_into_tensor(output, input) +``` +`all_gather_into_tensor` gathers the input tensor from all nodes and updates the `output` tensor in-place. It also returns an alias of the output. + +```Python +dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM) +``` +`reduce_scatter_tensor` reduces the input tensor across all nodes and distributes the result to the `output` tensor in-place. It returns an alias of the output. + +```Python +dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None) +``` +`all_to_all_single` function performs an all-to-all communication, updating the output tensor in-place and returning its alias. + +Note: Although `output_split_sizes` and `input_split_sizes` are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation. diff --git a/master/_static/img/dist_op_stack.png b/master/_static/img/dist_op_stack.png new file mode 100644 index 00000000000..5d7215fe49a Binary files /dev/null and b/master/_static/img/dist_op_stack.png differ diff --git a/master/accelerators/gpu.html b/master/accelerators/gpu.html index 579b8290ede..faf33cb22c0 100644 --- a/master/accelerators/gpu.html +++ b/master/accelerators/gpu.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/accelerators/tpu.html b/master/accelerators/tpu.html index 17f2075211e..a6f8dad7302 100644 --- a/master/accelerators/tpu.html +++ b/master/accelerators/tpu.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/contribute/bazel.html b/master/contribute/bazel.html index 76a391aa37b..5b8d4ac64bb 100644 --- a/master/contribute/bazel.html +++ b/master/contribute/bazel.html @@ -266,7 +266,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/contribute/codegen_migration.html b/master/contribute/codegen_migration.html index 947ae84dc9f..15a0974ff3f 100644 --- a/master/contribute/codegen_migration.html +++ b/master/contribute/codegen_migration.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/contribute/configure-environment.html b/master/contribute/configure-environment.html index a9da0ce6e04..2fee7c7e4c5 100644 --- a/master/contribute/configure-environment.html +++ b/master/contribute/configure-environment.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/contribute/op_lowering.html b/master/contribute/op_lowering.html index f72149e6fac..a4545bb7cea 100644 --- a/master/contribute/op_lowering.html +++ b/master/contribute/op_lowering.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/contribute/plugins.html b/master/contribute/plugins.html index c04ae74e83d..b571154740d 100644 --- a/master/contribute/plugins.html +++ b/master/contribute/plugins.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/features/distop.html b/master/features/distop.html new file mode 100644 index 00000000000..055e6ba747a --- /dev/null +++ b/master/features/distop.html @@ -0,0 +1,831 @@ + + + + + + + + + + + + Support of Torch Distributed API in PyTorch/XLA — PyTorch/XLA master documentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+ + + + + +
+
+
+ + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + +
+ +
    + +
  • + + + Docs + + > +
  • + + +
  • Support of Torch Distributed API in PyTorch/XLA
  • + + +
  • + + + + + +
  • + +
+ + +
+
+ +
+ Shortcuts +
+
+ +
+
+ + + + + + +
+ +
+
+ +
+

Support of Torch Distributed API in PyTorch/XLA

+

Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call torch_xla.core.xla_model.*. In the 2.5 release, we adopt torch.distributed.* in PyTorch/XLA for both Dynamo and non-Dynamo cases.

+
+

Collective ops lowering

+
+

Collective ops lowering stack

+

After introducing the traceable collective communication APIs, dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through torch.ops._c10d_functional call. Below is the figure that shows how the collective op, all_reduce in this case, is lowered between torch and torch_xla:

+

Alt Text

+

*Figure 1. Collective ops lowering stack*

+
+
+

non-Dynamo case

+

Collective ops are lowered through registering the ProcessGroupXla, which is derived from ProcessGroup:

+
# torch_xla/distributed/xla_backend.py
+def _create_xla_process_group(prefix_store, rank, size, timeout):
+  assert not xr.is_spmd(
+  ), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
+  return ProcessGroupXla(prefix_store, rank, size, timeout)
+
+
+def _register_xla_backend():
+  dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')
+
+
+class ProcessGroupXla(ProcessGroup):
+  ...
+  def allreduce(self, tensors, all_reduce_options):
+    ...
+  def allgather(self, output_tensors_list, input_tensors, opts=None):
+    ...
+
+
+

The corresponding xla dist backend is initialized when we call:

+
def _mp_fn(rank):
+  dist.init_process_group("xla", init_method='xla://')
+
+In this way, collective ops will be called based on the progress group instance:
+
+  # E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
+  @_exception_logger
+  def all_gather(tensor_list, tensor, group=None, async_op=False):
+    ...
+    group = group or _get_default_group()
+    work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
+
+
+
+
+

Dynamo case

+

For dynamo case, certain collective ops are remapped to the new function in pytorch/torch/distributed/_functional_collectives.py. For example, all_reduce() will be mapped to all_reduce_inplace(), where eventually torch.ops._c10d_functional.all_reduce(). Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering:

+
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
+                      std::string /*group_name*/)  {...}
+
+TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
+  m.impl("all_reduce", all_reduce);
+}
+
+
+
+
+
+

API description

+

For release 2.5, we now support four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch’s upstream implementation. While the function signatures remain consistent, certain input restrictions still apply. +For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to test_collective_ops_tpu.py, which demonstrates the use of dist ops in both Dynamo and non-Dynamo scenarios. +Below are the details for each operation:

+
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
+
+
+

all_reduce performs an in-place reduction on the input tensor by aggregating data from all nodes.

+
dist.all_gather_into_tensor(output, input)
+
+
+

all_gather_into_tensor gathers the input tensor from all nodes and updates the output tensor in-place. It also returns an alias of the output.

+
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
+
+
+

reduce_scatter_tensor reduces the input tensor across all nodes and distributes the result to the output tensor in-place. It returns an alias of the output.

+
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
+
+
+

all_to_all_single function performs an all-to-all communication, updating the output tensor in-place and returning its alias.

+

Note: Although output_split_sizes and input_split_sizes are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation.

+
+
+ + +
+ +
+ + +
+
+ + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+
+

Docs

+

Access comprehensive developer documentation for PyTorch

+ View Docs +
+ +
+

Tutorials

+

Get in-depth tutorials for beginners and advanced developers

+ View Tutorials +
+ +
+

Resources

+

Find development resources and get your questions answered

+ View Resources +
+
+
+
+ + + + + + + + + +
+
+
+
+ + +
+
+
+ + +
+ + + + + + + + \ No newline at end of file diff --git a/master/features/pallas.html b/master/features/pallas.html index e83be794bba..df8af52e99f 100644 --- a/master/features/pallas.html +++ b/master/features/pallas.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/features/stablehlo.html b/master/features/stablehlo.html index d7dc59ed7a8..50f7d98d081 100644 --- a/master/features/stablehlo.html +++ b/master/features/stablehlo.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/features/triton.html b/master/features/triton.html index a0a6ada4e2c..38e03cda973 100644 --- a/master/features/triton.html +++ b/master/features/triton.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/genindex.html b/master/genindex.html index 2ebe684cdb6..c5211e05cae 100644 --- a/master/genindex.html +++ b/master/genindex.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/index.html b/master/index.html index 0405453491f..00df74234d1 100644 --- a/master/index.html +++ b/master/index.html @@ -266,7 +266,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/api-guide.html b/master/learn/api-guide.html index 4d56a4e3b32..d157e7f31b0 100644 --- a/master/learn/api-guide.html +++ b/master/learn/api-guide.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/dynamic_shape.html b/master/learn/dynamic_shape.html index 9f7c706933d..c1a72d44055 100644 --- a/master/learn/dynamic_shape.html +++ b/master/learn/dynamic_shape.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/eager.html b/master/learn/eager.html index d3f1dca1804..781ab410279 100644 --- a/master/learn/eager.html +++ b/master/learn/eager.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/pjrt.html b/master/learn/pjrt.html index 2e22aee7704..e553076c8ac 100644 --- a/master/learn/pjrt.html +++ b/master/learn/pjrt.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/pytorch-on-xla-devices.html b/master/learn/pytorch-on-xla-devices.html index e4f4f6b27cf..127973e4b1b 100644 --- a/master/learn/pytorch-on-xla-devices.html +++ b/master/learn/pytorch-on-xla-devices.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/troubleshoot.html b/master/learn/troubleshoot.html index 0ec66fc1450..52656f55793 100644 --- a/master/learn/troubleshoot.html +++ b/master/learn/troubleshoot.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/learn/xla-overview.html b/master/learn/xla-overview.html index 42aacdcf186..e5f2f717cd4 100644 --- a/master/learn/xla-overview.html +++ b/master/learn/xla-overview.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/notes/source_of_recompilation.html b/master/notes/source_of_recompilation.html index fce323dff7b..d1e4f36ae90 100644 --- a/master/notes/source_of_recompilation.html +++ b/master/notes/source_of_recompilation.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/objects.inv b/master/objects.inv index aae39531bd6..1aea5f1e6e2 100644 Binary files a/master/objects.inv and b/master/objects.inv differ diff --git a/master/perf/amp.html b/master/perf/amp.html index e562966d5e2..1d5443dbf1f 100644 --- a/master/perf/amp.html +++ b/master/perf/amp.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/ddp.html b/master/perf/ddp.html index 1ce933cc769..b78f0cea0bf 100644 --- a/master/perf/ddp.html +++ b/master/perf/ddp.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/dynamo.html b/master/perf/dynamo.html index b78c72fb2a4..423dfaf9e4a 100644 --- a/master/perf/dynamo.html +++ b/master/perf/dynamo.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/fori_loop.html b/master/perf/fori_loop.html index 616818a7595..e213fc5656a 100644 --- a/master/perf/fori_loop.html +++ b/master/perf/fori_loop.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/fsdp.html b/master/perf/fsdp.html index ba77fa35246..14a7b52d5f9 100644 --- a/master/perf/fsdp.html +++ b/master/perf/fsdp.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/fsdpv2.html b/master/perf/fsdpv2.html index 66ba1cddf5e..9690b3d8ebb 100644 --- a/master/perf/fsdpv2.html +++ b/master/perf/fsdpv2.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/quantized_ops.html b/master/perf/quantized_ops.html index a103c3cc1b0..eb885a68442 100644 --- a/master/perf/quantized_ops.html +++ b/master/perf/quantized_ops.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/recompilation.html b/master/perf/recompilation.html index 64a19e603da..e7def8c24f6 100644 --- a/master/perf/recompilation.html +++ b/master/perf/recompilation.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/spmd_advanced.html b/master/perf/spmd_advanced.html index 3a3da094bb3..08c9f404248 100644 --- a/master/perf/spmd_advanced.html +++ b/master/perf/spmd_advanced.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/spmd_basic.html b/master/perf/spmd_basic.html index 3903c1b1400..16bfaa74f6e 100644 --- a/master/perf/spmd_basic.html +++ b/master/perf/spmd_basic.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/spmd_distributed_checkpoint.html b/master/perf/spmd_distributed_checkpoint.html index a03d601927e..6ea7fa0f66c 100644 --- a/master/perf/spmd_distributed_checkpoint.html +++ b/master/perf/spmd_distributed_checkpoint.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/perf/spmd_gpu.html b/master/perf/spmd_gpu.html index c1b243ba80a..2e1e3712ed2 100644 --- a/master/perf/spmd_gpu.html +++ b/master/perf/spmd_gpu.html @@ -267,7 +267,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/py-modindex.html b/master/py-modindex.html index 08e0a5b01b0..3e16c114e38 100644 --- a/master/py-modindex.html +++ b/master/py-modindex.html @@ -268,7 +268,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/search.html b/master/search.html index fdef454288f..ca5d5203150 100644 --- a/master/search.html +++ b/master/search.html @@ -265,7 +265,7 @@
- master (2.6.0+gita0f81e5 ) + master (2.6.0+gitfa311ec )
diff --git a/master/searchindex.js b/master/searchindex.js index d43d53f6a26..f2ec1bf4e75 100644 --- a/master/searchindex.js +++ b/master/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["accelerators/gpu", "accelerators/tpu", "contribute/bazel", "contribute/codegen_migration", "contribute/configure-environment", "contribute/op_lowering", "contribute/plugins", "features/pallas", "features/stablehlo", "features/triton", "index", "learn/api-guide", "learn/dynamic_shape", "learn/eager", "learn/pjrt", "learn/pytorch-on-xla-devices", "learn/troubleshoot", "learn/xla-overview", "notes/source_of_recompilation", "perf/amp", "perf/ddp", "perf/dynamo", "perf/fori_loop", "perf/fsdp", "perf/fsdpv2", "perf/quantized_ops", "perf/recompilation", "perf/spmd_advanced", "perf/spmd_basic", "perf/spmd_distributed_checkpoint", "perf/spmd_gpu"], "filenames": ["accelerators/gpu.md", "accelerators/tpu.md", "contribute/bazel.md", "contribute/codegen_migration.md", "contribute/configure-environment.md", "contribute/op_lowering.md", "contribute/plugins.md", "features/pallas.md", "features/stablehlo.md", "features/triton.md", "index.rst", "learn/api-guide.rst", "learn/dynamic_shape.md", "learn/eager.md", "learn/pjrt.md", "learn/pytorch-on-xla-devices.md", "learn/troubleshoot.md", "learn/xla-overview.md", "notes/source_of_recompilation.md", "perf/amp.md", "perf/ddp.md", "perf/dynamo.md", "perf/fori_loop.md", "perf/fsdp.md", "perf/fsdpv2.md", "perf/quantized_ops.md", "perf/recompilation.md", "perf/spmd_advanced.md", "perf/spmd_basic.md", "perf/spmd_distributed_checkpoint.md", "perf/spmd_gpu.md"], "titles": ["Learn about GPUs", "Learn about TPUs", "Bazel in Pytorch/XLA", "Codegen migration Guide", "Configure a development environment", "OP Lowering Guide", "Custom Hardware Plugins", "Custom Kernels via Pallas", "Torch Export to StableHLO", "Custom GPU Kernels via Triton", "PyTorch/XLA documentation", "PyTorch/XLA API", "Dynamic shape", "Eager Mode + Compile API", "PJRT Runtime", "PyTorch on XLA Devices", "Troubleshoot", "Pytorch/XLA overview", "Source of recompilations in torch_xla", "Automatic Mixed Precision", "How to do DistributedDataParallel(DDP)", "TorchDynamo integration in PyTorch XLA", "Optimize memory utilization using while_loop", "Fully Sharded Data Parallel in PyTorch XLA", "Fully Sharded Data Parallel using SPMD", "Quantized Operations for XLA (Experimental feature)", "Source of recompilations in Pytorch/XLA", "PyTorch/XLA SPMD advanced topics", "PyTorch/XLA SPMD User Guide", "Distributed Checkpointing", "Running SPMD on GPU"], "terms": {"For": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 30], "inform": [0, 1, 4, 9, 11, 13, 14, 15, 16, 17, 18, 26, 30], "googl": [0, 1, 7, 14, 15], "cloud": [0, 1, 2, 4, 6, 10, 14, 15, 21, 29], "see": [0, 1, 2, 3, 4, 5, 6, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 23, 26], "machin": [0, 2, 4, 14, 16, 17, 30], "type": [0, 4, 6, 8, 11, 14, 15, 16, 17, 19, 20], "ar": [1, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 23, 25, 26, 27, 28, 29], "custom": [1, 3, 4, 10, 11, 18, 20, 23, 25, 26, 27, 28], "design": [1, 14, 15, 21, 24, 28], "ai": 1, "acceler": [1, 4, 11, 12, 14, 15, 17, 19], "which": [1, 2, 3, 5, 6, 8, 11, 12, 14, 15, 16, 17, 18, 19, 21, 23, 24, 26, 27, 29], "optim": [1, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 26], "train": [1, 7, 11, 12, 15, 16, 17, 19, 27, 29, 30], "infer": [1, 3, 11, 14, 19, 27, 30], "larg": [1, 12, 14, 17, 18, 23, 26, 28], "model": [1, 3, 5, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 23, 24, 26, 27, 28, 29, 30], "thei": [1, 2, 5, 6, 11, 14, 15, 16, 17, 18, 19, 26, 27, 28], "ideal": [1, 2, 3, 18, 21, 26], "varieti": 1, "us": [1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 14, 15, 16, 17, 19, 21, 23, 27, 29, 30], "case": [1, 2, 3, 5, 8, 11, 14, 15, 16, 17, 21, 24, 27], "chatbot": 1, "code": [1, 3, 5, 8, 9, 11, 13, 14, 15, 16, 18, 20, 21, 26, 27], "gener": [1, 5, 11, 13, 14, 15, 16, 17, 18, 26], "media": 1, "content": [1, 11], "synthet": 1, "speech": 1, "vision": [1, 23], "servic": [1, 2, 14], "recommend": [1, 2, 3, 4, 5, 11, 13, 14, 15, 19, 27], "engin": [1, 16], "person": 1, "among": 1, "other": [1, 2, 3, 5, 7, 11, 12, 14, 15, 16, 17, 18, 19, 20, 25, 26, 28], "scale": [1, 8, 11, 14, 19, 21, 28], "cost": [1, 21], "effici": [1, 8, 16, 17, 21], "wide": [1, 5, 18, 26], "rang": [1, 5, 11, 14, 24, 27, 28], "workload": [1, 14, 15, 16, 27, 28], "span": [1, 3], "fine": 1, "tune": [1, 27], "provid": [1, 2, 3, 5, 6, 7, 8, 11, 15, 16, 17, 18, 19, 21, 22, 23, 25, 26, 27, 28, 29], "versatil": 1, "lead": [1, 16, 17], "framework": [1, 8, 10, 13, 18, 25, 26], "includ": [1, 2, 5, 11, 14, 16, 17, 18, 19, 22, 26, 29], "pytorch": [1, 5, 9, 12, 13, 14, 18, 19, 20, 22, 25, 29, 30], "jax": [1, 6, 7, 8, 14], "tensorflow": [1, 2, 6, 8, 11, 14, 16, 18, 26], "seamlessli": 1, "orchestr": 1, "through": [1, 3, 5, 6, 7, 15, 17, 18, 19, 26, 29], "integr": [1, 9, 10, 24, 25, 28], "kubernet": 1, "gke": 1, "leverag": [1, 9, 30], "dynam": [1, 3, 5, 10, 16, 17, 21], "schedul": [1, 17], "improv": [1, 14, 15, 16, 17, 19, 21, 27], "scalabl": 1, "all": [1, 2, 3, 5, 9, 11, 14, 15, 16, 17, 18, 19, 20, 23, 24, 26, 27, 29], "need": [1, 2, 3, 5, 11, 14, 15, 16, 17, 18, 20, 23, 24, 26, 27, 28], "simultan": 1, "look": [1, 3, 5, 15, 16, 17, 27], "simplest": 1, "wai": [1, 2, 5, 7, 11, 14, 15, 17, 18, 20, 21, 25, 26, 27], "develop": [1, 2, 9, 10, 13, 15, 20, 21, 25, 28], "can": [1, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 23, 24, 25, 27, 28, 29, 30], "also": [1, 2, 3, 5, 6, 8, 9, 11, 13, 14, 15, 16, 17, 18, 21, 23, 24, 25, 26, 27, 28], "vertex": 1, "fulli": [1, 10, 13, 14, 16, 28], "manag": [1, 7, 11, 19, 29], "platform": 1, "more": [1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 14, 15, 17, 18, 26, 27, 28, 30], "introduct": [1, 7], "set": [1, 2, 11, 14, 16, 17, 18, 19, 21, 23, 26, 27, 29], "up": [1, 2, 3, 14, 15, 17, 18, 21, 24, 26], "environ": [1, 2, 10, 14, 15, 17, 20, 27, 29], "resourc": [1, 11, 16], "i": [2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 27, 29], "free": [2, 5, 12, 16, 19, 20, 23], "softwar": [2, 16], "tool": [2, 5, 17, 23], "autom": 2, "openxla": [2, 6, 13, 21, 25], "both": [2, 4, 5, 8, 14, 17, 18, 19, 21, 23, 24, 25, 26, 28, 29], "make": [2, 4, 9, 11, 13, 14, 15, 16, 17, 18, 20, 21, 26, 27], "good": [2, 3, 5, 17, 18, 26, 27], "fit": [2, 3, 17, 23], "well": [2, 3, 6, 8, 11, 14, 17, 18, 26, 28], "extern": [2, 4, 7], "seen": [2, 17, 21], "workspac": [2, 16], "file": [2, 4, 11, 14, 16, 17, 19, 20], "http_archiv": 2, "name": [2, 4, 5, 8, 11, 14, 16, 18, 24, 26, 27, 28], "org_tensorflow": 2, "strip_prefix": 2, "f7759359f8420d3ca7b9fd19493f2a01bd47b4ef": 2, "url": 2, "http": [2, 3, 4, 7, 9, 11, 14, 16, 17, 23, 27], "github": [2, 3, 4, 5, 9, 11, 14, 16, 17, 20, 23, 27], "com": [2, 3, 4, 7, 9, 11, 14, 16, 17, 23, 27], "archiv": 2, "tar": 2, "gz": 2, "pin": [2, 11], "updat": [2, 3, 15, 17, 18, 19, 26, 27], "point": [2, 3, 4, 5, 6, 8, 11, 17, 18, 19, 26], "thi": [2, 3, 4, 5, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30], "repositori": [2, 14], "differ": [2, 6, 11, 15, 16, 17, 18, 20, 22, 23, 26, 27, 28], "revis": 2, "patch": [2, 16], "mai": [2, 3, 6, 14, 15, 16, 17, 18, 19, 26, 27], "ad": [2, 5, 11, 15, 17, 18, 21, 22, 26, 27], "resolv": 2, "prepar": 2, "hermet": 2, "mechan": 2, "deploi": 2, "becaus": [2, 3, 8, 13, 14, 15, 17, 19, 27], "local": [2, 4, 11, 14, 15, 16, 27], "checkout": [2, 16], "ha": [2, 3, 4, 5, 7, 11, 13, 14, 15, 17, 18, 26, 27, 28], "built": [2, 4], "from": [2, 3, 4, 5, 7, 8, 9, 11, 12, 16, 17, 19, 20, 21, 22, 23, 24, 27, 28, 29], "sourc": [2, 3, 5, 6, 8, 10, 11, 16], "instal": [2, 3, 4, 5, 6, 7, 8, 9, 14, 16, 17], "system": [2, 28], "version": [2, 3, 4, 7, 14, 17, 19, 27], "compat": [2, 8, 14, 25, 29], "e": [2, 4, 6, 8, 11, 12, 14, 16, 17, 18, 19, 23, 25, 26, 27], "g": [2, 4, 6, 8, 11, 12, 14, 16, 17, 18, 25, 26, 27, 29], "codegen": [2, 5, 10], "torchgen": [2, 3], "python": [2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 16, 17, 18, 20, 21, 26, 27], "modul": [2, 7, 8, 11, 15, 16, 20, 23, 24, 27], "should": [2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15, 16, 17, 18, 19, 22, 23, 26, 27, 29], "The": [2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 17, 18, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30], "directori": [2, 3, 5, 8], "either": [2, 5, 11, 14, 16, 18, 19, 26], "bzl": 2, "overriden": 2, "command": [2, 3, 4, 14, 15, 16, 17, 20, 23], "line": [2, 3, 11, 13, 15, 16, 17, 18, 23, 26], "override_repositori": 2, "path": [2, 6, 8, 11, 15, 16, 18, 23, 26], "export": [2, 3, 4, 5, 10, 14, 16, 17], "tf_repo": 2, "torch_repo": 2, "pleas": [2, 3, 5, 8, 11, 14, 15, 16, 17, 19, 23, 24, 25, 27, 30], "sure": [2, 15, 16], "overridden": [2, 3], "appropri": [2, 17], "been": [2, 5, 11, 14, 15, 17, 18, 26, 27], "use_cuda": 2, "0": [2, 3, 4, 6, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 30], "setup": [2, 3, 6, 15, 20], "py": [2, 3, 4, 5, 9, 12, 13, 14, 15, 16, 17, 20, 23, 27, 30], "bdist_wheel": 2, "expect": [2, 3, 6, 9, 13, 14, 16, 18, 21, 25, 26], "object": [2, 11, 27], "present": [2, 29], "new_local_repositori": 2, "build_fil": 2, "pytorch_local_dir": 2, "header": 2, "directli": [2, 3, 5, 6, 11, 14, 15, 16, 17, 18, 19, 23, 26, 27, 29], "share": [2, 3, 6, 14, 15, 16, 27], "libtorch": 2, "so": [2, 3, 6, 9, 11, 12, 14, 15, 16, 17, 18, 23, 26, 29], "same": [2, 3, 5, 6, 8, 9, 11, 13, 14, 15, 16, 17, 18, 19, 22, 25, 26, 27, 28, 30], "where": [2, 4, 7, 11, 12, 14, 15, 16, 17, 18, 23, 24, 26], "lib": [2, 6], "contain": [2, 3, 5, 6, 8, 9, 11, 14, 16, 17, 18, 26], "work": [2, 3, 11, 12, 14, 15, 16, 17, 18, 20, 21, 25, 26, 27, 28], "": [2, 4, 5, 6, 7, 8, 11, 13, 14, 15, 16, 17, 19, 20, 21, 25, 27, 28, 29], "requir": [2, 3, 5, 11, 12, 14, 15, 16, 17, 18, 19, 26, 27, 29, 30], "pass": [2, 5, 8, 9, 11, 14, 17, 19, 20, 27], "isystemextern": 2, "compil": [2, 5, 6, 8, 9, 10, 11, 12, 14, 17, 18, 19, 21, 24, 25, 26, 28, 29], "find": [2, 3, 5, 8, 14, 16, 17, 20, 24], "satisfi": [2, 27], "them": [2, 3, 5, 8, 11, 14, 15, 16, 17, 18, 26], "some": [2, 3, 5, 11, 12, 13, 14, 15, 16, 20, 25, 27], "user": [2, 4, 6, 8, 10, 13, 14, 15, 16, 17, 18, 21, 22, 24, 25, 26, 27, 29], "bring": [2, 3, 24], "pybind11": 2, "embed": 2, "link": [2, 3], "against": [2, 20], "libpython": 2, "instead": [2, 11, 13, 14, 15, 16, 17, 18, 20, 21, 23, 26, 27, 29], "These": [2, 3, 5, 7, 14, 17, 25, 29], "pybind11_emb": 2, "option": [2, 3, 4, 6, 8, 11, 14, 16, 17, 25, 27, 29], "transit": [2, 15], "simpl": [2, 3, 7, 14, 17, 19, 23, 28], "torch_xla": [2, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], "csrc": [2, 5], "runtim": [2, 3, 4, 6, 10, 15, 16, 20, 23, 27, 28, 29], "configr": 2, "via": [2, 4, 10, 14, 22, 23, 24, 27, 28], "bazelrc": 2, "take": [2, 3, 8, 9, 11, 15, 16, 17, 18, 26, 27], "flag": [2, 3, 11, 12, 19], "config": [2, 4], "remote_cach": 2, "configur": [2, 3, 5, 10, 11, 14, 16, 17, 29], "gcloud": [2, 4, 14, 15, 17], "usual": [2, 3, 5, 13, 15, 16], "faster": [2, 14, 17, 18, 21, 26], "authent": [2, 14], "easi": [2, 14, 15, 18, 26], "express": [2, 24, 28], "complex": [2, 9, 21], "lot": [2, 15, 16, 17, 18, 26], "gain": [2, 14], "have": [2, 3, 4, 5, 6, 7, 8, 11, 14, 15, 16, 17, 18, 20, 21, 23, 24, 26, 27, 29], "singl": [2, 11, 13, 18, 20, 21, 23, 24, 26, 27, 28, 30], "graph": [2, 8, 9, 11, 13, 14, 15, 16, 17, 18, 20, 21, 26, 27], "everyth": [2, 18, 20, 26], "therefor": [2, 16, 17], "separ": [2, 3, 5, 15, 17, 21, 23, 24], "rest": [2, 14, 16, 18, 26], "plu": [2, 20, 22], "whole": [2, 11, 13, 18, 21, 26], "everythin": 2, "els": [2, 16, 18, 26], "enough": [2, 17, 18, 26], "normal": [2, 3, 14, 18, 24, 26, 27], "achiev": [2, 5, 13, 20], "invok": [2, 3, 21, 27], "standard": [2, 8], "c": [2, 3, 5, 11, 14, 16, 18, 19, 26], "bind": [2, 8], "simpli": [2, 14], "_xlac": [2, 9, 16, 18, 26], "client": [2, 6, 11, 14], "togeth": [2, 13, 14, 15, 20, 23, 27], "when": [2, 3, 5, 9, 11, 12, 13, 14, 15, 16, 17, 19, 21, 23, 27, 28, 29], "chang": [2, 5, 12, 15, 16, 17, 18, 19, 20, 25, 26, 27], "abl": [2, 15, 18, 26, 29], "without": [2, 5, 11, 14, 16, 17, 27, 28, 29], "iter": [2, 11, 12, 15, 16, 17, 21, 27], "cycl": 2, "come": [2, 11, 18, 26], "There": [2, 3, 13, 15, 16, 17, 18, 20, 21, 26, 27], "plenti": 2, "backend": [2, 3, 11, 13, 14, 18, 21, 22, 25, 26, 27, 29], "we": [2, 3, 4, 5, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 30], "our": [2, 3, 4, 5, 6, 7, 8, 12, 14, 15, 16, 18, 19, 20, 21, 26, 27], "gc": [2, 29], "storag": [2, 4, 7, 15, 16, 17, 23, 29], "you": [2, 4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 20, 21, 23, 24, 27, 28, 30], "under": [2, 3, 5, 11, 14, 15, 20], "disabl": [2, 11, 13, 16, 17], "default": [2, 5, 11, 13, 14, 15, 16, 17, 19, 23, 27, 29], "speed": [2, 17, 18, 21, 26], "increment": [2, 3], "huge": [2, 16, 17, 18, 20, 26], "margin": 2, "almost": [2, 28], "alwai": [2, 14, 15, 16, 18, 26, 28], "enabl": [2, 9, 11, 12, 13, 16, 17, 19, 20, 25, 27, 28, 29], "ci": [2, 5], "To": [2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 23, 24, 26, 27, 29, 30], "ensur": [2, 8, 11, 18, 24, 26, 27, 29], "credenti": 2, "auth": [2, 14], "applic": [2, 16, 25, 29], "login": [2, 17], "launch": [2, 11, 14, 15, 17, 20, 21, 23], "browser": 2, "gcp": [2, 4, 14], "variou": [2, 9], "individu": [2, 23, 24, 28], "who": [2, 20], "access": [2, 3, 5, 7, 11, 14, 15, 16, 17, 18, 26, 29], "project": [2, 4, 6, 14, 15, 17], "one": [2, 3, 5, 7, 8, 11, 14, 15, 16, 17, 18, 21, 22, 23, 24, 26, 27, 28, 30], "onli": [2, 3, 5, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 26, 28, 29], "specifi": [2, 8, 11, 15, 17, 23, 27], "google_default_credenti": 2, "token": [2, 13, 17, 25], "out": [2, 5, 8, 11, 12, 13, 14, 15, 16, 17, 19, 21, 27], "box": [2, 5, 27], "log": [2, 16, 17], "permiss": 2, "add": [2, 3, 5, 8, 9, 11, 15, 16, 17, 18, 21, 22, 23, 26], "new": [2, 3, 4, 5, 13, 15, 16, 17, 18, 21, 26, 27], "role": 2, "In": [2, 3, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 26, 27, 28, 29], "account": [2, 17], "kei": [2, 4, 6, 14, 16, 17, 29], "google_credenti": 2, "On": [2, 14, 29], "docker": [2, 8], "network": [2, 11, 14, 15, 16, 19, 27], "cloudbuild": 2, "down": [2, 5, 17], "imag": [2, 14, 17, 18, 20, 23, 26], "do": [2, 3, 5, 10, 12, 14, 15, 16, 17, 18, 19, 23, 25, 26, 27], "doe": [2, 3, 11, 12, 14, 15, 16, 17, 18, 19, 26, 27], "read": [2, 4, 5, 11, 14, 27], "write": [2, 5, 9, 11, 15, 28], "silo": 2, "each": [2, 3, 5, 9, 11, 14, 15, 16, 17, 18, 21, 23, 24, 26, 27, 28, 29], "uniqu": [2, 15, 17, 18, 26], "benefit": [2, 17, 24, 25, 29], "consist": [2, 8, 14], "remote_default_exec_properti": 2, "some_silo_kei": 2, "bazel_remote_cach": 2, "1": [2, 4, 6, 7, 8, 11, 13, 14, 15, 16, 19, 20, 21, 22, 23, 24, 27, 28, 30], "silo_nam": 2, "your": [2, 3, 6, 7, 8, 14, 15, 16, 17, 18, 20, 24, 26, 27, 29], "tpuvm_mod": 2, "gcloud_service_key_fil": 2, "application_default_credenti": 2, "json": [2, 8], "might": [2, 5, 11, 15, 16, 17, 18, 26], "help": [2, 16, 17, 18, 26], "too": [2, 16, 18, 26], "cannot": [2, 7, 17, 18, 19, 23, 26], "here": [2, 3, 5, 7, 8, 12, 15, 17, 18, 20, 21, 23, 24, 26, 27, 28, 29], "author": 2, "usernam": 2, "behavior": [2, 3, 5, 14, 15, 16, 19], "function": [2, 5, 6, 7, 8, 9, 11, 13, 15, 16, 17, 21, 22, 24, 25, 29], "intend": 2, "first": [2, 3, 4, 8, 9, 11, 12, 14, 16, 17, 20, 27, 28, 29, 30], "time": [2, 3, 4, 11, 12, 14, 15, 16, 17, 18, 21, 22, 26, 27], "slow": [2, 16, 17], "scratch": [2, 3], "veri": [2, 6, 7, 13, 15, 17, 18, 26], "fast": [2, 18, 26], "onc": [2, 11, 15, 16, 17, 18, 21, 26, 27], "again": [2, 3, 8, 15, 17], "bit": [2, 15, 25], "slower": [2, 16, 17, 20], "per": [2, 8, 11, 14, 15, 16, 19, 20, 21, 25], "until": [2, 11, 15, 17, 29], "next": [2, 11, 16, 17, 18, 25, 26, 27], "quit": 2, "current": [2, 6, 7, 8, 11, 12, 13, 14, 15, 17, 18, 20, 21, 22, 24, 25, 26, 27, 30], "migrat": [2, 10, 14], "futur": [2, 3, 4, 6, 8, 12, 14, 15, 16, 17, 18, 24, 26], "plafrom": 2, "cpp": [2, 5], "main": [2, 4, 8, 9, 13, 14, 27], "Of": 2, "cours": 2, "pjrt": [2, 10, 11, 15, 27], "Not": 2, "environment": 2, "variabl": [2, 4, 12, 14, 17, 18, 26], "miss": [2, 5, 11, 16], "common": [2, 14, 18, 24, 25, 26, 28, 29], "part": [2, 3, 6, 9, 11, 13, 14, 16, 17, 27], "ones": [2, 11, 18, 26], "helper": [2, 3, 8, 11], "script": [2, 3, 4, 7, 14, 15, 16, 17, 19, 20, 30], "run_test": 2, "sh": 2, "r": [2, 17], "xla_client": 2, "pure": [2, 3], "easili": [2, 5, 18, 21, 26], "execut": [2, 9, 11, 13, 14, 15, 17, 18, 19, 20, 21, 26, 27, 28, 30], "parallel": [2, 10, 11, 14, 16, 20, 27, 28], "sinc": [2, 3, 5, 14, 15, 16, 17, 18, 19, 20, 21, 24, 26, 29], "xrt": [2, 11], "port": [2, 14, 30], "gpu": [2, 5, 6, 7, 10, 12, 16, 17, 27], "tpu": [2, 3, 5, 6, 7, 10, 11, 12, 16, 20, 21, 22, 29, 30], "devic": [2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 16, 18, 19, 20, 21, 22, 25, 26, 28, 29], "avail": [2, 11, 14, 15, 16, 17, 18, 23, 26, 30], "reason": [2, 3, 5, 13, 14, 17, 20], "bundl": 2, "target": [2, 8, 13, 14, 15, 17, 18, 19, 21, 26], "sequenti": [2, 11], "calcul": 2, "visual": [2, 27], "lcov": 2, "describ": [2, 3, 4, 8, 11, 15, 17, 19, 20, 28], "document": [2, 3, 4, 5, 6, 8, 14, 15, 19, 20, 25], "editor": 2, "choic": [2, 18, 26], "gutter": 2, "vscode": 2, "power": 2, "like": [2, 3, 4, 5, 7, 11, 14, 15, 16, 17, 18, 19, 23, 26, 27], "clangd": 2, "refer": [2, 3, 5, 7, 8, 9, 12, 14, 15, 17, 23, 25, 27, 30], "autocomplet": 2, "semant": [2, 5, 16, 18, 26], "understand": [2, 17, 18, 26], "underli": [2, 11, 15], "stack": [2, 15, 16, 18, 19, 26, 27], "combin": [2, 5, 11, 18, 26], "studio": 2, "extens": [2, 4, 5, 6], "featur": [2, 7, 12, 14, 16, 20, 24, 27, 28, 29], "assist": 2, "edit": 2, "As": [2, 3, 17, 18, 24, 26], "distutil": 2, "ltc": 3, "lazi": [3, 16, 17, 18, 21, 26, 27], "tensor": [3, 5, 8, 11, 12, 14, 17, 19, 21, 22, 24, 25, 27, 28], "core": [3, 5, 8, 11, 13, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 28], "clean": [3, 16, 21], "exist": [3, 8, 11, 13, 14, 15, 16, 21, 27], "stub": 3, "over": [3, 11, 13, 14, 15, 17, 23, 29], "6": [3, 4, 5, 8, 11, 16, 17, 18, 26], "were": [3, 15, 16, 17, 18, 26], "complet": [3, 11, 15, 16], "process": [3, 5, 6, 9, 11, 13, 14, 16, 17, 20, 23, 25], "found": [3, 14, 17], "ref": [3, 4, 14], "replac": [3, 17, 22], "support": [3, 6, 7, 8, 9, 11, 12, 14, 18, 21, 22, 23, 26, 27, 29, 30], "NOT": 3, "introduc": [3, 7, 13, 14, 16, 17, 20, 27], "ani": [3, 7, 8, 11, 12, 14, 15, 16, 17, 18, 19, 20, 23, 24, 26, 27, 28, 29, 30], "purpos": [3, 5, 25], "follow": [3, 5, 7, 8, 9, 11, 13, 14, 15, 16, 17, 18, 20, 23, 24, 26, 27, 30], "instruct": [3, 5, 17], "depend": [3, 4, 5, 12, 13, 15, 17, 18, 19, 26], "build": [3, 5, 15, 17, 23], "It": [3, 4, 5, 11, 12, 13, 15, 17, 18, 21, 23, 24, 25, 26, 27], "experi": [3, 5, 13, 14, 20, 29], "workstat": [3, 5], "cpu": [3, 5, 8, 11, 16, 17, 18, 23, 25, 26, 27, 29], "pjrt_devic": [3, 5, 6, 12, 14, 15, 16, 22, 30], "re": [3, 11, 13, 14, 16, 17, 18, 19, 22, 24, 26], "familiar": [3, 15, 24], "issu": [3, 5, 11, 13, 14, 15, 16, 17, 19, 20, 24], "3560": 3, "track": [3, 16, 29], "statu": [3, 16], "put": [3, 5, 15, 16, 20], "alia": [3, 11], "avoid": [3, 16, 17, 19], "duplic": 3, "mention": [3, 5, 18, 21, 26], "below": [3, 5, 8, 13, 14, 17, 18, 19, 26, 29, 30], "live": [3, 5, 11, 18, 26], "folder": [3, 4, 5], "except": [3, 5, 17, 27], "xla_native_funct": [3, 5], "yaml": [3, 5], "torch": [3, 4, 7, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29], "shape_infer": 3, "shape": [3, 5, 7, 8, 9, 10, 11, 16, 17, 22, 27, 28], "defin": [3, 5, 7, 9, 11, 17, 19, 22, 24, 27, 28], "input": [3, 5, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 22, 24, 27, 28, 29], "return": [3, 5, 6, 7, 8, 11, 13, 15, 16, 17, 18, 20, 21, 22, 25, 26, 27, 29], "output": [3, 4, 7, 8, 9, 11, 14, 15, 16, 19, 20, 21, 22, 23, 27], "manual": [3, 5, 7, 13, 16, 23], "gen_lazy_tensor": 3, "data": [3, 8, 10, 11, 13, 14, 15, 17, 18, 19, 21, 26, 28, 29], "aten": [3, 5, 16, 18, 26], "specif": [3, 11, 15, 17, 19, 20, 25], "run_gen_lazy_tensor": 3, "dest": 3, "lazy_ir": 3, "class": [3, 6, 8, 11, 20, 23, 25, 29], "genlazyir": 3, "back": [3, 5, 8, 11, 15, 16, 17, 27], "todai": [3, 12], "most": [3, 6, 11, 14, 16, 21], "categori": [3, 24], "goal": [3, 4, 5, 13], "move": [3, 8, 11, 14, 16, 18, 20, 26, 29], "full_codegen": 3, "necessari": [3, 11, 16, 19], "call": [3, 5, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 26, 27, 29], "upstream": [3, 12, 21], "api": [3, 5, 10, 14, 15, 18, 20, 21, 23, 25, 26, 27, 28, 29], "xlanativefunct": [3, 5], "column": 3, "declar": [3, 5], "anoth": [3, 8, 12, 15, 16, 17, 18, 26], "wrap": [3, 5, 6, 7, 8, 11, 13, 15, 17, 19, 23, 24, 25, 27], "around": [3, 14, 18, 23, 26], "xlatensor": [3, 5, 11, 27], "construct": [3, 5, 15, 17, 23, 27, 28, 29], "aten_xla_typ": [3, 5], "Will": 3, "method": [3, 8, 11, 14, 19, 24, 27, 29], "map": [3, 5, 11], "node": [3, 5, 9, 16, 18, 26, 30], "remov": [3, 14, 16, 17], "tensor_method": [3, 5], "possibl": [3, 14, 15, 16, 17, 23, 24, 27], "multipl": [3, 8, 11, 13, 18, 21, 25, 26], "few": [3, 15, 16, 17, 18, 20, 26, 29], "simpler": [3, 14], "go": [3, 13, 15, 17, 19, 27], "unari": 3, "binari": [3, 6, 8, 21], "exampl": [3, 4, 5, 6, 8, 11, 12, 13, 14, 15, 16, 18, 20, 21, 25, 26, 27, 28, 29, 30], "characterist": 3, "fallback": [3, 5], "_adaptive_avg_pool3d": 3, "condit": [3, 18, 22, 26], "issupportedadaptivepool": 3, "xlahelp": 3, "i64list": 3, "self": [3, 5, 6, 8, 11, 17, 20, 25, 27], "size": [3, 9, 12, 14, 15, 16, 17, 18, 26, 29], "output_size_list": 3, "pool_dim": 3, "nativ": [3, 5, 13, 14, 16, 19, 20, 27], "call_fallback_fn": 3, "xla_fallback": 3, "aten_op": 3, "output_s": 3, "wip": 3, "evolv": 3, "At": [3, 6, 11], "self_tensor": 3, "static": [3, 12, 18, 26], "bool": [3, 11], "sync_upd": 3, "sys_util": 3, "getenvbool": 3, "xla_tensor_update_sync": 3, "true": [3, 11, 13, 14, 17, 18, 20, 23, 26, 27, 29], "xla_check": 3, "dst_tensor": 3, "updatefromtensor": 3, "sync": [3, 11, 13, 16, 17, 19], "complic": [3, 5, 7], "an": [3, 4, 5, 6, 7, 11, 14, 16, 17, 18, 19, 20, 21, 23, 24, 26, 27, 28, 29], "would": [3, 4, 5, 11, 14, 15, 16, 17, 18, 22, 26], "someth": [3, 17], "ab": [3, 23], "const": [3, 5], "torch_lazy_fn_count": 3, "bridg": [3, 21], "atenfromxlatensor": 3, "getxlatensor": 3, "fail": [3, 11, 15, 16, 29], "explain": [3, 6, 15, 16, 17, 18, 26, 28], "later": [3, 17], "still": [3, 14, 15, 18, 19, 20, 26, 29], "snippet": [3, 15, 27], "auto": [3, 5, 11, 23, 29], "common_devic": 3, "getxladevic": 3, "torch_internal_assert": 3, "xlatensorptr": 3, "lazy_self": 3, "getxlatensororcreateforwrappednumb": 3, "nodeptr": 3, "reusenod": 3, "getirvalu": 3, "makenod": 3, "cachenod": 3, "creat": [3, 8, 9, 11, 14, 16, 17, 19, 20, 27, 29], "std": [3, 20], "get": [3, 5, 11, 12, 13, 14, 17, 18, 20, 23, 25, 26], "check": [3, 4, 5, 11, 15, 25, 28], "reus": [3, 15, 17, 19], "previou": [3, 14, 15, 17, 18, 26], "creation": [3, 11], "If": [3, 4, 5, 8, 11, 14, 15, 16, 17, 18, 25, 26, 27], "correspond": [3, 5, 11, 17, 19, 23, 27, 28], "cach": [3, 7, 11, 12, 17], "newli": [3, 8], "And": [3, 18, 20, 26, 27], "within": [3, 8, 11, 15, 16, 17, 25, 29], "note": [3, 4, 7, 8, 9, 11, 13, 14, 15, 16, 17, 18, 21, 23, 24, 25, 26, 28], "done": [3, 4, 7, 15, 16, 17, 18, 26], "public": [3, 14], "xlanod": 3, "xlavalu": 3, "opkind": [3, 5], "absoutputshap": 3, "num_output": [3, 18, 26], "mhash": 3, "string": [3, 11, 27], "tostr": 3, "overrid": [3, 11, 19], "stringstream": 3, "ss": 3, "str": [3, 6, 11], "xlaopvector": 3, "loweringcontext": 3, "loctx": 3, "A": [3, 4, 6, 11, 14, 15, 17, 18, 19, 24, 25, 26, 27], "coupl": [3, 15, 16], "thing": [3, 16, 17], "keep": [3, 4, 12, 14, 16, 18, 26], "mind": [3, 14, 16], "clone": [3, 14, 16, 17], "even": [3, 11, 14, 15, 16, 18, 20, 26], "everi": [3, 5, 7, 8, 11, 14, 15, 16, 18, 21, 26, 27, 29], "outputshap": 3, "xla_shap": 3, "overli": 3, "simplifi": 3, "buildxxxop": 3, "slightli": [3, 5, 11], "better": [3, 5, 13, 14, 15, 16, 17, 18, 21, 22, 26], "maximumoutputshap": 3, "lower_for_shape_fn": 3, "absl": 3, "xlaop": [3, 5], "operand": 3, "promot": 3, "max": [3, 18, 26, 29], "second": [3, 9, 12, 14, 16, 17, 20, 28, 30], "inferoutputshap": 3, "comput": [3, 4, 11, 14, 15, 16, 17, 18, 19, 26, 27, 28], "logic": [3, 11, 13, 18, 22, 26, 27, 28], "two": [3, 6, 11, 14, 16, 17, 18, 26, 27, 28], "xla_input": 3, "getoutputop": 3, "returnop": 3, "buildab": 3, "origin": [3, 8, 17], "genericop": 3, "modifi": [3, 17, 19, 21, 27], "abov": [3, 5, 6, 8, 12, 13, 14, 15, 16, 17, 18, 20, 21, 26, 28], "delet": 3, "sometim": [3, 17, 18, 26], "being": [3, 11, 15, 17, 20, 28], "tensor_op": 3, "cross": [3, 15, 27], "s1": [3, 27], "sub": 3, "mul": [3, 18, 26], "u2": 3, "v3": [3, 15, 20], "u3": 3, "v2": [3, 4, 15], "irnod": 3, "those": [3, 5, 8, 11, 16, 17, 20], "long": [3, 13, 16, 17, 18, 20, 26], "term": [3, 9, 13, 16, 18, 26], "rid": [3, 18, 26], "composit": [3, 5], "end": [3, 5, 9, 11, 12, 14, 15, 16, 17, 20, 23, 24], "exp": 3, "pow": 3, "norm_exp": 3, "vector": [3, 9], "involv": [3, 18, 26, 27], "don": [3, 5, 12, 13, 14, 15, 16, 18, 23, 26], "t": [3, 5, 8, 11, 12, 13, 14, 15, 16, 18, 19, 23, 24, 26, 27, 28, 29], "build_cpp_test": 3, "skip": [3, 5, 16, 21], "desir": [3, 8, 17, 29], "test_ptxla": 3, "gtest_filt": 3, "atenxlatensortest": 3, "testab": 3, "correct": [3, 18, 26], "counter": [3, 5, 11, 16], "correctli": [3, 16, 24], "gt": [3, 4, 8, 14, 17], "erf": 3, "erfc": 3, "erfinv": 3, "pull": [3, 8, 19, 20, 23], "3659": 3, "binary_cross_entropi": [3, 19], "backward": [3, 5, 8, 13, 14, 15, 19, 20, 21, 23, 24], "3809": 3, "scalar": [3, 5, 16, 18, 26], "addcdiv": 3, "addcmul": 3, "3768": 3, "neg": 3, "index": [3, 4, 6, 11, 14, 15, 16, 17, 30], "amin": 3, "amax": 3, "3771": 3, "special": [3, 8, 9, 17, 27], "partial": [3, 18, 23, 24, 26], "adaptive_avgpool3d": 3, "3790": 3, "guid": [4, 8, 10, 14, 15, 17, 23, 24, 27], "interact": [4, 14], "start": [4, 13, 14, 15, 16, 17], "colab": [4, 16], "kaggl": 4, "preinstal": [4, 14], "ecosystem": [4, 25], "packag": [4, 9, 10, 15, 17, 19, 20], "date": 4, "list": [4, 5, 11, 17, 19, 22, 27], "readm": [4, 16, 17], "prerequisit": 4, "remot": 4, "quota": 4, "about": [4, 13, 14, 15, 17, 18, 26], "request": [4, 5, 11, 16, 17, 18, 19, 20, 26, 27], "offici": [4, 16], "ssh": [4, 14, 15, 17], "regist": [4, 5, 6, 14, 29], "agent": 4, "alreadi": [4, 7, 9, 11, 16, 17, 18, 20, 23, 26, 29], "befor": [4, 7, 8, 11, 14, 15, 16, 17, 18, 19, 20, 21, 24, 26, 27, 29], "begin": [4, 27], "zone": [4, 14, 15, 17], "tpu_typ": 4, "8": [4, 8, 9, 11, 13, 14, 15, 17, 18, 20, 21, 25, 26, 27, 28], "vm": [4, 14, 15, 16, 17, 20], "assum": [4, 6, 7, 11, 15, 18, 20, 24, 26, 27], "id_ed25519": 4, "ubuntu2204": 4, "base": [4, 11, 13, 14, 16, 17, 18, 23, 26, 27, 28], "metadata": [4, 16], "cat": [4, 19], "pub": 4, "ip": [4, 11, 14, 29, 30], "format": [4, 11, 16, 17, 21, 25], "valu": [4, 5, 8, 9, 11, 12, 14, 16, 17, 18, 22, 26, 27, 30], "networkendpoint": 4, "accessconfig": 4, "externalip": 4, "123": 4, "give": [4, 8, 16, 17, 25, 27, 28], "friendli": 4, "easier": [4, 13, 17, 18, 26], "echo": 4, "host": [4, 11, 14, 15, 16, 17, 19, 23, 29, 30], "n": [4, 11, 20, 25], "hostnam": 4, "test": [4, 6, 7, 8, 9, 12, 14, 20, 23, 30], "v": [4, 7, 8, 14, 18, 26], "palett": 4, "select": [4, 11, 14, 29], "visualstudio": 4, "doc": [4, 11, 13, 14, 15, 18, 24, 26, 27], "__": [4, 14], "just": [4, 7, 13, 14, 15, 18, 20, 23, 26, 29], "titl": [4, 14], "open": [4, 5, 6, 8, 14, 16], "window": 4, "termin": [4, 29], "mkdir": 4, "ptxla": 4, "Then": [4, 8, 17], "ui": 4, "venv": 4, "virtual": [4, 11], "latest": [4, 8], "releas": [4, 6, 7, 14, 15, 16, 17, 21, 23, 24, 25, 27], "pip": [4, 7, 8, 9, 17], "numpi": [4, 7, 8, 11, 17, 28], "f": [4, 7, 8, 11, 15, 20, 23, 25, 29], "googleapi": [4, 7, 17], "libtpu": [4, 6, 14], "html": [4, 7, 14, 23], "import": [4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 27, 28, 29], "set_device_typ": 4, "print": [4, 8, 11, 12, 14, 15, 16, 17, 18, 20, 21, 26, 27, 29], "real_devic": 4, "run": [4, 5, 7, 9, 10, 11, 12, 13, 14, 18, 19, 20, 21, 25, 26, 29], "2": [4, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 21, 22, 23, 25, 27, 30], "3": [4, 5, 6, 7, 8, 9, 11, 13, 16, 17, 21, 22, 23, 25, 27], "4": [4, 6, 7, 8, 11, 14, 15, 16, 17, 18, 21, 22, 23, 25, 26, 27, 28], "5": [4, 8, 11, 12, 16, 17, 18, 20, 23, 25, 26], "7": [4, 11, 16, 20, 21], "number": [4, 9, 11, 12, 13, 14, 16, 17, 23, 27, 28], "vari": [4, 14, 18, 24, 26], "That": [4, 18, 26], "now": [4, 8, 9, 13, 14, 15, 17, 18, 26, 27], "realist": 4, "librari": [5, 6, 17, 28, 29], "offer": [5, 8, 24, 25], "implement": [5, 7, 8, 13, 14, 16, 18, 21, 23, 24, 26], "xla": [5, 8, 9, 12, 13, 14, 18, 20, 22, 24, 29, 30], "its": [5, 8, 12, 14, 15, 16, 20, 21, 23, 27, 28], "convert": [5, 11, 15, 20], "higher": [5, 16, 29], "level": [5, 16, 17, 21, 25, 29], "represent": [5, 11, 15, 17, 28], "hlo": [5, 11, 15, 16, 17], "beyond": 5, "scope": 5, "forward": [5, 8, 13, 19, 20, 21, 24, 25], "haven": [5, 18, 26], "yet": 5, "caus": [5, 11, 13, 14, 15, 16, 17, 18, 19, 26], "signific": [5, 16, 17, 21], "slowdown": [5, 16, 20], "must": [5, 6, 11, 14, 15, 16, 24, 29, 30], "best": [5, 7, 21, 25], "perform": [5, 7, 8, 9, 11, 13, 15, 19, 20, 21, 23, 25, 27], "what": [5, 15, 17], "debug": [5, 13, 18, 25, 26], "pt": [5, 14, 15, 16, 17], "profil": [5, 14], "_ctc_loss": [5, 16], "_ctc_loss_backward": [5, 16], "contribut": 5, "definit": [5, 15, 18, 26], "native_funct": 5, "after": [5, 8, 11, 14, 15, 16, 17, 18, 22, 26, 27], "kernel": [5, 8, 10, 18, 25, 26], "aten_fallback": 5, "h": 5, "search": 5, "repo": [5, 15, 16, 17, 20], "sequenc": [5, 11], "explicitli": [5, 15, 16, 17, 18, 19, 26], "compos": 5, "match": [5, 8, 11, 15, 16], "serv": 5, "interfac": [5, 6, 15, 16, 24, 29], "machineri": 5, "registerxla": 5, "registerautogradxla": 5, "entri": [5, 6, 8], "pytorch_xla": 5, "world": [5, 7, 14, 18, 21, 26, 29], "written": [5, 17, 29], "paramet": [5, 11, 14, 15, 16, 19, 20, 24, 27, 29, 30], "result": [5, 11, 12, 14, 15, 16, 17, 20, 22, 27], "dispatch": [5, 29], "wrapper": [5, 15, 20, 23, 24], "inplac": [5, 11, 27], "ir": [5, 8, 11, 16, 17, 18, 26], "insid": [5, 8, 15, 17, 27], "stand": 5, "intermedi": [5, 14, 16, 17], "smaller": [5, 17, 18, 26], "inherit": 5, "dai": 5, "addit": [5, 6, 9, 14, 15, 16, 17, 19, 20], "unless": [5, 16, 18, 26], "want": [5, 11, 13, 14, 15, 16, 17, 18, 21, 26, 27, 30], "verifi": 5, "test_oper": 5, "test_aten_xla_tensor": 5, "yield": [5, 15, 16], "break": [5, 17, 18, 26], "grasp": 5, "capabl": 5, "how": [5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 26, 27, 30], "similar": [5, 14, 17, 20, 22, 25], "minim": [5, 17], "pr": [5, 16, 23], "vanilla": 5, "lerp": 5, "variant": [5, 11, 18, 19, 26], "lerp_": 5, "scalar_out": 5, "tensor_out": 5, "prototyp": [5, 8, 27], "weight": [5, 8, 11, 16, 24, 25], "lerp_out": 5, "howev": [5, 7, 8, 16, 17, 27], "namespac": [5, 16], "wrapper_scalar_lerp": 5, "No": [5, 12, 14, 18, 25, 26], "deviceguard": 5, "omit": [5, 14, 28, 30], "anonym": 5, "wrapper_scalar_lerp_": 5, "wrapper_scalar_lerp__tmp": 5, "_copy_from": 5, "m": [5, 8, 18, 23, 26], "impl": [5, 8], "torch_fn": 5, "automat": [5, 6, 10, 11, 14, 15, 16, 17, 18, 23, 26, 28, 29], "u": [5, 14, 16, 17, 18, 21, 26], "explicit": [5, 19, 23], "place": [5, 11, 17, 19, 27, 29], "ll": [5, 18, 26], "interned_str": 5, "symbol": [5, 18, 26], "submit": [5, 16, 17, 19], "team": [6, 21], "direclti": 6, "tf": [6, 16, 18, 26], "close": 6, "expos": [6, 14, 15, 17, 27], "deviceplugin": 6, "handl": [6, 13, 16, 18, 23, 24, 26, 27, 28], "short": [6, 16, 18, 26], "pjrtclient": 6, "mirror": 6, "pjrt_api": 6, "straightforward": [6, 11, 17], "detail": [6, 7, 8, 11, 12, 14, 15, 16, 17, 18, 26], "concret": [6, 18, 26], "placehold": 6, "pjrt_library_path": 6, "extra": [6, 20, 24], "multiprocess": [6, 11, 14, 15], "compon": 6, "least": [6, 17], "cpuplugin": 6, "def": [6, 7, 8, 9, 11, 13, 14, 15, 17, 20, 21, 22, 24, 25], "library_path": 6, "o": [6, 8, 14, 20], "join": [6, 11], "dirnam": 6, "__file__": 6, "pjrt_c_api_cpu_plugin": 6, "identifi": [6, 11, 29], "exmapl": 6, "pyproject": 6, "toml": 6, "torch_xla_cpu_plugin": 6, "With": [6, 7, 8, 12, 14, 18, 21, 26], "initi": [6, 8, 11, 14, 15, 17, 20, 22, 29], "experiment": [6, 7, 8, 9, 10, 12, 13, 14, 15, 20, 21, 22, 24, 27, 29], "state": [6, 11, 23], "becom": [6, 7, 8, 14, 16, 17, 18, 26], "stabl": [6, 14, 23], "rise": 7, "openai": [7, 9], "triton": [7, 10], "popular": 7, "commun": [7, 11, 14, 15, 17, 21, 28], "instanc": [7, 11, 23, 29], "order": [7, 11, 15, 16, 17, 27, 28], "pariti": 7, "continu": [7, 14, 21], "push": 7, "let": [7, 14, 15, 16, 17, 21, 28], "custom_kernel": 7, "jax_import_guard": 7, "pl": [7, 14, 15, 27], "jnp": 7, "add_vectors_kernel": 7, "x_ref": 7, "y_ref": 7, "o_ref": 7, "x": [7, 8, 9, 11, 15, 16, 17, 18, 20, 22, 23, 24, 25, 26, 27, 28], "y": [7, 9, 11, 16, 17, 18, 23, 24, 25, 26, 27], "jit": [7, 9, 21], "add_vector": 7, "arrai": [7, 11, 17, 24, 28], "pallas_cal": 7, "out_shap": 7, "shapedtypestruct": 7, "dtype": [7, 8, 9, 14, 18, 19, 25, 26], "otherwis": [7, 11, 16, 17, 18, 24, 26], "program": [7, 8, 9, 11, 16, 17, 18, 21, 26, 27, 28], "hang": 7, "lock": 7, "q": [7, 8], "randn": [7, 8, 11, 13, 14, 15, 20, 21, 25, 27, 28], "128": [7, 8, 14, 23, 25, 30], "k": [7, 8, 16], "make_kernel_from_palla": 7, "pt_kernel": 7, "lambda": [7, 23], "liner": 7, "flash": [7, 9], "attent": [7, 9], "besid": 7, "op": [7, 8, 10, 11, 13, 16, 17, 18, 19, 26, 27, 28], "suppor": 7, "flash_attent": 7, "paged_attent": 7, "queri": [7, 14], "squeez": 7, "dim": [7, 11], "key_cach": 7, "value_cach": 7, "context_len": 7, "block_tabl": 7, "pages_per_compute_block": 7, "megacore_mod": 7, "none": [7, 8, 11, 16, 24, 27, 28], "vllm": 7, "util": [7, 10, 11, 15, 16, 20, 23, 24, 25, 29], "effect": [7, 11], "memori": [7, 10, 11, 12, 16, 17, 18, 23, 26], "kv": 7, "proper": [7, 28], "jax_nightly_releas": 7, "jaxlib_nightly_releas": 7, "exported_program_to_stablehlo": 8, "xla_model": [8, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 28], "xm": [8, 11, 13, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 28], "torchvis": [8, 13, 21], "xla_devic": [8, 11, 14, 15, 16, 17, 19, 20, 21, 22, 25, 28], "resnet18": [8, 13, 21], "sampl": [8, 11, 14, 16], "tupl": [8, 11, 18, 22, 24, 26, 28], "sample_input": 8, "224": [8, 13], "stablehlo_program": 8, "callabl": [8, 11, 23], "get_stablehlo_text": 8, "get_stablehlo_bytecod": [8, 11], "sample_input_xla": 8, "output2": 8, "allclos": 8, "atol": 8, "1e": [8, 16, 21], "One": [8, 11, 12, 17, 23], "tmp": [8, 15, 16, 23], "stablehlo_dir": 8, "empti": [8, 11], "doesn": [8, 15, 16, 18, 24, 26], "load": [8, 9, 11, 14, 16, 20, 23, 25, 29], "stablehlographmodul": 8, "stablehlo_program2": 8, "output3": 8, "server": [8, 11, 14, 17], "env": [8, 11, 14, 27], "nightli": [8, 16, 17, 23, 27], "resnet_tf": 8, "p": [8, 14, 16, 18, 26], "8500": 8, "mount": [8, 15], "model_nam": 8, "accomplish": 8, "tf_saved_model_integr": 8, "save_torch_module_as_tf_saved_model": 8, "nn": [8, 11, 14, 15, 20, 21, 23, 25, 27], "trace": [8, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 26, 27], "exported_model": 8, "exportedprogram": 8, "pathlik": 8, "stablehloexportopt": 8, "alias": [8, 16, 19], "save_torch_model_as_stablehlo": 8, "torchmodel": 8, "arg": [8, 11, 15, 17, 22, 23], "constant": [8, 16, 17, 27], "ndarrai": [8, 11], "human": 8, "readabl": [8, 17], "mlir": 8, "form": [8, 14, 16, 18, 26, 30], "posit": [8, 11], "argument": [8, 9, 11, 17, 19, 21, 23], "meta": 8, "l__fn___layers_15_feed_forward_w2": 8, "l__fn___layers_13_feed_forward_w1": 8, "l__fn___layers_3_attention_wo": 8, "l__fn___layers_12_ffn_norm_weight": 8, "l__fn___layers_25_attention_wo": 8, "serial": [8, 14, 15], "stablehlofunc": 8, "stage": 8, "guarante": [8, 11], "plan": [8, 12, 14], "major": 8, "agre": [8, 17], "scaled_dot_product_attent": 8, "decompos": 8, "low": [8, 12, 16], "dure": [8, 11, 16, 17, 21, 23, 27], "lower": [8, 10, 16, 17, 18, 19, 26], "captur": [8, 11, 16, 17], "downstream": [8, 19], "ml": [8, 28], "crucial": 8, "geneart": 8, "while": [8, 11, 17, 18, 20, 26], "pattern": [8, 16, 18, 21, 26], "bunch": 8, "challeng": 8, "error": [8, 11, 15, 16], "prone": 8, "robust": 8, "outlin": [8, 25], "stablehlocompositebuild": 8, "arbitari": 8, "region": [8, 11, 13, 16, 19, 27], "non": [8, 11, 13, 18, 19, 26, 28], "hardcod": [8, 27], "store": [8, 9, 11, 16], "attribut": 8, "retriev": [8, 11, 15, 18, 21, 26, 27], "show": [8, 14, 15, 16, 20], "pratic": 8, "scaled_product_attent": 8, "mark_pattern_util": 8, "__init__": [8, 20, 25], "super": [8, 20, 21], "q_proj": 8, "linear": [8, 11, 14, 15, 19, 20, 25], "bia": 8, "fals": [8, 11, 15, 23, 27], "k_proj": 8, "v_proj": 8, "builder": 8, "b": [8, 11, 14, 17, 18, 19, 21, 26, 28], "sdpa": 8, "25": [8, 12], "other_attr": 8, "val": 8, "mark_input": 8, "attn_out": 8, "mark_output": 8, "input_arg": 8, "10": [8, 11, 14, 15, 16, 17, 18, 20, 21, 22, 25, 26, 29], "stablehlo_gm": 8, "shown": [8, 14, 18, 26], "irtohlo": 8, "56": 8, "mhlo": 8, "cross_program_prefetch": 8, "input_output_alia": 8, "is_dynam": 8, "use_auto_spmd_partit": 8, "func": 8, "arg0": 8, "10x8x128xf32": 8, "arg1": 8, "128x128xf32": 8, "arg2": 8, "arg3": 8, "9": [8, 17, 18, 20, 23, 26], "composite_attribut": 8, "500000e": 8, "01": [8, 9], "f32": 8, "decomposit": 8, "11": [8, 16, 18, 26], "privat": [8, 14], "actual": [8, 13, 17, 18, 20, 26, 27], "encapsul": 8, "propag": [8, 16], "high": [9, 12, 17, 20, 25], "deep": [9, 10, 16], "learn": [9, 14], "languag": 9, "empow": 9, "full": [9, 11, 15, 16, 23], "potenti": [9, 11, 14, 16, 24], "oper": [9, 10, 11, 14, 15, 16, 17, 29], "given": [9, 11, 16, 17, 18, 20, 23, 26, 28], "add_kernel": 9, "x_ptr": 9, "pointer": 9, "y_ptr": 9, "output_ptr": 9, "n_element": 9, "block_siz": 9, "tl": 9, "constexpr": 9, "element": [9, 11, 18, 24, 26, 27], "blob": [9, 11, 14, 27], "tutori": [9, 16, 17, 20, 27], "l28": 9, "pid": 9, "program_id": 9, "axi": [9, 11, 24], "block_start": 9, "offset": 9, "arang": 9, "mask": [9, 16, 18, 26], "xla_triton": 9, "16": [9, 15, 17, 23, 25, 28], "int64": 9, "empty_lik": 9, "grid": 9, "cdiv": 9, "triton_cal": 9, "itself": [9, 11, 23], "kwarg": [9, 11, 23, 27], "payload": [9, 11, 14], "regard": [9, 15, 21], "buffer": [9, 11], "_xla_gpu_custom_cal": 9, "dep": 9, "connect": [10, 11, 14, 27], "overview": [10, 28], "eager": [10, 11, 18, 20, 25, 26], "mode": [10, 11, 18, 20, 25, 26, 27, 29], "troubleshoot": 10, "palla": 10, "stablehlo": [10, 11], "mix": [10, 11, 28], "precis": 10, "spmd": [10, 15, 17, 29], "advanc": [10, 28], "topic": [10, 28], "distribut": [10, 15, 16, 20, 23, 24, 27, 28], "checkpoint": [10, 14, 17, 23, 28], "distributeddataparallel": [10, 14], "ddp": [10, 14], "torchdynamo": 10, "while_loop": 10, "shard": [10, 11, 28, 29], "quantiz": 10, "recompil": [10, 12, 13, 15, 16, 17], "hardwar": [10, 11, 16, 17, 19], "plugin": [10, 14], "bazel": 10, "int": [11, 14, 18, 26, 27], "device_count": [11, 27], "address": [11, 14, 27, 30], "wait": [11, 16, 17], "pend": [11, 13], "whether": [11, 15, 19], "block": [11, 17, 23, 27], "finish": [11, 17], "full_graph": 11, "num_different_graphs_allow": 11, "lazytensor": [11, 13, 17], "repres": [11, 14, 18, 26], "happen": [11, 13, 14, 15, 16, 17, 18, 26], "decid": [11, 16, 18, 26], "funciton": 11, "act": [11, 15], "context": [11, 14, 16, 18, 19, 26], "throw": [11, 15], "info": [11, 16, 18, 26, 28], "exit": [11, 16, 19, 20], "pt_xla_debug": 11, "messag": [11, 16], "dump": [11, 16], "allow": [11, 15, 16, 17, 19, 27, 28, 29], "rais": [11, 16], "limit": [11, 14, 15], "exceed": 11, "usag": [11, 16, 17, 18, 23, 24, 26, 29], "foo": 11, "sin": 11, "co": 11, "foo2": 11, "compiled_foo2": 11, "manual_se": [11, 14], "seed": 11, "random": [11, 13, 14, 17, 25], "integ": [11, 16], "rng": [11, 14], "device_typ": 11, "local_process_count": 11, "local_device_count": 11, "total": [11, 18, 26, 28], "addressable_device_count": 11, "visibl": [11, 18, 26], "global_device_count": 11, "across": [11, 14, 15, 16, 23, 28], "global_runtime_device_count": [11, 24, 27, 28], "especi": [11, 14, 17, 21, 27], "world_siz": [11, 14, 19, 20, 23, 27], "particip": [11, 14], "job": [11, 17, 21], "global_ordin": [11, 14, 15, 20, 23], "global": [11, 14, 15, 27, 29], "ordin": [11, 15], "thread": [11, 14, 15, 16, 29], "predict": 11, "relationship": [11, 15, 16], "worker": [11, 14, 15, 17, 23, 29], "id": [11, 14, 16, 17], "nor": 11, "contigu": [11, 15, 16], "local_ordin": 11, "get_master_ip": 11, "master": [11, 14, 15, 29], "discoveri": 11, "use_spmd": [11, 27, 28, 29], "forc": [11, 14, 16, 18, 22, 26], "mean": [11, 14, 15, 16, 17, 18, 20, 24, 26, 27], "replic": [11, 27, 28], "spmd_advanc": 11, "md": [11, 14], "is_spmd": 11, "initialize_cach": [11, 15], "readonli": [11, 15], "persist": [11, 15, 29], "devkind": 11, "cuda": [11, 14, 15, 17, 18, 19, 25, 26, 30], "deprec": 11, "xla_device_hw": 11, "union": 11, "real": [11, 21], "is_master_ordin": 11, "multi": [11, 12, 27, 30], "num_host": 11, "boolean": 11, "indic": [11, 16, 17, 18, 26], "all_reduc": [11, 19], "reduce_typ": 11, "float": [11, 18, 19, 26], "group": [11, 14, 20, 27], "pin_layout": 11, "reduc": [11, 12, 13, 14, 15, 16, 17, 23], "reduce_sum": 11, "reduce_mul": 11, "reduce_and": 11, "reduce_or": 11, "reduce_min": 11, "reduce_max": 11, "appli": [11, 19, 23, 24, 29], "replica": [11, 14], "layout": [11, 25], "pine": 11, "prevent": [11, 17, 19, 21, 27], "corrupt": 11, "unpin": 11, "hlomodul": 11, "constrain": [11, 14], "hold": [11, 27, 28], "all_gath": [11, 14], "gather": [11, 27], "along": [11, 23], "dimens": [11, 12, 27, 28], "all_to_al": 11, "split_dimens": 11, "concat_dimens": 11, "split_count": 11, "alltoal": 11, "www": 11, "org": [11, 14, 23], "operation_semant": 11, "upon": 11, "split": 11, "concat": 11, "count": [11, 16], "add_step_closur": 11, "closur": 11, "run_async": 11, "step": [11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 26, 27, 29], "mani": [11, 14, 16, 17, 18, 26, 30], "report": 11, "consol": 11, "post": [11, 16], "tensorboard": [11, 17], "etc": [11, 13, 16, 18, 26, 27], "intermediari": 11, "inspect": 11, "typic": 11, "barrier": [11, 14, 15, 17], "materi": [11, 16, 17, 18, 26, 27], "queu": 11, "though": [11, 15, 20], "advis": 11, "throttl": 11, "event": 11, "asynchron": [11, 27, 29], "wait_device_op": 11, "async": [11, 21], "whose": [11, 12], "optimizer_step": [11, 15, 17, 19, 20, 23], "optimizer_arg": 11, "dict": [11, 23], "gradid": 11, "parallelload": [11, 27], "dataparallel": 11, "loader": [11, 16, 17, 21], "dictionari": 11, "gradient": [11, 15, 19, 23, 29], "save": [11, 16, 23, 29], "file_or_path": 11, "textio": 11, "master_onli": [11, 23], "global_mast": 11, "transfer": [11, 14, 16, 17, 27], "care": [11, 15, 16, 18, 26], "taken": [11, 15, 16, 18, 20, 26, 29], "view": [11, 15, 16], "recreat": [11, 15], "destin": [11, 15], "nest": [11, 23], "locat": 11, "control": [11, 12, 15, 16, 27], "obj_to_sav": 11, "path_to_sav": 11, "rendezv": 11, "tag": [11, 14], "byte": 11, "mesh": [11, 14, 24], "reach": 11, "xla_rendezv": 11, "sent": [11, 16], "exchang": 11, "mesh_reduc": 11, "reduce_fn": 11, "toxlatensorarena": 11, "reduct": 11, "receiv": 11, "copi": [11, 14, 15, 16, 17], "np": [11, 24, 28], "accuraci": [11, 20, 23], "test_accuraci": 11, "set_rng_stat": 11, "get_rng_stat": 11, "get_memory_info": 11, "memoryinfo": 11, "bytes_us": 11, "290816": 11, "bytes_limit": 11, "34088157184": 11, "peak_bytes_us": 11, "500816": 11, "get_stablehlo": 11, "var": [11, 27], "xla_hlo_debug": [11, 16], "root": [11, 18, 26], "bytecod": [11, 21], "parallel_load": [11, 14, 15, 16], "mpdeviceload": [11, 15, 17, 27], "dataload": [11, 15, 17, 20, 27, 29], "background": [11, 29], "upload": [11, 17, 27], "per_device_load": [11, 27], "constructor": 11, "train_device_load": 11, "train_load": [11, 15, 27], "xla_multiprocess": 11, "spawn": [11, 14, 15, 17], "fn": 11, "nproc": [11, 14], "daemon": 11, "start_method": 11, "moment": 11, "maximum": [11, 12, 17, 25], "valueerror": 11, "mark_shard": [11, 24, 27, 28], "xlashardedtensor": [11, 29], "partition_spec": [11, 27, 28], "annot": [11, 27, 28], "partit": [11, 27], "spec": [11, 27], "intern": [11, 14, 15, 16, 18, 26, 27, 30], "spmdpartition": [11, 27], "topologi": [11, 15, 27, 28], "device_mesh": [11, 27], "rank": [11, 14, 20, 23, 28, 29], "mesh_shap": [11, 24, 27, 28], "ax": [11, 27, 28], "impact": [11, 14, 16, 18, 20, 26], "dynamo_custom_op": 11, "dynamo": [11, 17, 21, 25], "recogniz": 11, "traceabl": 11, "xr": [11, 14, 15, 19, 20, 23, 24, 27, 28, 29], "num_devic": [11, 24, 27, 28], "device_id": [11, 24, 27, 28], "32": [11, 16, 17], "clear_shard": 11, "clear": 11, "cast": [11, 19], "t1": [11, 15, 16, 28], "get_1d_mesh": 11, "set_global_mesh": 11, "get_global_mesh": 11, "axis_nam": [11, 27], "v4": [11, 13, 14, 15, 17, 21, 27], "ravel": 11, "reshap": 11, "fill": 11, "assign": [11, 15, 17], "Its": 11, "length": [11, 18, 26], "len": [11, 17], "get_xla_supported_devic": 11, "get_logical_mesh": 11, "ordereddict": [11, 27, 28], "hybridmesh": [11, 27], "ici_mesh_shap": [11, 27], "dcn_mesh_shap": [11, 27], "hybrid": 11, "ici": 11, "dcn": [11, 27], "increas": 11, "intens": 11, "mdl": 11, "inner": [11, 23, 27], "outer": [11, 23, 24, 27], "slice": [11, 17, 27], "fsdp": [11, 23, 24, 27, 28], "eager_mod": [11, 13], "wa": [11, 14, 16, 17, 18, 26, 29], "d": [11, 12, 18, 19, 26], "eagerli": [11, 13, 15, 16, 18, 26], "metric": [11, 20], "metrics_report": [11, 16], "short_metrics_report": [11, 16], "counter_nam": 11, "metric_nam": 11, "activ": [11, 15, 16, 20, 23, 24, 25], "counter_valu": 11, "metric_data": 11, "total_sampl": 11, "accumul": 11, "retain": 11, "circular": 11, "sum": [11, 19, 23, 24], "natur": 12, "in_tensor": 12, "randint": [12, 25], "out_tensor": 12, "nonzero": [12, 16, 17, 18, 26], "word": [12, 18, 26], "further": [12, 17, 20], "categor": 12, "unbound": 12, "alloc": 12, "infinit": [12, 24], "phase": 12, "layer": [12, 13, 23, 24, 27], "perceptron": 12, "mlp": 12, "xla_experiment": 12, "masked_select": 12, "masked_scatt": 12, "your_script": [12, 17], "100": [12, 16, 23], "29": [12, 20, 21], "49": [12, 21], "20": [12, 15, 16, 20, 25], "03": 12, "102": 12, "hit": [12, 18, 26], "198": 12, "1953": 12, "motiv": 12, "excess": 12, "between": [12, 14, 15, 16, 17, 18, 19, 20, 22, 26, 27], "figur": [12, 28], "half": 12, "drop": [12, 16], "try": [12, 16, 17, 18, 26], "python3": [12, 14, 15, 16, 17, 23], "test_dynamic_shape_model": 12, "testdynamicshapemodel": 12, "test_backward_pass_with_dynamic_input": 12, "expand": [12, 21], "feel": [12, 16, 20], "review": [12, 24], "rfc": [12, 27, 30], "align": 13, "64": [13, 21, 23], "mark_step": [13, 14, 15, 16, 17, 20], "drawback": 13, "approach": [13, 18, 20, 23, 26], "often": [13, 16, 18, 26], "confus": 13, "preprocess": [13, 25], "small": [13, 16, 17, 18, 20, 21, 26], "leak": 13, "expens": [13, 16, 18, 26], "hard": [13, 18, 20, 21, 26], "why": [13, 18, 26], "mitig": 13, "ux": 13, "mark": [13, 15], "compiled_model": 13, "right": [13, 18, 21, 26], "awai": 13, "pretti": [13, 15, 18, 20, 26], "straight": 13, "enter": 13, "reenabl": 13, "perfomr": 13, "compar": [13, 14, 15, 16, 20, 21, 22], "recommen": 13, "overhad": 13, "step_fn": 13, "loss_fn": [13, 14, 15, 19, 20, 21], "zero_grad": [13, 14, 15, 19, 20], "logit": [13, 24], "loss": [13, 14, 15, 19, 21, 23, 24], "ask": [13, 16, 18, 26], "refactor": 13, "decod": 13, "much": [13, 14, 15, 17, 18, 21, 26], "llama2": 13, "fake": [13, 29], "chip": [13, 14], "300": [13, 16], "observ": [13, 14, 20], "147": 13, "65": [13, 16], "45": 13, "train_decoder_only_bas": [13, 16], "perfomran": 13, "tri": [13, 17], "resnet50": [13, 14, 15, 21, 23], "exepct": 13, "loop": [13, 15, 16, 17, 18, 24, 26, 29], "meant": 13, "encount": [14, 16, 17], "bug": [14, 16, 20], "r2": [14, 16, 27], "init": [14, 15, 20, 21, 22], "renam": 14, "xla_backend": [14, 20, 29], "torchrun": [14, 15, 30], "init_method": [14, 29], "xpu": 14, "neuron": 14, "xrt_tpu_config": 14, "30": [14, 23], "thousand": 14, "preview": 14, "safe": 14, "section": [14, 15, 16, 17, 27], "broadcast": 14, "broadcast_master_param": 14, "pjrt_backend": 14, "diff": [14, 17], "dist": [14, 20, 29], "_mp_fn": [14, 15], "init_process_group": [14, 20, 29], "42": 14, "gradient_as_bucket_view": [14, 20], "mseloss": [14, 20], "sgd": [14, 15, 19, 20, 21], "lr": [14, 15, 20, 21, 23, 24], "001": [14, 20], "confirm": 14, "__name__": [14, 15, 20], "__main__": [14, 15, 20], "localservic": 14, "localhost": [14, 20], "51011": 14, "master_addr": [14, 20], "master_port": [14, 20], "12355": [14, 20, 30], "Or": [14, 15, 18, 26], "overhead": [14, 20, 21], "grpc": 14, "torchbench": 14, "35": [14, 16], "tpuvm": [14, 15, 17, 27], "2048": 14, "mnist": [14, 15, 16, 19], "test_train_mp_mnist": [14, 20], "fake_data": [14, 16, 20, 30], "alpha": [14, 15], "central2": [14, 17], "git": [14, 16, 17, 23], "depth": [14, 16], "branch": [14, 16, 18, 26], "test_train_mp_imagenet": [14, 16, 20], "batch_siz": [14, 23, 30], "256": 14, "num_epoch": [14, 20, 23], "By": [14, 18, 26], "tpu_process_bound": 14, "tpu_visible_chip": 14, "r1": 14, "13": [14, 15, 20, 22], "docker_imag": 14, "gcr": 14, "io": [14, 23], "sudo": [14, 17], "rm": 14, "privileg": 14, "net": [14, 17, 19], "gpu_num_devic": 14, "nnode": [14, 30], "num_gpu_devic": 14, "pjrt_distribut": 14, "physic": [14, 27, 28], "12": [14, 16, 21, 23], "number_gpu_vm": [14, 30], "node_rank": [14, 30], "current_node_rank": 14, "nproc_per_nod": [14, 30], "number_local_gpu_devic": 14, "rdzv_endpoint": [14, 30], "internal_ip_address": 14, "multinode_train": 14, "endpoint": [14, 30], "machine_0": 14, "machine_1": 14, "machine_0_internal_ip_address": [14, 30], "ident": 14, "page": 14, "although": [14, 18, 26], "mostli": [14, 23], "interchang": 14, "perspect": [14, 15], "subtl": 14, "importantli": 14, "architectur": [14, 23], "thu": [14, 16], "batch": [14, 15, 16, 17, 27], "latenc": 14, "deseri": 14, "send": [14, 15, 17, 27], "direct": [14, 16], "independ": [14, 15, 16], "constraint": [14, 16], "significantli": [14, 15, 17], "xla_dist": 14, "scp": [14, 15], "sdk": 14, "reimplement": 14, "collect": [14, 20, 21, 28, 29], "enhanc": 14, "stabil": [14, 16, 19], "xmp": [14, 15, 17], "substanti": 14, "practic": [14, 18, 24, 26], "unreli": 14, "due": [14, 16, 17, 30], "inbound": 14, "could": [14, 17, 18, 26, 27], "failur": 14, "entir": [14, 23], "restart": 14, "impos": 14, "middl": [14, 17, 18, 26], "unwant": 14, "permit": 14, "subset": 14, "old": 14, "alter": 14, "synchron": [14, 15, 17, 27, 29], "consid": [14, 17], "all_gather_object": 14, "gloo": [14, 20, 29], "subgroup": 14, "monitor": 14, "_": [14, 21, 22], "altern": [14, 18, 19, 25, 26], "less": [14, 18, 21, 26], "reliabl": 14, "than": [14, 16, 18, 20, 23, 26], "strongli": 14, "_all_gath": 14, "int32": 14, "zeros_lik": 14, "get_world_s": 14, "averag": 14, "task": 14, "175": 14, "chart": 14, "breakdown": 14, "tfrt": 14, "legaci": 14, "streamexecutor": 14, "tpu_legaci": 14, "comparison": [14, 28], "regular": [15, 16, 17, 25], "t0": 15, "matrix": 15, "multipli": [15, 28], "mm": [15, 19], "neural": 15, "l_in": 15, "l_out": 15, "floattensor": 15, "highlight": [15, 17], "nllloss": 15, "momentum": 15, "switch": [15, 16, 18, 20, 26], "acquir": 15, "mp_device_load": 15, "three": 15, "multithread": [15, 16], "own": [15, 23], "onto": 15, "preload": [15, 17], "overlap": [15, 17, 21, 27], "batches_per_execut": 15, "consolid": [15, 23], "all_reduce_gradi": 15, "remain": [15, 17, 18, 26, 30], "parent": 15, "talk": 15, "basi": 15, "howto": 15, "focu": [15, 18, 26], "train_mnist_xla": 15, "outsid": 15, "infrastructur": 15, "awar": 15, "fakedata": 15, "But": [15, 16, 18, 26], "immedi": [15, 27], "hand": 15, "record": [15, 16, 17], "defer": 15, "fuse": [15, 17], "invis": 15, "caller": 15, "insert": [15, 17], "paper": 15, "opaqu": [15, 16], "appear": [15, 16, 17], "unlik": [15, 17], "adjust": 15, "preserv": [15, 16], "appreci": 15, "accommod": 15, "previous": 15, "state_dict": [15, 23, 29], "footprint": 15, "xser": 15, "stream": 15, "amount": [15, 16, 17, 18, 26], "restor": 15, "load_state_dict": [15, 29], "unavail": [15, 16], "consum": [15, 18, 26], "disk": 15, "occur": 15, "opt": 15, "your_cache_path": 15, "mp_fn": 15, "xla_cache_": 15, "runnabl": [15, 20, 24], "subject": 16, "peculiar": 16, "detial": 16, "__version__": 16, "cu121": 16, "t2": [16, 28], "200": 16, "rx": 16, "conclud": 16, "diagnos": 16, "extrem": 16, "pt_xla_debug_level": 16, "slip": 16, "analyz": [16, 17], "summari": 16, "compiletim": 16, "frequent": 16, "21": 16, "transferfromdevicetim": 16, "23": 16, "hash": 16, "c74c3b91b855b2b123f833b0d5f86943": 16, "107": 16, "frame": 16, "trigger": [16, 17, 18, 26], "dk3": 16, "1055": 16, "44": 16, "__next__": 16, "train_loop_fn": 16, "48": [16, 20], "start_train": 16, "73": 16, "548000": 16, "gb": 16, "922460": 16, "547871": 16, "124478": 16, "028210": 16, "steptrac": 16, "frequenc": 16, "pair": 16, "met": 16, "spent": [16, 17], "destroi": 16, "percentil": 16, "totalsampl": 16, "202": 16, "06m09s401ms746": 16, "001u": 16, "valuer": 16, "778ms572": 16, "062u": 16, "rate": [16, 20], "425201": 16, "001ms32": 16, "778u": 16, "001ms61": 16, "283u": 16, "001ms79": 16, "236u": 16, "001ms110": 16, "973u": 16, "50": [16, 17, 22], "001ms228": 16, "773u": 16, "80": 16, "001ms339": 16, "183u": 16, "90": 16, "001ms434": 16, "305u": 16, "95": 16, "002ms921": 16, "063u": 16, "99": [16, 20], "21s102ms853": 16, "173u": 16, "cachedsynctensor": 16, "395": [16, 20], "area": 16, "rout": 16, "qualifi": 16, "33": [16, 20, 21], "_local_scalar_dens": 16, "epoch": [16, 17, 23], "clear_al": 16, "xla_dynamo_debug": 16, "bottleneck": [16, 17], "notebook": 16, "train_resnet_benchmark": 16, "behav": 16, "evalu": [16, 17, 18, 26], "suggest": 16, "certain": [16, 18, 19, 26], "bad": 16, "degrad": [16, 17], "speedup": [16, 21], "indirect": 16, "solut": [16, 18, 25, 26], "variat": 16, "pad": [16, 17, 18, 26], "fix": [16, 17, 21, 24], "translat": 16, "item": [16, 17], "substitut": 16, "flow": 16, "clip_grad_norm": 16, "problemat": 16, "clip_grad_norm_": 16, "dramat": 16, "total_norm": 16, "zero": [16, 23, 29], "param_norm": 16, "grad": 16, "norm": 16, "norm_typ": 16, "add_": 16, "clip_coef": 16, "max_norm": 16, "mul_": 16, "data_parallel": 16, "last": 16, "dataset": [16, 20, 23], "stride": 16, "reconstruct": 16, "shallow": 16, "ty": 16, "made": [16, 17, 18, 26, 27], "_get_xla_tensors_text": [16, 18, 26], "_get_xla_tensors_hlo": 16, "prior": [16, 29], "degre": 16, "xla_ir_debug": 16, "henc": [16, 21], "respons": [16, 17, 21, 29], "xla_save_tensors_fil": 16, "realli": [16, 18, 21, 26], "big": [16, 18, 26], "left": 16, "append": 16, "sheet": 16, "xla_save_tensors_fmt": 16, "text": 16, "dot": 16, "graphviz": 16, "xla_flag": 16, "xla_dump_to": 16, "dir_nam": 16, "unoptim": 16, "optimz": 16, "xla_metrics_fil": 16, "xla_save_hlo_fil": 16, "offend": 16, "xla_sync_wait": 16, "xla_use_eager_debug_mod": 16, "bypass": 16, "overal": [16, 17], "optimizaiton": 16, "tf_cpp_log_thread_id": 16, "tf_cpp_vmodul": 16, "vlog": 16, "tf_cpp_min_log_level": 16, "turn": 16, "warn": 16, "tf_vlog": 16, "xla_dump_hlo_graph": 16, "xla_util": 16, "cc": 16, "save1": 16, "xla_graph_executor": 16, "pjrt_computation_cli": 16, "dir": 16, "pytorch_test_with_slow": 16, "test_torch": 16, "test_put_xla_uint8": 16, "torch_test_devic": 16, "pytorch_test_bas": 16, "brief": 17, "basic": [17, 18, 20, 26], "reader": 17, "modif": 17, "fetch": 17, "discuss": [17, 28], "opcod": 17, "fed": 17, "four": 17, "attach": [17, 27], "callback": 17, "xla_tensor_z": 17, "cut": [17, 18, 26], "transferfromdevic": 17, "tell": [17, 18, 26], "properti": [17, 18, 26], "illustr": [17, 28], "suppos": 17, "tensors_on_devic": 17, "z": [17, 18, 26], "subgraph": [17, 18, 26], "signal": 17, "far": 17, "suitabl": 17, "trade": [17, 18, 26], "off": 17, "spend": 17, "fusion": 17, "worth": [17, 18, 26], "latter": [17, 23], "wheel": [17, 23], "runtime_vers": 17, "project_id": 17, "accelerator_typ": 17, "tpu_nam": 17, "your_tpu_nam": 17, "subnetwork": 17, "tpusubnet": 17, "pip3": 17, "cp38": 17, "linux_x86_64": 17, "whl": 17, "apt": 17, "libopenbla": 17, "dev": [17, 20], "libgl1": 17, "guidelin": 17, "progress": 17, "bar": 17, "rememb": 17, "txt2img": 17, "prompt": 17, "photograph": 17, "astronaut": 17, "ride": 17, "hors": 17, "relat": 17, "precision_scop": 17, "addition": [17, 19, 23], "particular": 17, "frozenclipembedd": 17, "simplic": [17, 18, 26], "ddim": 17, "top": 17, "attr": 17, "statement": [17, 18, 26], "stop": 17, "fall": [17, 24], "difficult": 17, "readi": 17, "investig": [17, 20], "cover": [17, 27], "huggingfac": 17, "sd": 17, "xl": 17, "cd": [17, 23], "text_to_imag": 17, "inference_tpu_single_devic": 17, "lora": 17, "model_id": 17, "stabilityai": 17, "pipelin": 17, "dpmsolvermultistepschedul": 17, "txt": 17, "invisible_watermark": 17, "transform": [17, 23, 28], "safetensor": 17, "licens": 17, "card": 17, "cli": 17, "_your_copied_token__": 17, "pipe": 17, "hour": 17, "wherea": 17, "likewis": 17, "gpt": 17, "15": 17, "min": 17, "subsequ": 17, "advantag": 17, "mayb": 17, "notic": 17, "piec": 17, "__call__": 17, "commit": 17, "caveat": 17, "rule": [17, 19], "thumb": 17, "durat": [17, 29], "constantli": 17, "idl": 17, "inference_tpu_": 17, "capture_profil": 17, "gap": 17, "xp": 17, "measur": 17, "portion": 17, "busi": 17, "scroll": 17, "occupi": 17, "demonstr": [17, 19, 24, 29], "displai": 17, "largest": 17, "zoom": 17, "timelin": 17, "period": 17, "examin": 17, "did": 17, "pipe_watermark": 17, "closer": 17, "preced": 17, "proceed": [17, 24], "watermark": 17, "cv2": 17, "pywt": 17, "leav": 17, "broken": 17, "rewrit": [17, 18, 26, 27], "rerun": 17, "scale_model_input": 17, "ran": 17, "my_funct": 17, "preocess": 17, "debug_single_process": 17, "magic": [17, 18, 26], "treat": 17, "xla_no_special_scalar": 17, "hurt": [18, 26], "perf": [18, 26], "pov": [18, 26], "sai": [18, 26], "assur": [18, 26], "gone": [18, 26], "coverag": [18, 26], "aim": [18, 24, 26], "explan": [18, 26], "mainli": [18, 26], "problem": [18, 26], "beginn": [18, 26], "propos": [18, 26], "reli": [18, 26], "impract": [18, 26], "assumpt": [18, 26], "ye": [18, 25, 26], "sentenc": [18, 26], "bucket": [18, 26, 29], "kinda": [18, 26], "anti": [18, 26], "frontend": [18, 26], "matter": [18, 26], "workaround": [18, 26], "okai": [18, 26], "teach": [18, 26], "produc": [18, 19, 20, 26], "theoret": [18, 26], "sort": [18, 26], "obviou": [18, 26], "s64": [18, 26], "inde": [18, 26], "_get_xla_tensor_dimension_s": [18, 26], "commonli": [18, 26], "wrong": [18, 26], "wors": [18, 26], "probabl": [18, 26], "know": [18, 20, 26], "upper": [18, 26], "nit": [18, 26], "rand": [18, 26], "solv": [18, 26], "kept": [18, 26], "earli": [18, 26], "accessor": [18, 26], "2d": [18, 24, 26], "implicitli": [18, 26], "doubl": [18, 26], "overload": [18, 26], "explod": [18, 26], "convers": [18, 26], "cheap": [18, 26], "ve": [18, 26], "hoc": [18, 26], "think": [18, 26], "verison": [18, 26], "bla": [18, 26], "blabla": [18, 26], "interpret": [18, 26], "proce": [18, 26], "adopt": [18, 26], "uglier": [18, 26], "win": [18, 26], "pars": [18, 26], "torchscript": [18, 26], "somehow": [18, 26], "merg": [18, 26], "lazili": [18, 26, 27, 29], "properli": [18, 26], "thought": [18, 26], "trivial": [18, 26], "effort": [18, 26, 27], "side": [18, 26], "bandwidth": [18, 26], "automag": [18, 26], "gold": [18, 26], "smart": [18, 26], "trick": [18, 26], "tbh": [18, 26], "longer": [18, 26], "unawar": [18, 26], "hope": [18, 26], "smash": [18, 26], "blocker": [18, 26], "ahead": [18, 26], "nnc": [18, 26], "exactli": [18, 26], "transpos": [18, 26], "brian": [18, 26], "hirsh": [18, 26], "bdhirsh": [18, 26], "question": [18, 26], "comment": [18, 26], "stick": [18, 26], "torch_warn": [18, 26], "yea": [18, 26], "hei": [18, 26], "won": [18, 19, 26], "blaze": [18, 26], "isn": [18, 26, 29], "abil": [18, 20, 26], "devirtu": [18, 26], "sound": [18, 26], "great": [18, 26], "carri": [18, 26, 27], "truth": [18, 26], "irvalu": [18, 26], "enforc": [18, 20, 26], "discrep": [18, 26], "followup": [18, 26], "1000": [18, 26], "my": [18, 26, 29], "presenc": [18, 26], "get_dimention_s": [18, 26], "didn": [18, 26], "exponenti": [18, 26], "blowup": [18, 26], "fewer": [18, 26], "opportun": [18, 26], "recogn": [18, 21, 26], "feasibl": [18, 26], "annoi": [18, 26], "wasn": [18, 26], "materiz": [18, 26], "combo": [18, 26], "extend": 19, "float32": 19, "datatyp": 19, "float16": 19, "bfloat16": [19, 25], "syncfre": 19, "autocast": 19, "summar": 19, "elig": 19, "suppli": 19, "addmm": 19, "addmm_": 19, "prefer": 19, "float64": 19, "respect": 19, "unlist": 19, "__matmul__": 19, "addbmm": 19, "addmv": 19, "addr": 19, "baddbmm": 19, "bmm": 19, "conv1d": 19, "conv2d": [19, 23], "conv3d": 19, "conv_transpose1d": 19, "conv_transpose2d": 19, "conv_transpose3d": 19, "matmul": 19, "relu": [19, 20], "prelu": 19, "max_pool2d": 19, "batch_norm": 19, "log_softmax": 19, "binary_cross_entropy_with_logit": 19, "prod": 19, "cdist": 19, "chloeski": 19, "invers": 19, "reflection_pad": 19, "replication_pad": 19, "mse_loss": 19, "cosine_embbeding_loss": 19, "nll_loss": 19, "multilabel_margin_loss": 19, "qr": 19, "svd": 19, "triangular_solv": 19, "linalg_svd": 19, "linalg_inv_ex": 19, "widest": 19, "index_copi": 19, "scaler": [19, 25], "gradscal": 19, "_fetch_gradi": 19, "xla_use_f16": 19, "underflow": 19, "imagenet": 19, "minimum": [20, 23, 24], "nccl": 20, "new_rank": 20, "ddp_model": 20, "final": [20, 27], "launcher": 20, "demo_fn": 20, "touch": [20, 29], "five": 20, "sy": 20, "tempfil": 20, "cleanup": 20, "destroy_process_group": 20, "toymodel": 20, "net1": 20, "1000000": 20, "net2": 20, "demo_bas": 20, "assert": 20, "graident_as_bucket_view": 20, "label": 20, "run_demo": 20, "tot": 20, "statist": 20, "unit": 20, "median": 20, "90th": 20, "deviat": 20, "cv": 20, "418": 20, "54": 20, "419": 20, "22": 20, "430": 20, "40": 20, "76": 20, "02": 20, "97": 20, "407": 20, "60": 20, "39": 20, "seem": 20, "17864": 20, "19": [20, 21], "20108": 20, "96": 20, "24351": 20, "74": 20, "5866": 20, "83": 20, "10701": 20, "11770": 20, "00": 20, "14313": 20, "78": 20, "3102": 20, "92": 20, "41": [20, 21], "round": 20, "heavili": [20, 21], "sens": 20, "amort": 20, "logdir": 20, "converg": 20, "caution": 20, "interest": 20, "known": 20, "crash": 20, "unmodifi": 21, "hook": 21, "biggest": [21, 23], "torchfx": 21, "technologi": 21, "fx": 21, "a_xla": 21, "b_xla": 21, "compiled_cod": 21, "eval_model": 21, "xla_resnet18": 21, "eval": 21, "dynamo_resnet18": 21, "no_grad": 21, "resent18": 21, "analysi": 21, "bench": 21, "59": 21, "resnext50_32x4d": 21, "91": 21, "alexnet": 21, "28": 21, "mobilenet_v2": 21, "18": 21, "62": 21, "mnasnet1_0": 21, "68": 21, "vgg16": 21, "bert_pytorch": 21, "squeezenet1_1": 21, "timm_vision_transform": 21, "52": 21, "geomean": 21, "04": 21, "train_model": 21, "crossentropyloss": 21, "pred": 21, "train_model_main": 21, "dynamo_train_model": 21, "xla_optim": 21, "weight_decai": 21, "extract": 21, "07": 21, "43": 21, "81": 21, "87": 21, "fwd": 21, "bwd": 21, "e2": 21, "hide": 21, "scenario": 21, "larger": [21, 23], "wit": 21, "promis": 21, "tradit": 21, "excit": 21, "upcom": [21, 27], "invest": 21, "matur": 21, "stori": 21, "_higher_order_op": 22, "fori_loop": 22, "cond_fn": 22, "body_fn": 22, "bodi": 22, "iteri": 22, "init_v": 22, "functionaltensor": 22, "lvl": 22, "cumul": 22, "ten": 22, "51": 22, "xlafullyshardeddataparallel": 23, "my_modul": [23, 24], "adam": [23, 24], "0001": [23, 24], "leftov": [23, 24], "arxiv": 23, "1910": 23, "02054": 23, "reshard_after_forward": 23, "test_train_mp_mnist_fsdp_with_ckpt": 23, "test_train_mp_imagenet_fsdp": 23, "interleav": 23, "submodul": 23, "fsdpvitmodel": 23, "checkpoint_modul": [23, 24], "3524": 23, "auto_wrap_polici": [23, 24], "size_based_auto_wrap_polici": 23, "polici": [23, 27], "100m": 23, "transformer_auto_wrap_polici": [23, 24], "transformer_layer_cl": [23, 24], "auto_wrapper_cal": 23, "remateri": 23, "resum": 23, "get_shard_metadata": 23, "consolidate_sharded_model_checkpoint": 23, "stitch": 23, "ckpt": 23, "shard_metadata": 23, "ckpt_path": 23, "pth": 23, "consolidate_sharded_ckpt": 23, "ckpt_prefix": 23, "your_sharded_checkpoint_fil": 23, "ckpt_suffix": 23, "_rank": 23, "inspir": 23, "structur": [23, 27], "fairscal": 23, "fullyshardeddataparallel": 23, "readthedoc": 23, "en": 23, "resort": 23, "train_resnet_fsdp_auto_wrap": 23, "newer": 23, "recurs": [23, 24], "98": 23, "drop_last": 23, "use_nested_fsdp": 23, "use_gradient_checkpoint": 23, "final_ckpt": 23, "75": 23, "download": 23, "1k": 23, "datadir": 23, "test_set_batch_s": 23, "eval_interv": 23, "num_warmup_epoch": 23, "lr_scheduler_divide_every_n_epoch": 23, "lr_scheduler_divisor": 23, "residu": 23, "algorithm": [23, 24], "ronghanghu": 23, "vit_10b_fsdp_exampl": 23, "vit": 23, "fsdpv2": 24, "famou": 24, "enjoi": 24, "tabl": 24, "spmd_fully_sharded_data_parallel": 24, "spmdfullyshardeddataparallel": 24, "autowrap": 24, "decoderlay": 24, "functool": 24, "decoder_only_model": 24, "shard_output": 24, "0th": 24, "children": 24, "fork": 24, "hf": 24, "abstract": [25, 27], "blockwis": 25, "int4": 25, "analog": 25, "classifi": 25, "flexibl": 25, "choos": [25, 29], "docstr": 25, "xla_quantized_matmul": 25, "n_input_featur": 25, "n_output_featur": 25, "w_int": 25, "127": 25, "int8": 25, "matmul_output": 25, "quantized_matmul": 25, "x_xla": 25, "w_int_xla": 25, "scaler_xla": 25, "matmul_output_xla": 25, "w": 25, "f_dynamo": 25, "dynamo_out_xla": 25, "myqlinearforxlabackend": 25, "load_weight": 25, "processed_w": 25, "processed_scal": 25, "stuff": 25, "orig_model": 25, "mymodel": 25, "q_weight": 25, "q_weights_for_xla": 25, "process_for_xla": 25, "q_linear": 25, "xlaquantizedlinear": 25, "in_featur": 25, "out_featur": 25, "load_quantized_weight": 25, "channel": 25, "sym": 25, "asym": 25, "w8a16": 25, "w8a8": 25, "w4a8": 25, "gspmd": [27, 28], "proced": 27, "src": [27, 29], "_input_sharding_": 27, "4d": 27, "input_shard": 27, "shardingspec": 27, "input_mesh": 27, "s2": 27, "s3": 27, "s4": 27, "_after": 27, "_the": 27, "unnecessari": 27, "forth": 27, "techniqu": 27, "decis": 27, "nice": 27, "arrang": 27, "center": 27, "multislic": 27, "accept": 27, "denot": 27, "delai": 27, "subclass": 27, "__torch_dispatch__": 27, "global_tensor": 27, "strictli": 27, "local_shard": 27, "xlashard": 27, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 27, "xla_sharded_tensor": 27, "l12": 27, "ongo": 27, "distributedtensor": 27, "proof": 27, "concept": [27, 28], "distribute_tensor": 27, "devicemesh": 27, "big_tensor": 27, "100000": 27, "88": 27, "my_dtensor": 27, "stai": 27, "dynamo_mark_shard": 27, "placement": 27, "visualize_tensor_shard": 27, "visualize_shard": 27, "rich": 27, "2x2": 27, "generated_t": 27, "use_color": 27, "style": 27, "tile": 27, "partial_repl": 27, "envvar": 27, "xla_auto_spmd": 27, "_tensor": 27, "distribute_modul": 27, "auto_polici": 27, "mymodul": 27, "sharded_model": 27, "behvaior": 27, "xla_auto_use_group_shard": 27, "reshard": 27, "xla_auto_spmd_mesh": 27, "unset": 27, "hint": 28, "strategi": 28, "th": 28, "cluster": 28, "interconnect": 28, "encourag": 28, "fist": 28, "paral": 28, "dedic": 29, "planner": 29, "spmdsaveplann": 29, "spmdloadplann": 29, "dist_cp": 29, "distributed_checkpoint": 29, "xc": 29, "storage_writ": 29, "filesystemwrit": 29, "checkpoint_dir": 29, "storage_read": 29, "filesystemread": 29, "all_step": 29, "save_async": 29, "unblock": 29, "preemption": 29, "detect": 29, "provis": 29, "queuedresourc": 29, "autocheckpoint": 29, "chkpt_on_preempt": 29, "fsspec": 29, "filesystem": 29, "prime_optim": 29, "chkpt_mgr": 29, "tracked_step": 29, "highest": 29, "best_step": 29, "prime": 29, "enumer": 29, "attempt": 29, "unprim": 29, "destruct": 29, "discov": 29, "nvidia": 30, "resnet": 30, "num_gpu_machin": 30, "rank_of_current_machin": 30, "machine_0_ip_address": 30, "training_or_inference_script_using_spmd": 30, "xla_use_spmd": 30, "test_train_spmd_imagenet": 30}, "objects": {"": [[11, 0, 0, "-", "torch_xla"]], "torch_xla": [[11, 1, 1, "", "compile"], [11, 1, 1, "", "device"], [11, 1, 1, "", "device_count"], [11, 1, 1, "", "devices"], [11, 0, 0, "-", "experimental"], [11, 1, 1, "", "manual_seed"], [11, 0, 0, "-", "runtime"], [11, 1, 1, "", "sync"]], "torch_xla.core": [[11, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[11, 1, 1, "", "add_step_closure"], [11, 1, 1, "", "all_gather"], [11, 1, 1, "", "all_reduce"], [11, 1, 1, "", "all_to_all"], [11, 1, 1, "", "get_memory_info"], [11, 1, 1, "", "get_rng_state"], [11, 1, 1, "", "get_stablehlo"], [11, 1, 1, "", "get_stablehlo_bytecode"], [11, 1, 1, "", "is_master_ordinal"], [11, 1, 1, "", "mesh_reduce"], [11, 1, 1, "", "optimizer_step"], [11, 1, 1, "", "rendezvous"], [11, 1, 1, "", "save"], [11, 1, 1, "", "set_rng_state"], [11, 1, 1, "", "wait_device_ops"], [11, 1, 1, "", "xla_device"], [11, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[11, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[11, 1, 1, "", "counter_names"], [11, 1, 1, "", "counter_value"], [11, 1, 1, "", "metric_data"], [11, 1, 1, "", "metric_names"], [11, 1, 1, "", "metrics_report"], [11, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[11, 0, 0, "-", "parallel_loader"], [11, 0, 0, "-", "spmd"], [11, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[11, 2, 1, "", "MpDeviceLoader"]], "torch_xla.distributed.spmd": [[11, 2, 1, "", "HybridMesh"], [11, 2, 1, "", "Mesh"], [11, 1, 1, "", "clear_sharding"], [11, 1, 1, "", "get_1d_mesh"], [11, 1, 1, "", "get_global_mesh"], [11, 1, 1, "", "mark_sharding"], [11, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[11, 1, 1, "", "spawn"]], "torch_xla.experimental": [[11, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[11, 1, 1, "", "addressable_device_count"], [11, 1, 1, "", "device_type"], [11, 1, 1, "", "get_master_ip"], [11, 1, 1, "", "global_device_count"], [11, 1, 1, "", "global_ordinal"], [11, 1, 1, "", "global_runtime_device_count"], [11, 1, 1, "", "initialize_cache"], [11, 1, 1, "", "is_spmd"], [11, 1, 1, "", "local_device_count"], [11, 1, 1, "", "local_ordinal"], [11, 1, 1, "", "local_process_count"], [11, 1, 1, "", "use_spmd"], [11, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"]}, "titleterms": {"learn": [0, 1, 10], "about": [0, 1, 10], "gpu": [0, 9, 14, 19, 30], "tpu": [1, 4, 14, 15, 17, 19, 23, 27], "bazel": 2, "pytorch": [2, 3, 4, 6, 7, 8, 10, 11, 15, 16, 17, 21, 23, 26, 27, 28], "xla": [2, 3, 4, 6, 7, 10, 11, 15, 16, 17, 19, 21, 23, 25, 26, 27, 28], "depend": [2, 7, 9], "how": [2, 20, 25, 28], "build": 2, "librari": 2, "torch": [2, 8, 14, 27], "plugin": [2, 6], "remot": 2, "cach": [2, 15], "run": [2, 3, 8, 15, 16, 17, 27, 30], "test": [2, 3, 5, 16, 22], "code": [2, 4, 17, 25], "coverag": 2, "languag": 2, "server": 2, "codegen": 3, "migrat": 3, "guid": [3, 5, 28], "befor": [3, 5], "you": [3, 5, 18, 26], "start": [3, 5, 18, 26], "file": [3, 5, 8], "structur": [3, 5], "old": 3, "op": [3, 5, 25], "lower": [3, 5], "step": [3, 4], "1": [3, 17, 18, 26], "identifi": 3, "2": [3, 17, 18, 24, 26], "inspect": 3, "gener": [3, 8], "lazyir": 3, "h": 3, "3": [3, 18, 26], "implement": [3, 6], "miss": 3, "ir": 3, "function": 3, "torch_xla": [3, 11, 18], "csrc": 3, "ops_xla_shape_fn": 3, "cpp": 3, "4": 3, "ops_lower_fn": 3, "5": 3, "cleanup": 3, "verifi": 3, "result": 3, "sampl": 3, "pr": 3, "configur": 4, "develop": 4, "environ": [4, 16], "visual": 4, "studio": 4, "creat": [4, 15], "connect": 4, "your": 4, "set": 4, "up": 4, "workspac": 4, "next": 4, "understand": [5, 16], "oper": [5, 8, 18, 19, 25, 26], "unit": [5, 16], "tip": 5, "custom": [6, 7, 9], "hardwar": 6, "pjrt": [6, 14], "c": 6, "api": [6, 11, 13], "packag": 6, "kernel": [7, 9], "via": [7, 9], "palla": 7, "adopt": 7, "abov": 7, "compat": 7, "us": [7, 18, 20, 22, 24, 25, 26, 28], "built": 7, "flashattent": 7, "exampl": [7, 17, 19, 22, 23, 24], "usag": [7, 13, 22], "integr": [7, 21, 27], "pagedattent": 7, "export": 8, "stablehlo": 8, "save": [8, 15], "bytecod": 8, "disk": 8, "convert": [8, 17], "serv": 8, "common": [8, 16], "wrapper": 8, "i": [8, 18, 26, 28], "want": 8, "directli": 8, "tf": 8, "saved_model": 8, "format": 8, "without": [8, 18, 26], "need": 8, "an": [8, 15], "separ": 8, "command": 8, "other": 8, "produc": 8, "save_as_stablehlo": 8, "preserv": 8, "high": 8, "level": 8, "composit": 8, "triton": 9, "document": 10, "acceler": 10, "featur": [10, 21, 25], "improv": 10, "workload": 10, "perform": [10, 14, 16, 17], "contribut": 10, "runtim": [11, 14], "xla_model": 11, "distribut": [11, 14, 29], "spmd": [11, 24, 27, 28, 30], "experiment": [11, 25], "debug": [11, 16, 27], "dynam": [12, 18, 26], "shape": [12, 18, 26], "bound": [12, 18, 26], "eager": 13, "mode": [13, 28], "compil": [13, 15, 16, 27], "basic": 13, "infer": [13, 17, 21], "train": [13, 14, 21, 23], "benchmark": [13, 16, 20], "tl": 14, "dr": 14, "benefit": 14, "quickstart": 14, "cpu": [14, 15], "pod": [14, 15, 17, 23, 27], "docker": 14, "singl": [14, 15, 17], "node": 14, "multi": [14, 15], "differ": 14, "from": [14, 15, 18, 26], "xrt": 14, "multithread": 14, "v2": 14, "v3": [14, 23], "chang": 14, "xm": 14, "rendezv": 14, "new": 14, "devic": [15, 17, 27], "tensor": [15, 16, 18, 26], "ar": 15, "model": [15, 25], "multipl": [15, 17], "process": [15, 29], "deep": 15, "dive": 15, "lazi": 15, "memori": [15, 22], "layout": 15, "move": 15, "load": [15, 27], "further": [15, 28], "read": [15, 28], "troubleshoot": 16, "saniti": 16, "check": 16, "version": 16, "A": 16, "simpl": [16, 22], "calcul": 16, "resnet": [16, 23], "With": 16, "fake": [16, 20], "data": [16, 20, 23, 24, 27], "tool": [16, 27], "auto": [16, 27], "metric": 16, "analysi": [16, 17], "execut": 16, "get": 16, "report": 16, "The": 16, "clear": 16, "dynamo": 16, "profil": [16, 17], "known": 16, "caveat": 16, "quirk": 16, "more": 16, "variabl": 16, "combin": 16, "reproduc": 16, "ci": 16, "cd": 16, "failur": 16, "overview": 17, "setup": 17, "stabl": 17, "diffus": 17, "lightn": 17, "hf": 17, "sourc": [18, 26], "recompil": [18, 26], "let": [18, 26], "": [18, 26], "first": [18, 26], "some": [18, 26], "fact": [18, 26], "constraint": [18, 26], "input": [18, 26], "dataset": [18, 26], "output": [18, 24, 26], "can": [18, 26], "fix": [18, 26], "case": [18, 22, 26], "when": [18, 26], "queri": [18, 26], "its": [18, 26], "real": [18, 20, 26], "dimens": [18, 26], "what": [18, 26, 28], "control": [18, 22, 26], "flow": [18, 26], "conclus": [18, 26], "appendix": [18, 26], "automat": 19, "mix": 19, "precis": 19, "amp": 19, "best": 19, "practic": 19, "support": [19, 25], "do": 20, "distributeddataparallel": 20, "ddp": 20, "background": 20, "motiv": 20, "resnet50": 20, "mnist": [20, 23], "disclaim": 20, "torchdynamo": 21, "gap": 21, "take": 21, "awai": 21, "optim": [22, 27, 29], "util": 22, "while_loop": 22, "group": [22, 29], "pure": 22, "python": 22, "while": 22, "loop": 22, "fulli": [23, 24], "shard": [23, 24, 27], "parallel": [23, 24], "script": 23, "imagenet": 23, "instal": 23, "clone": 23, "repo": 23, "8": 23, "50": 23, "10": 23, "billion": 23, "paramet": 23, "gradient": 24, "checkpoint": [24, 29], "huggingfac": 24, "llama": 24, "quantiz": 25, "call": 25, "modul": 25, "swap": 25, "matrix": 25, "multipli": 25, "advanc": 27, "topic": 27, "awar": 27, "host": 27, "virtual": 27, "hybrid": 27, "mesh": [27, 28], "xlashardedtensor": 27, "dtensor": 27, "activ": 27, "user": 28, "partit": 28, "spec": 28, "checkpointmanag": 29, "restor": 29, "state": 29}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Learn about GPUs": [[0, "learn-about-gpus"]], "Learn about TPUs": [[1, "learn-about-tpus"]], "Bazel in Pytorch/XLA": [[2, "bazel-in-pytorch-xla"]], "Bazel dependencies": [[2, "bazel-dependencies"]], "How to build XLA libraries": [[2, "how-to-build-xla-libraries"]], "How to build the Torch/XLA plugin": [[2, "how-to-build-the-torch-xla-plugin"]], "Remote caching": [[2, "remote-caching"]], "Running tests": [[2, "running-tests"]], "Code coverage": [[2, "code-coverage"]], "Language Server": [[2, "language-server"]], "Building PyTorch/XLA": [[2, "building-pytorch-xla"]], "Codegen migration Guide": [[3, "codegen-migration-guide"]], "Before you start": [[3, "before-you-start"], [5, "before-you-start"]], "File structure": [[3, "file-structure"], [5, "file-structure"]], "PyTorch Codegen files": [[3, "pytorch-codegen-files"]], "PyTorch/XLA Codegen files": [[3, "pytorch-xla-codegen-files"]], "PyTorch/XLA Old Op Lowering files": [[3, "pytorch-xla-old-op-lowering-files"]], "Codegen step by step": [[3, "codegen-step-by-step"]], "1. Identify the op": [[3, "identify-the-op"]], "2. Codegen the op and inspect the generated file": [[3, "codegen-the-op-and-inspect-the-generated-file"]], "LazyIr.h": [[3, "lazyir-h"]], "3. Implement the missing IR function": [[3, "implement-the-missing-ir-function"]], "torch_xla/csrc/ops/ops_xla_shape_fn.h": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-h"]], "torch_xla/csrc/ops/ops_xla_shape_fn.cpp": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-cpp"]], "4. Implement the lowering function": [[3, "implement-the-lowering-function"]], "torch_xla/csrc/ops/ops_lower_fn.cpp": [[3, "torch-xla-csrc-ops-ops-lower-fn-cpp"]], "5. Cleanup": [[3, "cleanup"]], "Run the test and verify the result": [[3, "run-the-test-and-verify-the-result"]], "Sample PRs": [[3, "sample-prs"]], "Configure a development environment": [[4, "configure-a-development-environment"]], "Visual Studio Code": [[4, "visual-studio-code"]], "Creating and connecting to your TPU": [[4, "creating-and-connecting-to-your-tpu"]], "Setting up a Visual Studio Code workspace with PyTorch/XLA": [[4, "setting-up-a-visual-studio-code-workspace-with-pytorch-xla"]], "Next steps": [[4, "next-steps"]], "OP Lowering Guide": [[5, "op-lowering-guide"]], "Understanding the operation": [[5, "understanding-the-operation"]], "Unit Test": [[5, "unit-test"]], "Tips": [[5, "tips"]], "Custom Hardware Plugins": [[6, "custom-hardware-plugins"]], "Implementing a PJRT Plugin": [[6, "implementing-a-pjrt-plugin"]], "PJRT C API Implementation": [[6, "pjrt-c-api-implementation"]], "PyTorch/XLA Plugin Package": [[6, "pytorch-xla-plugin-package"]], "Custom Kernels via Pallas": [[7, "custom-kernels-via-pallas"]], "Adopt the above kernel to be compatible with PyTorch/XLA": [[7, "adopt-the-above-kernel-to-be-compatible-with-pytorch-xla"]], "Use built-in kernels": [[7, "use-built-in-kernels"]], "FlashAttention": [[7, "id1"]], "Example usage": [[7, "example-usage"], [7, "id3"]], "Integration Example": [[7, "integration-example"], [7, "id4"]], "PagedAttention": [[7, "id2"]], "Dependencies": [[7, "dependencies"], [9, "dependencies"]], "Torch Export to StableHLO": [[8, "torch-export-to-stablehlo"]], "Saving StableHLO bytecodes to disk": [[8, "saving-stablehlo-bytecodes-to-disk"]], "Convert saved StableHLO for serving": [[8, "convert-saved-stablehlo-for-serving"]], "Common wrappers": [[8, "common-wrappers"]], "I want to save directly tf.saved_model format without needing to run an separate command.": [[8, "i-want-to-save-directly-tf-saved-model-format-without-needing-to-run-an-separate-command"]], "Other common wrappers": [[8, "other-common-wrappers"]], "Files produced by save_as_stablehlo.": [[8, "files-produced-by-save-as-stablehlo"]], "Preserving High-Level PyTorch Operations in StableHLO by generating stablehlo.composite": [[8, "preserving-high-level-pytorch-operations-in-stablehlo-by-generating-stablehlo-composite"]], "Custom GPU Kernels via Triton": [[9, "custom-gpu-kernels-via-triton"]], "PyTorch/XLA documentation": [[10, "pytorch-xla-documentation"]], "Learn about Pytorch/XLA": [[10, null]], "Learn about accelerators": [[10, null]], "PyTorch/XLA features": [[10, null]], "Improve Pytorch/XLA workload performance": [[10, null]], "Contribute to Pytorch/XLA": [[10, null]], "PyTorch/XLA API": [[11, "pytorch-xla-api"]], "torch_xla": [[11, "module-torch_xla"]], "runtime": [[11, "module-torch_xla.runtime"]], "xla_model": [[11, "module-torch_xla.core.xla_model"]], "distributed": [[11, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[11, "module-torch_xla.distributed.spmd"]], "experimental": [[11, "module-torch_xla.experimental"]], "debug": [[11, "module-torch_xla.debug.metrics"]], "Dynamic shape": [[12, "dynamic-shape"]], "Bounded dynamic shape": [[12, "bounded-dynamic-shape"]], "Eager Mode + Compile API": [[13, "eager-mode-compile-api"]], "Basic Usage": [[13, "basic-usage"]], "Inference": [[13, "inference"], [21, "inference"]], "Training": [[13, "training"], [21, "training"]], "Benchmark": [[13, "benchmark"]], "PJRT Runtime": [[14, "pjrt-runtime"]], "TL;DR": [[14, "tl-dr"]], "Benefits": [[14, "benefits"]], "Quickstart": [[14, "quickstart"]], "CPU": [[14, "cpu"]], "TPU": [[14, "tpu"]], "Pods": [[14, "pods"]], "Docker": [[14, "docker"]], "GPU": [[14, "gpu"]], "Single-node GPU training": [[14, "single-node-gpu-training"]], "Multi-node GPU training": [[14, "multi-node-gpu-training"]], "Differences from XRT": [[14, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[14, "id3"]], "Changes to xm.rendezvous": [[14, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[14, "pjrt-and-torch-distributed"]], "Performance": [[14, "performance"]], "New TPU runtime": [[14, "new-tpu-runtime"]], "PyTorch on XLA Devices": [[15, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[15, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[15, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[15, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[15, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[15, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[15, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[15, "id3"]], "XLA Tensors are Lazy": [[15, "xla-tensors-are-lazy"]], "Memory Layout": [[15, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[15, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[15, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[15, "compilation-caching"]], "Further Reading": [[15, "further-reading"], [28, "further-reading"]], "Troubleshoot": [[16, "troubleshoot"]], "Sanity Check": [[16, "sanity-check"]], "Check PyTorch/XLA Version": [[16, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[16, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[16, "run-resnet-with-fake-data"]], "Performance Debugging": [[16, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[16, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[16, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[16, "compilation-execution-analysis"]], "Get A Metrics Report": [[16, "get-a-metrics-report"]], "Understand The Metrics Report": [[16, "understand-the-metrics-report"]], "Clear The Metrics Report": [[16, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[16, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[16, "performance-profiling"]], "Simple Benchmarking": [[16, "simple-benchmarking"]], "Known Performance Caveats": [[16, "known-performance-caveats"]], "XLA Tensor Quirks": [[16, "xla-tensor-quirks"]], "More Debugging Tools": [[16, "more-debugging-tools"]], "Environment Variables": [[16, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[16, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[16, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Pytorch/XLA overview": [[17, "pytorch-xla-overview"]], "TPU Setup": [[17, "tpu-setup"]], "Converting code to PyTorch XLA": [[17, "converting-code-to-pytorch-xla"]], "Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device": [[17, "example-1-stable-diffusion-inference-in-pytorch-lightning-on-a-single-tpu-device"]], "Example 2. HF Stable Diffusion Inference": [[17, "example-2-hf-stable-diffusion-inference"]], "Running on a Single TPU device": [[17, "running-on-a-single-tpu-device"]], "Profiling and performance analysis": [[17, "profiling-and-performance-analysis"]], "Running on Multiple TPU Devices": [[17, "running-on-multiple-tpu-devices"]], "Running on Pods": [[17, "running-on-pods"]], "Source of recompilations in torch_xla": [[18, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[18, "lets-first-start-with-some-facts-constraints"], [26, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[18, "from-input-dataset"], [26, "from-input-dataset"]], "#2. From operator output": [[18, "from-operator-output"], [26, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[18, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"], [26, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[18, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"], [26, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[18, "from-control-flow"], [26, "from-control-flow"]], "Conclusion:": [[18, "conclusion"], [26, "conclusion"]], "Appendix:": [[18, "appendix"], [26, "appendix"]], "Automatic Mixed Precision": [[19, "automatic-mixed-precision"]], "AMP for XLA:TPU": [[19, "amp-for-xla-tpu"]], "AMP for XLA:TPU Best Practices": [[19, "amp-for-xla-tpu-best-practices"]], "Supported Operators": [[19, "supported-operators"]], "AMP for XLA:GPU": [[19, "amp-for-xla-gpu"]], "AMP for XLA:GPU Best Practices": [[19, "amp-for-xla-gpu-best-practices"]], "Examples": [[19, "examples"]], "How to do DistributedDataParallel(DDP)": [[20, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[20, "background-motivation"]], "How to use DistributedDataParallel": [[20, "how-to-use-distributeddataparallel"]], "Benchmarking": [[20, "benchmarking"]], "Resnet50 with fake data": [[20, "resnet50-with-fake-data"]], "MNIST with fake data": [[20, "mnist-with-fake-data"]], "MNIST with real data": [[20, "mnist-with-real-data"]], "Disclaimer": [[20, "disclaimer"]], "TorchDynamo integration in PyTorch XLA": [[21, "torchdynamo-integration-in-pytorch-xla"]], "Integration": [[21, "integration"]], "Feature gaps": [[21, "feature-gaps"]], "Take away": [[21, "take-away"]], "Optimize memory utilization using while_loop": [[22, "optimize-memory-utilization-using-while-loop"]], "while_loop": [[22, "while-loop"]], "Usage:": [[22, "usage"]], "simple example with while_loop:": [[22, "simple-example-with-while-loop"]], "Control group test case": [[22, "control-group-test-case"]], "Control group example with pure python while loop": [[22, "control-group-example-with-pure-python-while-loop"]], "Fully Sharded Data Parallel in PyTorch XLA": [[23, "fully-sharded-data-parallel-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[23, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[23, "installation"]], "Clone PyTorch/XLA repo": [[23, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[23, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[23, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[23, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Fully Sharded Data Parallel using SPMD": [[24, "fully-sharded-data-parallel-using-spmd"]], "Sharding output": [[24, "sharding-output"]], "Gradient checkpointing": [[24, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[24, "huggingface-llama-2-example"]], "Quantized Operations for XLA (Experimental feature)": [[25, "quantized-operations-for-xla-experimental-feature"]], "How to use:": [[25, "how-to-use"]], "Call XLA quantized op in model code": [[25, "call-xla-quantized-op-in-model-code"]], "Module Swap": [[25, "module-swap"]], "Supported Quantized Operations:": [[25, "supported-quantized-operations"]], "Matrix Multiply": [[25, "matrix-multiply"]], "Source of recompilations in Pytorch/XLA": [[26, "source-of-recompilations-in-pytorch-xla"]], "PyTorch/XLA SPMD advanced topics": [[27, "pytorch-xla-spmd-advanced-topics"]], "Sharding-Aware Host-to-Device Data Loading": [[27, "sharding-aware-host-to-device-data-loading"]], "Virtual Device Optimization": [[27, "virtual-device-optimization"]], "Hybrid Mesh": [[27, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[27, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[27, "xlashardedtensor"]], "DTensor Integration": [[27, "dtensor-integration"]], "Activation Sharding for torch.compile": [[27, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[27, "spmd-debugging-tool"]], "Auto-Sharding": [[27, "auto-sharding"]], "PyTorch/XLA SPMD User Guide": [[28, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[28, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[28, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[28, "spmd-mode"]], "Mesh": [[28, "mesh"]], "Partition Spec": [[28, "partition-spec"]], "Distributed Checkpointing": [[29, "distributed-checkpointing"]], "CheckpointManager": [[29, "checkpointmanager"]], "Restoring Optimizer State": [[29, "restoring-optimizer-state"]], "Process Groups": [[29, "process-groups"]], "Running SPMD on GPU": [[30, "running-spmd-on-gpu"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.Mesh"]], "mpdeviceloader (class in torch_xla.distributed.parallel_loader)": [[11, "torch_xla.distributed.parallel_loader.MpDeviceLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[11, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[11, "torch_xla.device"]], "device_count() (in module torch_xla)": [[11, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[11, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[11, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[11, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.metrics_report"]], "module": [[11, "module-torch_xla"], [11, "module-torch_xla.core.xla_model"], [11, "module-torch_xla.debug.metrics"], [11, "module-torch_xla.distributed.parallel_loader"], [11, "module-torch_xla.distributed.spmd"], [11, "module-torch_xla.distributed.xla_multiprocessing"], [11, "module-torch_xla.experimental"], [11, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.optimizer_step"]], "rendezvous() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[11, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[11, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[11, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[11, "torch_xla.sync"]], "torch_xla": [[11, "module-torch_xla"]], "torch_xla.core.xla_model": [[11, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[11, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[11, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[11, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[11, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[11, "module-torch_xla.experimental"]], "torch_xla.runtime": [[11, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[11, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[11, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["accelerators/gpu", "accelerators/tpu", "contribute/bazel", "contribute/codegen_migration", "contribute/configure-environment", "contribute/op_lowering", "contribute/plugins", "features/distop", "features/pallas", "features/stablehlo", "features/triton", "index", "learn/api-guide", "learn/dynamic_shape", "learn/eager", "learn/pjrt", "learn/pytorch-on-xla-devices", "learn/troubleshoot", "learn/xla-overview", "notes/source_of_recompilation", "perf/amp", "perf/ddp", "perf/dynamo", "perf/fori_loop", "perf/fsdp", "perf/fsdpv2", "perf/quantized_ops", "perf/recompilation", "perf/spmd_advanced", "perf/spmd_basic", "perf/spmd_distributed_checkpoint", "perf/spmd_gpu"], "filenames": ["accelerators/gpu.md", "accelerators/tpu.md", "contribute/bazel.md", "contribute/codegen_migration.md", "contribute/configure-environment.md", "contribute/op_lowering.md", "contribute/plugins.md", "features/distop.md", "features/pallas.md", "features/stablehlo.md", "features/triton.md", "index.rst", "learn/api-guide.rst", "learn/dynamic_shape.md", "learn/eager.md", "learn/pjrt.md", "learn/pytorch-on-xla-devices.md", "learn/troubleshoot.md", "learn/xla-overview.md", "notes/source_of_recompilation.md", "perf/amp.md", "perf/ddp.md", "perf/dynamo.md", "perf/fori_loop.md", "perf/fsdp.md", "perf/fsdpv2.md", "perf/quantized_ops.md", "perf/recompilation.md", "perf/spmd_advanced.md", "perf/spmd_basic.md", "perf/spmd_distributed_checkpoint.md", "perf/spmd_gpu.md"], "titles": ["Learn about GPUs", "Learn about TPUs", "Bazel in Pytorch/XLA", "Codegen migration Guide", "Configure a development environment", "OP Lowering Guide", "Custom Hardware Plugins", "Support of Torch Distributed API in PyTorch/XLA", "Custom Kernels via Pallas", "Torch Export to StableHLO", "Custom GPU Kernels via Triton", "PyTorch/XLA documentation", "PyTorch/XLA API", "Dynamic shape", "Eager Mode + Compile API", "PJRT Runtime", "PyTorch on XLA Devices", "Troubleshoot", "Pytorch/XLA overview", "Source of recompilations in torch_xla", "Automatic Mixed Precision", "How to do DistributedDataParallel(DDP)", "TorchDynamo integration in PyTorch XLA", "Optimize memory utilization using while_loop", "Fully Sharded Data Parallel in PyTorch XLA", "Fully Sharded Data Parallel using SPMD", "Quantized Operations for XLA (Experimental feature)", "Source of recompilations in Pytorch/XLA", "PyTorch/XLA SPMD advanced topics", "PyTorch/XLA SPMD User Guide", "Distributed Checkpointing", "Running SPMD on GPU"], "terms": {"For": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 31], "inform": [0, 1, 4, 10, 12, 14, 15, 16, 17, 18, 19, 27, 31], "googl": [0, 1, 8, 15, 16], "cloud": [0, 1, 2, 4, 6, 11, 15, 16, 22, 30], "see": [0, 1, 2, 3, 4, 5, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 27], "machin": [0, 2, 4, 15, 17, 18, 31], "type": [0, 4, 6, 9, 12, 15, 16, 17, 18, 20, 21], "ar": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 26, 27, 28, 29, 30], "custom": [1, 3, 4, 7, 11, 12, 19, 21, 24, 26, 27, 28, 29], "design": [1, 15, 16, 22, 25, 29], "ai": 1, "acceler": [1, 4, 12, 13, 15, 16, 18, 20], "which": [1, 2, 3, 5, 6, 7, 9, 12, 13, 15, 16, 17, 18, 19, 20, 22, 24, 25, 27, 28, 30], "optim": [1, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27], "train": [1, 8, 12, 13, 16, 17, 18, 20, 28, 30, 31], "infer": [1, 3, 12, 15, 20, 28, 31], "larg": [1, 13, 15, 18, 19, 24, 27, 29], "model": [1, 3, 5, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30, 31], "thei": [1, 2, 5, 6, 7, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "ideal": [1, 2, 3, 19, 22, 27], "varieti": 1, "us": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 14, 15, 16, 17, 18, 20, 22, 24, 28, 30, 31], "case": [1, 2, 3, 5, 9, 12, 15, 16, 17, 18, 22, 25, 28], "chatbot": 1, "code": [1, 3, 5, 9, 10, 12, 14, 15, 16, 17, 19, 21, 22, 27, 28], "gener": [1, 5, 12, 14, 15, 16, 17, 18, 19, 27], "media": 1, "content": [1, 12], "synthet": 1, "speech": 1, "vision": [1, 24], "servic": [1, 2, 15], "recommend": [1, 2, 3, 4, 5, 12, 14, 15, 16, 20, 28], "engin": [1, 17], "person": 1, "among": 1, "other": [1, 2, 3, 5, 8, 12, 13, 15, 16, 17, 18, 19, 20, 21, 26, 27, 29], "scale": [1, 9, 12, 15, 20, 22, 29], "cost": [1, 22], "effici": [1, 9, 17, 18, 22], "wide": [1, 5, 19, 27], "rang": [1, 5, 12, 15, 25, 28, 29], "workload": [1, 15, 16, 17, 28, 29], "span": [1, 3], "fine": 1, "tune": [1, 28], "provid": [1, 2, 3, 5, 6, 8, 9, 12, 16, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 29, 30], "versatil": 1, "lead": [1, 17, 18], "framework": [1, 9, 11, 14, 19, 26, 27], "includ": [1, 2, 5, 12, 15, 17, 18, 19, 20, 23, 27, 30], "pytorch": [1, 5, 10, 13, 14, 15, 19, 20, 21, 23, 26, 30, 31], "jax": [1, 6, 8, 9, 15], "tensorflow": [1, 2, 6, 9, 12, 15, 17, 19, 27], "seamlessli": 1, "orchestr": 1, "through": [1, 3, 5, 6, 7, 8, 16, 18, 19, 20, 27, 30], "integr": [1, 10, 11, 25, 26, 29], "kubernet": 1, "gke": 1, "leverag": [1, 10, 31], "dynam": [1, 3, 5, 11, 17, 18, 22], "schedul": [1, 18], "improv": [1, 15, 16, 17, 18, 20, 22, 28], "scalabl": 1, "all": [1, 2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 30], "need": [1, 2, 3, 5, 12, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 29], "simultan": 1, "look": [1, 3, 5, 16, 17, 18, 28], "simplest": 1, "wai": [1, 2, 5, 7, 8, 12, 15, 16, 18, 19, 21, 22, 26, 27, 28], "develop": [1, 2, 10, 11, 14, 16, 21, 22, 26, 29], "can": [1, 2, 3, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 24, 25, 26, 28, 29, 30, 31], "also": [1, 2, 3, 5, 6, 7, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 28, 29], "vertex": 1, "fulli": [1, 11, 14, 15, 17, 29], "manag": [1, 8, 12, 20, 30], "platform": 1, "more": [1, 2, 3, 4, 5, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 27, 28, 29, 31], "introduct": [1, 8], "set": [1, 2, 7, 12, 15, 17, 18, 19, 20, 22, 24, 27, 28, 30], "up": [1, 2, 3, 15, 16, 18, 19, 22, 25, 27], "environ": [1, 2, 11, 15, 16, 18, 21, 28, 30], "resourc": [1, 12, 17], "i": [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 28, 30], "free": [2, 5, 13, 17, 20, 21, 24], "softwar": [2, 17], "tool": [2, 5, 18, 24], "autom": 2, "openxla": [2, 6, 14, 22, 26], "both": [2, 4, 5, 7, 9, 15, 18, 19, 20, 22, 24, 25, 26, 27, 29, 30], "make": [2, 4, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "good": [2, 3, 5, 18, 19, 27, 28], "fit": [2, 3, 18, 24], "well": [2, 3, 6, 9, 12, 15, 18, 19, 27, 29], "extern": [2, 4, 8], "seen": [2, 18, 22], "workspac": [2, 17], "file": [2, 4, 12, 15, 17, 18, 20, 21], "http_archiv": 2, "name": [2, 4, 5, 9, 12, 15, 17, 19, 25, 27, 28, 29], "org_tensorflow": 2, "strip_prefix": 2, "f7759359f8420d3ca7b9fd19493f2a01bd47b4ef": 2, "url": 2, "http": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "github": [2, 3, 4, 5, 10, 12, 15, 17, 18, 21, 24, 28], "com": [2, 3, 4, 8, 10, 12, 15, 17, 18, 24, 28], "archiv": 2, "tar": 2, "gz": 2, "pin": [2, 12], "updat": [2, 3, 7, 16, 18, 19, 20, 27, 28], "point": [2, 3, 4, 5, 6, 9, 12, 18, 19, 20, 27], "thi": [2, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "repositori": [2, 15], "differ": [2, 6, 12, 16, 17, 18, 19, 21, 23, 24, 27, 28, 29], "revis": 2, "patch": [2, 17], "mai": [2, 3, 6, 15, 16, 17, 18, 19, 20, 27, 28], "ad": [2, 5, 12, 16, 18, 19, 22, 23, 27, 28], "resolv": 2, "prepar": 2, "hermet": 2, "mechan": 2, "deploi": 2, "becaus": [2, 3, 9, 14, 15, 16, 18, 20, 28], "local": [2, 4, 12, 15, 16, 17, 28], "checkout": [2, 17], "ha": [2, 3, 4, 5, 8, 12, 14, 15, 16, 18, 19, 27, 28, 29], "built": [2, 4], "from": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17, 18, 20, 21, 22, 23, 24, 25, 28, 29, 30], "sourc": [2, 3, 5, 6, 9, 11, 12, 17], "instal": [2, 3, 4, 5, 6, 8, 9, 10, 15, 17, 18], "system": [2, 29], "version": [2, 3, 4, 8, 15, 18, 20, 28], "compat": [2, 9, 15, 26, 30], "e": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 20, 24, 26, 27, 28], "g": [2, 4, 6, 7, 9, 12, 13, 15, 17, 18, 19, 26, 27, 28, 30], "codegen": [2, 5, 11], "torchgen": [2, 3], "python": [2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 15, 17, 18, 19, 21, 22, 27, 28], "modul": [2, 8, 9, 12, 16, 17, 21, 24, 25, 28], "should": [2, 3, 4, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 24, 27, 28, 30], "The": [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31], "directori": [2, 3, 5, 9], "either": [2, 5, 7, 12, 15, 17, 19, 20, 27], "bzl": 2, "overriden": 2, "command": [2, 3, 4, 15, 16, 17, 18, 21, 24], "line": [2, 3, 12, 14, 16, 17, 18, 19, 24, 27], "override_repositori": 2, "path": [2, 6, 9, 12, 16, 17, 19, 24, 27], "export": [2, 3, 4, 5, 11, 15, 17, 18], "tf_repo": 2, "torch_repo": 2, "pleas": [2, 3, 5, 7, 9, 12, 15, 16, 17, 18, 20, 24, 25, 26, 28, 31], "sure": [2, 16, 17], "overridden": [2, 3], "appropri": [2, 18], "been": [2, 5, 12, 15, 16, 18, 19, 27, 28], "use_cuda": 2, "0": [2, 3, 4, 6, 9, 10, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31], "setup": [2, 3, 6, 16, 21], "py": [2, 3, 4, 5, 7, 10, 13, 14, 15, 16, 17, 18, 21, 24, 28, 31], "bdist_wheel": 2, "expect": [2, 3, 6, 10, 14, 15, 17, 19, 22, 26, 27], "object": [2, 12, 28], "present": [2, 30], "new_local_repositori": 2, "build_fil": 2, "pytorch_local_dir": 2, "header": 2, "directli": [2, 3, 5, 6, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28, 30], "share": [2, 3, 6, 15, 16, 17, 28], "libtorch": 2, "so": [2, 3, 6, 10, 12, 13, 15, 16, 17, 18, 19, 24, 27, 30], "same": [2, 3, 5, 6, 9, 10, 12, 14, 15, 16, 17, 18, 19, 20, 23, 26, 27, 28, 29, 31], "where": [2, 4, 7, 8, 12, 13, 15, 16, 17, 18, 19, 24, 25, 27], "lib": [2, 6], "contain": [2, 3, 5, 6, 9, 10, 12, 15, 17, 18, 19, 27], "work": [2, 3, 7, 12, 13, 15, 16, 17, 18, 19, 21, 22, 26, 27, 28, 29], "": [2, 4, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 18, 20, 21, 22, 26, 28, 29, 30], "requir": [2, 3, 5, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28, 30, 31], "pass": [2, 5, 9, 10, 12, 15, 18, 20, 21, 28], "isystemextern": 2, "compil": [2, 5, 6, 9, 10, 11, 12, 13, 15, 18, 19, 20, 22, 25, 26, 27, 29, 30], "find": [2, 3, 5, 9, 15, 17, 18, 21, 25], "satisfi": [2, 28], "them": [2, 3, 5, 9, 12, 15, 16, 17, 18, 19, 27], "some": [2, 3, 5, 12, 13, 14, 15, 16, 17, 21, 26, 28], "user": [2, 4, 6, 9, 11, 14, 15, 16, 17, 18, 19, 22, 23, 25, 26, 27, 28, 30], "bring": [2, 3, 25], "pybind11": 2, "embed": 2, "link": [2, 3], "against": [2, 21], "libpython": 2, "instead": [2, 7, 12, 14, 15, 16, 17, 18, 19, 21, 22, 24, 27, 28, 30], "These": [2, 3, 5, 8, 15, 18, 26, 30], "pybind11_emb": 2, "option": [2, 3, 4, 6, 9, 12, 15, 17, 18, 26, 28, 30], "transit": [2, 16], "simpl": [2, 3, 8, 15, 18, 20, 24, 29], "torch_xla": [2, 4, 5, 6, 7, 8, 9, 10, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], "csrc": [2, 5], "runtim": [2, 3, 4, 6, 11, 16, 17, 21, 24, 28, 29, 30], "configr": 2, "via": [2, 4, 11, 15, 23, 24, 25, 28, 29], "bazelrc": 2, "take": [2, 3, 9, 10, 12, 16, 17, 18, 19, 27, 28], "flag": [2, 3, 12, 13, 20], "config": [2, 4], "remote_cach": 2, "configur": [2, 3, 5, 11, 12, 15, 17, 18, 30], "gcloud": [2, 4, 15, 16, 18], "usual": [2, 3, 5, 14, 16, 17], "faster": [2, 15, 18, 19, 22, 27], "authent": [2, 15], "easi": [2, 15, 16, 19, 27], "express": [2, 25, 29], "complex": [2, 10, 22], "lot": [2, 16, 17, 18, 19, 27], "gain": [2, 15], "have": [2, 3, 4, 5, 6, 8, 9, 12, 15, 16, 17, 18, 19, 21, 22, 24, 25, 27, 28, 30], "singl": [2, 12, 14, 19, 21, 22, 24, 25, 27, 28, 29, 31], "graph": [2, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 22, 27, 28], "everyth": [2, 19, 21, 27], "therefor": [2, 17, 18], "separ": [2, 3, 5, 16, 18, 22, 24, 25], "rest": [2, 15, 17, 19, 27], "plu": [2, 21, 23], "whole": [2, 12, 14, 19, 22, 27], "everythin": 2, "els": [2, 17, 19, 27], "enough": [2, 18, 19, 27], "normal": [2, 3, 15, 19, 25, 27, 28], "achiev": [2, 5, 14, 21], "invok": [2, 3, 22, 28], "standard": [2, 9], "c": [2, 3, 5, 12, 15, 17, 19, 20, 27], "bind": [2, 9], "simpli": [2, 15], "_xlac": [2, 10, 17, 19, 27], "client": [2, 6, 12, 15], "togeth": [2, 14, 15, 16, 21, 24, 28], "when": [2, 3, 5, 7, 10, 12, 13, 14, 15, 16, 17, 18, 20, 22, 24, 28, 29, 30], "chang": [2, 5, 13, 16, 17, 18, 19, 20, 21, 26, 27, 28], "abl": [2, 16, 19, 27, 30], "without": [2, 5, 12, 15, 17, 18, 28, 29, 30], "iter": [2, 12, 13, 16, 17, 18, 22, 28], "cycl": 2, "come": [2, 12, 19, 27], "There": [2, 3, 14, 16, 17, 18, 19, 21, 22, 27, 28], "plenti": 2, "backend": [2, 3, 7, 12, 14, 15, 19, 22, 23, 26, 27, 28, 30], "we": [2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 29, 31], "our": [2, 3, 4, 5, 6, 7, 8, 9, 13, 15, 16, 17, 19, 20, 21, 22, 27, 28], "gc": [2, 30], "storag": [2, 4, 8, 16, 17, 18, 24, 30], "you": [2, 4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 24, 25, 28, 29, 31], "under": [2, 3, 5, 12, 15, 16, 21], "disabl": [2, 12, 14, 17, 18], "default": [2, 5, 12, 14, 15, 16, 17, 18, 20, 24, 28, 30], "speed": [2, 18, 19, 22, 27], "increment": [2, 3], "huge": [2, 17, 18, 19, 21, 27], "margin": 2, "almost": [2, 29], "alwai": [2, 15, 16, 17, 19, 27, 29], "enabl": [2, 10, 12, 13, 14, 17, 18, 20, 21, 26, 28, 29, 30], "ci": [2, 5], "To": [2, 3, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 24, 25, 27, 28, 30, 31], "ensur": [2, 9, 12, 19, 25, 27, 28, 30], "credenti": 2, "auth": [2, 15], "applic": [2, 17, 26, 30], "login": [2, 18], "launch": [2, 12, 15, 16, 18, 21, 22, 24], "browser": 2, "gcp": [2, 4, 15], "variou": [2, 10], "individu": [2, 24, 25, 29], "who": [2, 21], "access": [2, 3, 5, 8, 12, 15, 16, 17, 18, 19, 27, 30], "project": [2, 4, 6, 15, 16, 18], "one": [2, 3, 5, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 25, 27, 28, 29, 31], "onli": [2, 3, 5, 7, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 27, 29, 30], "specifi": [2, 7, 9, 12, 16, 18, 24, 28], "google_default_credenti": 2, "token": [2, 14, 18, 26], "out": [2, 5, 9, 12, 13, 14, 15, 16, 17, 18, 20, 22, 28], "box": [2, 5, 28], "log": [2, 17, 18], "permiss": 2, "add": [2, 3, 5, 9, 10, 12, 16, 17, 18, 19, 22, 23, 24, 27], "new": [2, 3, 4, 5, 7, 14, 16, 17, 18, 19, 22, 27, 28], "role": 2, "In": [2, 3, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 27, 28, 29, 30], "account": [2, 18], "kei": [2, 4, 6, 15, 17, 18, 30], "google_credenti": 2, "On": [2, 15, 30], "docker": [2, 9], "network": [2, 12, 15, 16, 17, 20, 28], "cloudbuild": 2, "down": [2, 5, 18], "imag": [2, 15, 18, 19, 21, 24, 27], "do": [2, 3, 5, 11, 13, 15, 16, 17, 18, 19, 20, 24, 26, 27, 28], "doe": [2, 3, 12, 13, 15, 16, 17, 18, 19, 20, 27, 28], "read": [2, 4, 5, 12, 15, 28], "write": [2, 5, 10, 12, 16, 29], "silo": 2, "each": [2, 3, 5, 7, 10, 12, 15, 16, 17, 18, 19, 22, 24, 25, 27, 28, 29, 30], "uniqu": [2, 16, 18, 19, 27], "benefit": [2, 18, 25, 26, 30], "consist": [2, 7, 9, 15], "remote_default_exec_properti": 2, "some_silo_kei": 2, "bazel_remote_cach": 2, "1": [2, 4, 6, 7, 8, 9, 12, 14, 15, 16, 17, 20, 21, 22, 23, 24, 25, 28, 29, 31], "silo_nam": 2, "your": [2, 3, 6, 8, 9, 15, 16, 17, 18, 19, 21, 25, 27, 28, 30], "tpuvm_mod": 2, "gcloud_service_key_fil": 2, "application_default_credenti": 2, "json": [2, 9], "might": [2, 5, 12, 16, 17, 18, 19, 27], "help": [2, 17, 18, 19, 27], "too": [2, 17, 19, 27], "cannot": [2, 8, 18, 19, 20, 24, 27], "here": [2, 3, 5, 8, 9, 13, 16, 18, 19, 21, 22, 24, 25, 27, 28, 29, 30], "author": 2, "usernam": 2, "behavior": [2, 3, 5, 15, 16, 17, 20], "function": [2, 5, 6, 7, 8, 9, 10, 12, 14, 16, 17, 18, 22, 23, 25, 26, 30], "intend": 2, "first": [2, 3, 4, 9, 10, 12, 13, 15, 17, 18, 21, 28, 29, 30, 31], "time": [2, 3, 4, 12, 13, 15, 16, 17, 18, 19, 22, 23, 27, 28], "slow": [2, 17, 18], "scratch": [2, 3], "veri": [2, 6, 8, 14, 16, 18, 19, 27], "fast": [2, 19, 27], "onc": [2, 7, 12, 16, 17, 18, 19, 22, 27, 28], "again": [2, 3, 9, 16, 18], "bit": [2, 16, 26], "slower": [2, 17, 18, 21], "per": [2, 9, 12, 15, 16, 17, 20, 21, 22, 26], "until": [2, 12, 16, 18, 30], "next": [2, 12, 17, 18, 19, 26, 27, 28], "quit": 2, "current": [2, 6, 8, 9, 12, 13, 14, 15, 16, 18, 19, 21, 22, 23, 25, 26, 27, 28, 31], "migrat": [2, 11, 15], "futur": [2, 3, 4, 6, 9, 13, 15, 16, 17, 18, 19, 25, 27], "plafrom": 2, "cpp": [2, 5], "main": [2, 4, 7, 9, 10, 14, 15, 28], "Of": 2, "cours": 2, "pjrt": [2, 11, 12, 16, 28], "Not": 2, "environment": 2, "variabl": [2, 4, 13, 15, 18, 19, 27], "miss": [2, 5, 12, 17], "common": [2, 15, 19, 25, 26, 27, 29, 30], "part": [2, 3, 6, 10, 12, 14, 15, 17, 18, 28], "ones": [2, 12, 19, 27], "helper": [2, 3, 9, 12], "script": [2, 3, 4, 8, 15, 16, 17, 18, 20, 21, 31], "run_test": 2, "sh": 2, "r": [2, 18], "xla_client": 2, "pure": [2, 3], "easili": [2, 5, 19, 22, 27], "execut": [2, 10, 12, 14, 15, 16, 18, 19, 20, 21, 22, 27, 28, 29, 31], "parallel": [2, 11, 12, 15, 17, 21, 28, 29], "sinc": [2, 3, 5, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 30], "xrt": [2, 12], "port": [2, 15, 31], "gpu": [2, 5, 6, 8, 11, 13, 17, 18, 28], "tpu": [2, 3, 5, 6, 8, 11, 12, 13, 17, 21, 22, 23, 30, 31], "devic": [2, 3, 4, 5, 6, 7, 9, 11, 12, 13, 14, 15, 17, 19, 20, 21, 22, 23, 26, 27, 29, 30], "avail": [2, 12, 15, 16, 17, 18, 19, 24, 27, 31], "reason": [2, 3, 5, 14, 15, 18, 21], "bundl": 2, "target": [2, 9, 14, 15, 16, 18, 19, 20, 22, 27], "sequenti": [2, 12], "calcul": 2, "visual": [2, 28], "lcov": 2, "describ": [2, 3, 4, 9, 12, 16, 18, 20, 21, 29], "document": [2, 3, 4, 5, 6, 9, 15, 16, 20, 21, 26], "editor": 2, "choic": [2, 19, 27], "gutter": 2, "vscode": 2, "power": 2, "like": [2, 3, 4, 5, 8, 12, 15, 16, 17, 18, 19, 20, 24, 27, 28], "clangd": 2, "refer": [2, 3, 5, 7, 8, 9, 10, 13, 15, 16, 18, 24, 26, 28, 31], "autocomplet": 2, "semant": [2, 5, 17, 19, 27], "understand": [2, 18, 19, 27], "underli": [2, 12, 16], "stack": [2, 16, 17, 19, 20, 27, 28], "combin": [2, 5, 12, 19, 27], "studio": 2, "extens": [2, 4, 5, 6], "featur": [2, 8, 13, 15, 17, 21, 25, 28, 29, 30], "assist": 2, "edit": 2, "As": [2, 3, 18, 19, 25, 27], "distutil": 2, "ltc": 3, "lazi": [3, 17, 18, 19, 22, 27, 28], "tensor": [3, 5, 7, 9, 12, 13, 15, 18, 20, 22, 23, 25, 26, 28, 29], "core": [3, 5, 7, 9, 12, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "clean": [3, 17, 22], "exist": [3, 9, 12, 14, 15, 16, 17, 22, 28], "stub": 3, "over": [3, 12, 14, 15, 16, 18, 24, 30], "6": [3, 4, 5, 9, 12, 17, 18, 19, 27], "were": [3, 16, 17, 18, 19, 27], "complet": [3, 12, 16, 17], "process": [3, 5, 6, 7, 10, 12, 14, 15, 17, 18, 21, 24, 26], "found": [3, 15, 18], "ref": [3, 4, 15], "replac": [3, 18, 23], "support": [3, 6, 8, 9, 10, 12, 13, 15, 19, 22, 23, 24, 27, 28, 30, 31], "NOT": 3, "introduc": [3, 7, 8, 14, 15, 17, 18, 21, 28], "ani": [3, 8, 9, 12, 13, 15, 16, 17, 18, 19, 20, 21, 24, 25, 27, 28, 29, 30, 31], "purpos": [3, 5, 26], "follow": [3, 5, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 21, 24, 25, 27, 28, 31], "instruct": [3, 5, 18], "depend": [3, 4, 5, 13, 14, 16, 18, 19, 20, 27], "build": [3, 5, 16, 18, 24], "It": [3, 4, 5, 7, 12, 13, 14, 16, 18, 19, 22, 24, 25, 26, 27, 28], "experi": [3, 5, 14, 15, 21, 30], "workstat": [3, 5], "cpu": [3, 5, 7, 9, 12, 17, 18, 19, 24, 26, 27, 28, 30], "pjrt_devic": [3, 5, 6, 13, 15, 16, 17, 23, 31], "re": [3, 12, 14, 15, 17, 18, 19, 20, 23, 25, 27], "familiar": [3, 16, 25], "issu": [3, 5, 12, 14, 15, 16, 17, 18, 20, 21, 25], "3560": 3, "track": [3, 17, 30], "statu": [3, 17], "put": [3, 5, 16, 17, 21], "alia": [3, 7, 12], "avoid": [3, 17, 18, 20], "duplic": 3, "mention": [3, 5, 19, 22, 27], "below": [3, 5, 7, 9, 14, 15, 18, 19, 20, 27, 30, 31], "live": [3, 5, 12, 19, 27], "folder": [3, 4, 5], "except": [3, 5, 18, 28], "xla_native_funct": [3, 5], "yaml": [3, 5], "torch": [3, 4, 8, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30], "shape_infer": 3, "shape": [3, 5, 8, 9, 10, 11, 12, 17, 18, 23, 28, 29], "defin": [3, 5, 8, 10, 12, 18, 20, 23, 25, 28, 29], "input": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 20, 23, 25, 28, 29, 30], "return": [3, 5, 6, 7, 8, 9, 12, 14, 16, 17, 18, 19, 21, 22, 23, 26, 27, 28, 30], "output": [3, 4, 7, 8, 9, 10, 12, 15, 16, 17, 20, 21, 22, 23, 24, 28], "manual": [3, 5, 8, 14, 17, 24], "gen_lazy_tensor": 3, "data": [3, 7, 9, 11, 12, 14, 15, 16, 18, 19, 20, 22, 27, 29, 30], "aten": [3, 5, 17, 19, 27], "specif": [3, 12, 16, 18, 20, 21, 26], "run_gen_lazy_tensor": 3, "dest": 3, "lazy_ir": 3, "class": [3, 6, 7, 9, 12, 21, 24, 26, 30], "genlazyir": 3, "back": [3, 5, 9, 12, 16, 17, 18, 28], "todai": [3, 13], "most": [3, 6, 12, 15, 17, 22], "categori": [3, 25], "goal": [3, 4, 5, 7, 14], "move": [3, 9, 12, 15, 17, 19, 21, 27, 30], "full_codegen": 3, "necessari": [3, 12, 17, 20], "call": [3, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 27, 28, 30], "upstream": [3, 7, 13, 22], "api": [3, 5, 11, 15, 16, 19, 21, 22, 24, 26, 27, 28, 29, 30], "xlanativefunct": [3, 5], "column": 3, "declar": [3, 5], "anoth": [3, 9, 13, 16, 17, 18, 19, 27], "wrap": [3, 5, 6, 8, 9, 12, 14, 16, 18, 20, 24, 25, 26, 28], "around": [3, 15, 19, 24, 27], "xlatensor": [3, 5, 12, 28], "construct": [3, 5, 16, 18, 24, 28, 29, 30], "aten_xla_typ": [3, 5], "Will": 3, "method": [3, 9, 12, 15, 20, 25, 28, 30], "map": [3, 5, 7, 12], "node": [3, 5, 7, 10, 17, 19, 27, 31], "remov": [3, 15, 17, 18], "tensor_method": [3, 5], "possibl": [3, 15, 16, 17, 18, 24, 25, 28], "multipl": [3, 7, 9, 12, 14, 19, 22, 26, 27], "few": [3, 16, 17, 18, 19, 21, 27, 30], "simpler": [3, 15], "go": [3, 14, 16, 18, 20, 28], "unari": 3, "binari": [3, 6, 9, 22], "exampl": [3, 4, 5, 6, 7, 9, 12, 13, 14, 15, 16, 17, 19, 21, 22, 26, 27, 28, 29, 30, 31], "characterist": 3, "fallback": [3, 5], "_adaptive_avg_pool3d": 3, "condit": [3, 19, 23, 27], "issupportedadaptivepool": 3, "xlahelp": 3, "i64list": 3, "self": [3, 5, 6, 7, 9, 12, 18, 21, 26, 28], "size": [3, 7, 10, 13, 15, 16, 17, 18, 19, 27, 30], "output_size_list": 3, "pool_dim": 3, "nativ": [3, 5, 14, 15, 17, 20, 21, 28], "call_fallback_fn": 3, "xla_fallback": 3, "aten_op": 3, "output_s": 3, "wip": 3, "evolv": 3, "At": [3, 6, 12], "self_tensor": 3, "static": [3, 13, 19, 27], "bool": [3, 12], "sync_upd": 3, "sys_util": 3, "getenvbool": 3, "xla_tensor_update_sync": 3, "true": [3, 12, 14, 15, 18, 19, 21, 24, 27, 28, 30], "xla_check": 3, "dst_tensor": 3, "updatefromtensor": 3, "sync": [3, 12, 14, 17, 18, 20], "complic": [3, 5, 8], "an": [3, 4, 5, 6, 7, 8, 12, 15, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 29, 30], "would": [3, 4, 5, 12, 15, 16, 17, 18, 19, 23, 27], "someth": [3, 18], "ab": [3, 24], "const": [3, 5, 7], "torch_lazy_fn_count": 3, "bridg": [3, 22], "atenfromxlatensor": 3, "getxlatensor": 3, "fail": [3, 12, 16, 17, 30], "explain": [3, 6, 16, 17, 18, 19, 27, 29], "later": [3, 18], "still": [3, 7, 15, 16, 19, 20, 21, 27, 30], "snippet": [3, 16, 28], "auto": [3, 5, 12, 24, 30], "common_devic": 3, "getxladevic": 3, "torch_internal_assert": 3, "xlatensorptr": 3, "lazy_self": 3, "getxlatensororcreateforwrappednumb": 3, "nodeptr": 3, "reusenod": 3, "getirvalu": 3, "makenod": 3, "cachenod": 3, "creat": [3, 9, 10, 12, 15, 17, 18, 20, 21, 28, 30], "std": [3, 7, 21], "get": [3, 5, 12, 13, 14, 15, 18, 19, 21, 24, 26, 27], "check": [3, 4, 5, 12, 16, 26, 29], "reus": [3, 16, 18, 20], "previou": [3, 15, 16, 18, 19, 27], "creation": [3, 12], "If": [3, 4, 5, 9, 12, 15, 16, 17, 18, 19, 26, 27, 28], "correspond": [3, 5, 7, 12, 18, 20, 24, 28, 29], "cach": [3, 8, 12, 13, 18], "newli": [3, 9], "And": [3, 19, 21, 27, 28], "within": [3, 9, 12, 16, 17, 18, 26, 30], "note": [3, 4, 7, 8, 9, 10, 12, 14, 15, 16, 17, 18, 19, 22, 24, 25, 26, 27, 29], "done": [3, 4, 8, 16, 17, 18, 19, 27], "public": [3, 15], "xlanod": 3, "xlavalu": 3, "opkind": [3, 5], "absoutputshap": 3, "num_output": [3, 19, 27], "mhash": 3, "string": [3, 7, 12, 28], "tostr": 3, "overrid": [3, 12, 20], "stringstream": 3, "ss": 3, "str": [3, 6, 12], "xlaopvector": 3, "loweringcontext": 3, "loctx": 3, "A": [3, 4, 6, 12, 15, 16, 18, 19, 20, 25, 26, 27, 28], "coupl": [3, 16, 17], "thing": [3, 17, 18], "keep": [3, 4, 13, 15, 17, 19, 27], "mind": [3, 15, 17], "clone": [3, 15, 17, 18], "even": [3, 12, 15, 16, 17, 19, 21, 27], "everi": [3, 5, 8, 9, 12, 15, 16, 17, 19, 22, 27, 28, 30], "outputshap": 3, "xla_shap": 3, "overli": 3, "simplifi": 3, "buildxxxop": 3, "slightli": [3, 5, 12], "better": [3, 5, 14, 15, 16, 17, 18, 19, 22, 23, 27], "maximumoutputshap": 3, "lower_for_shape_fn": 3, "absl": 3, "xlaop": [3, 5], "operand": 3, "promot": 3, "max": [3, 19, 27, 30], "second": [3, 10, 13, 15, 17, 18, 21, 29, 31], "inferoutputshap": 3, "comput": [3, 4, 12, 15, 16, 17, 18, 19, 20, 27, 28, 29], "logic": [3, 12, 14, 19, 23, 27, 28, 29], "two": [3, 6, 12, 15, 17, 18, 19, 27, 28, 29], "xla_input": 3, "getoutputop": 3, "returnop": 3, "buildab": 3, "origin": [3, 9, 18], "genericop": 3, "modifi": [3, 18, 20, 22, 28], "abov": [3, 5, 6, 9, 13, 14, 15, 16, 17, 18, 19, 21, 22, 27, 29], "delet": 3, "sometim": [3, 18, 19, 27], "being": [3, 12, 16, 18, 21, 29], "tensor_op": 3, "cross": [3, 16, 28], "s1": [3, 28], "sub": 3, "mul": [3, 19, 27], "u2": 3, "v3": [3, 16, 21], "u3": 3, "v2": [3, 4, 16], "irnod": 3, "those": [3, 5, 9, 12, 17, 18, 21], "long": [3, 14, 17, 18, 19, 21, 27], "term": [3, 10, 14, 17, 19, 27], "rid": [3, 19, 27], "composit": [3, 5], "end": [3, 5, 10, 12, 13, 15, 16, 17, 18, 21, 24, 25], "exp": 3, "pow": 3, "norm_exp": 3, "vector": [3, 10], "involv": [3, 19, 27, 28], "don": [3, 5, 13, 14, 15, 16, 17, 19, 24, 27], "t": [3, 5, 9, 12, 13, 14, 15, 16, 17, 19, 20, 24, 25, 27, 28, 29, 30], "build_cpp_test": 3, "skip": [3, 5, 17, 22], "desir": [3, 9, 18, 30], "test_ptxla": 3, "gtest_filt": 3, "atenxlatensortest": 3, "testab": 3, "correct": [3, 19, 27], "counter": [3, 5, 12, 17], "correctli": [3, 17, 25], "gt": [3, 4, 9, 15, 18], "erf": 3, "erfc": 3, "erfinv": 3, "pull": [3, 9, 20, 21, 24], "3659": 3, "binary_cross_entropi": [3, 20], "backward": [3, 5, 9, 14, 15, 16, 20, 21, 22, 24, 25], "3809": 3, "scalar": [3, 5, 17, 19, 27], "addcdiv": 3, "addcmul": 3, "3768": 3, "neg": 3, "index": [3, 4, 6, 12, 15, 16, 17, 18, 31], "amin": 3, "amax": 3, "3771": 3, "special": [3, 9, 10, 18, 28], "partial": [3, 19, 24, 25, 27], "adaptive_avgpool3d": 3, "3790": 3, "guid": [4, 9, 11, 15, 16, 18, 24, 25, 28], "interact": [4, 15], "start": [4, 14, 15, 16, 17, 18], "colab": [4, 17], "kaggl": 4, "preinstal": [4, 15], "ecosystem": [4, 26], "packag": [4, 10, 11, 16, 18, 20, 21], "date": 4, "list": [4, 5, 12, 18, 20, 23, 28], "readm": [4, 17, 18], "prerequisit": 4, "remot": 4, "quota": 4, "about": [4, 14, 15, 16, 18, 19, 27], "request": [4, 5, 12, 17, 18, 19, 20, 21, 27, 28], "offici": [4, 17], "ssh": [4, 15, 16, 18], "regist": [4, 5, 6, 7, 15, 30], "agent": 4, "alreadi": [4, 8, 10, 12, 17, 18, 19, 21, 24, 27, 30], "befor": [4, 7, 8, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 28, 30], "begin": [4, 28], "zone": [4, 15, 16, 18], "tpu_typ": 4, "8": [4, 9, 10, 12, 14, 15, 16, 18, 19, 21, 22, 26, 27, 28, 29], "vm": [4, 15, 16, 17, 18, 21], "assum": [4, 6, 8, 12, 16, 19, 21, 25, 27, 28], "id_ed25519": 4, "ubuntu2204": 4, "base": [4, 7, 12, 14, 15, 17, 18, 19, 24, 27, 28, 29], "metadata": [4, 17], "cat": [4, 20], "pub": 4, "ip": [4, 12, 15, 30, 31], "format": [4, 12, 17, 18, 22, 26], "valu": [4, 5, 9, 10, 12, 13, 15, 17, 18, 19, 23, 27, 28, 31], "networkendpoint": 4, "accessconfig": 4, "externalip": 4, "123": 4, "give": [4, 9, 17, 18, 26, 28, 29], "friendli": 4, "easier": [4, 14, 18, 19, 27], "echo": 4, "host": [4, 12, 15, 16, 17, 18, 20, 24, 30, 31], "n": [4, 12, 21, 26], "hostnam": 4, "test": [4, 6, 8, 9, 10, 13, 15, 21, 24, 31], "v": [4, 8, 9, 15, 19, 27], "palett": 4, "select": [4, 12, 15, 30], "visualstudio": 4, "doc": [4, 12, 14, 15, 16, 19, 25, 27, 28], "__": [4, 15], "just": [4, 8, 14, 15, 16, 19, 21, 24, 27, 30], "titl": [4, 15], "open": [4, 5, 6, 9, 15, 17], "window": 4, "termin": [4, 30], "mkdir": 4, "ptxla": 4, "Then": [4, 9, 18], "ui": 4, "venv": 4, "virtual": [4, 12], "latest": [4, 9], "releas": [4, 6, 7, 8, 15, 16, 17, 18, 22, 24, 25, 26, 28], "pip": [4, 8, 9, 10, 18], "numpi": [4, 8, 9, 12, 18, 29], "f": [4, 8, 9, 12, 16, 21, 24, 26, 30], "googleapi": [4, 8, 18], "libtpu": [4, 6, 15], "html": [4, 8, 15, 24], "import": [4, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 28, 29, 30], "set_device_typ": 4, "print": [4, 9, 12, 13, 15, 16, 17, 18, 19, 21, 22, 27, 28, 30], "real_devic": 4, "run": [4, 5, 8, 10, 11, 12, 13, 14, 15, 19, 20, 21, 22, 26, 27, 30], "2": [4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 22, 23, 24, 26, 28, 31], "3": [4, 5, 6, 8, 9, 10, 12, 14, 17, 18, 22, 23, 24, 26, 28], "4": [4, 6, 8, 9, 12, 15, 16, 17, 18, 19, 22, 23, 24, 26, 27, 28, 29], "5": [4, 7, 9, 12, 13, 17, 18, 19, 21, 24, 26, 27], "7": [4, 12, 17, 21, 22], "number": [4, 10, 12, 13, 14, 15, 17, 18, 24, 28, 29], "vari": [4, 15, 19, 25, 27], "That": [4, 19, 27], "now": [4, 7, 9, 10, 14, 15, 16, 18, 19, 27, 28], "realist": 4, "librari": [5, 6, 18, 29, 30], "offer": [5, 9, 25, 26], "implement": [5, 7, 8, 9, 14, 15, 17, 19, 22, 24, 25, 27], "xla": [5, 9, 10, 13, 14, 15, 19, 21, 23, 25, 30, 31], "its": [5, 7, 9, 13, 15, 16, 17, 21, 22, 24, 28, 29], "convert": [5, 12, 16, 21], "higher": [5, 17, 30], "level": [5, 17, 18, 22, 26, 30], "represent": [5, 12, 16, 18, 29], "hlo": [5, 12, 16, 17, 18], "beyond": 5, "scope": 5, "forward": [5, 9, 14, 20, 21, 22, 25, 26], "haven": [5, 19, 27], "yet": [5, 7], "caus": [5, 12, 14, 15, 16, 17, 18, 19, 20, 27], "signific": [5, 17, 18, 22], "slowdown": [5, 17, 21], "must": [5, 6, 7, 12, 15, 16, 17, 25, 30, 31], "best": [5, 8, 22, 26], "perform": [5, 7, 8, 9, 10, 12, 14, 16, 20, 21, 22, 24, 26, 28], "what": [5, 16, 18], "debug": [5, 14, 19, 26, 27], "pt": [5, 15, 16, 17, 18], "profil": [5, 15], "_ctc_loss": [5, 17], "_ctc_loss_backward": [5, 17], "contribut": 5, "definit": [5, 16, 19, 27], "native_funct": 5, "after": [5, 7, 9, 12, 15, 16, 17, 18, 19, 23, 27, 28], "kernel": [5, 9, 11, 19, 26, 27], "aten_fallback": 5, "h": 5, "search": 5, "repo": [5, 16, 17, 18, 21], "sequenc": [5, 12], "explicitli": [5, 16, 17, 18, 19, 20, 27], "compos": 5, "match": [5, 9, 12, 16, 17], "serv": 5, "interfac": [5, 6, 16, 17, 25, 30], "machineri": 5, "registerxla": 5, "registerautogradxla": 5, "entri": [5, 6, 9], "pytorch_xla": 5, "world": [5, 8, 15, 19, 22, 27, 30], "written": [5, 18, 30], "paramet": [5, 12, 15, 16, 17, 20, 21, 25, 28, 30, 31], "result": [5, 7, 12, 13, 15, 16, 17, 18, 21, 23, 28], "dispatch": [5, 30], "wrapper": [5, 16, 21, 24, 25], "inplac": [5, 12, 28], "ir": [5, 9, 12, 17, 18, 19, 27], "insid": [5, 9, 16, 18, 28], "stand": 5, "intermedi": [5, 15, 17, 18], "smaller": [5, 18, 19, 27], "inherit": 5, "dai": 5, "addit": [5, 6, 10, 15, 16, 17, 18, 20, 21], "unless": [5, 17, 19, 27], "want": [5, 12, 14, 15, 16, 17, 18, 19, 22, 27, 28, 31], "verifi": 5, "test_oper": 5, "test_aten_xla_tensor": 5, "yield": [5, 16, 17], "break": [5, 18, 19, 27], "grasp": 5, "capabl": 5, "how": [5, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 27, 28, 31], "similar": [5, 15, 18, 21, 23, 26], "minim": [5, 18], "pr": [5, 17, 24], "vanilla": 5, "lerp": 5, "variant": [5, 12, 19, 20, 27], "lerp_": 5, "scalar_out": 5, "tensor_out": 5, "prototyp": [5, 9, 28], "weight": [5, 9, 12, 17, 25, 26], "lerp_out": 5, "howev": [5, 8, 9, 17, 18, 28], "namespac": [5, 17], "wrapper_scalar_lerp": 5, "No": [5, 13, 15, 19, 26, 27], "deviceguard": 5, "omit": [5, 15, 29, 31], "anonym": 5, "wrapper_scalar_lerp_": 5, "wrapper_scalar_lerp__tmp": 5, "_copy_from": 5, "m": [5, 7, 9, 19, 24, 27], "impl": [5, 7, 9], "torch_fn": 5, "automat": [5, 6, 11, 12, 15, 16, 17, 18, 19, 24, 27, 29, 30], "u": [5, 15, 17, 18, 19, 22, 27], "explicit": [5, 20, 24], "place": [5, 7, 12, 18, 20, 28, 30], "ll": [5, 19, 27], "interned_str": 5, "symbol": [5, 19, 27], "submit": [5, 17, 18, 20], "team": [6, 22], "direclti": 6, "tf": [6, 17, 19, 27], "close": 6, "expos": [6, 15, 16, 18, 28], "deviceplugin": 6, "handl": [6, 14, 17, 19, 24, 25, 27, 28, 29], "short": [6, 17, 19, 27], "pjrtclient": 6, "mirror": 6, "pjrt_api": 6, "straightforward": [6, 12, 18], "detail": [6, 7, 8, 9, 12, 13, 15, 16, 17, 18, 19, 27], "concret": [6, 19, 27], "placehold": 6, "pjrt_library_path": 6, "extra": [6, 21, 25], "multiprocess": [6, 12, 15, 16], "compon": 6, "least": [6, 18], "cpuplugin": 6, "def": [6, 7, 8, 9, 10, 12, 14, 15, 16, 18, 21, 22, 23, 25, 26], "library_path": 6, "o": [6, 9, 15, 21], "join": [6, 12], "dirnam": 6, "__file__": 6, "pjrt_c_api_cpu_plugin": 6, "identifi": [6, 12, 30], "exmapl": 6, "pyproject": 6, "toml": 6, "torch_xla_cpu_plugin": 6, "With": [6, 8, 9, 13, 15, 19, 22, 27], "initi": [6, 7, 9, 12, 15, 16, 18, 21, 23, 30], "experiment": [6, 8, 9, 10, 11, 13, 14, 15, 16, 21, 22, 23, 25, 28, 30], "state": [6, 12, 24], "becom": [6, 8, 9, 15, 17, 18, 19, 27], "stabl": [6, 15, 24], "xla_model": [7, 9, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 29], "adopt": [7, 19, 27], "traceabl": [7, 12], "commun": [7, 8, 12, 15, 16, 18, 22, 29], "reimplement": [7, 15], "_c10d_function": 7, "figur": [7, 13, 29], "show": [7, 9, 15, 16, 17, 21], "all_reduc": [7, 12, 20], "between": [7, 13, 15, 16, 17, 18, 19, 20, 21, 23, 27, 28], "processgroupxla": 7, "deriv": 7, "processgroup": 7, "xla_backend": [7, 15, 21, 30], "_create_xla_process_group": 7, "prefix_stor": 7, "rank": [7, 12, 15, 21, 24, 29, 30], "timeout": 7, "assert": [7, 21], "xr": [7, 12, 15, 16, 20, 21, 24, 25, 28, 29, 30], "is_spmd": [7, 12], "spmd": [7, 11, 16, 18, 30], "group": [7, 12, 15, 21, 28], "_register_xla_backend": 7, "dist": [7, 15, 21, 30], "register_backend": 7, "allreduc": 7, "all_reduce_opt": 7, "allgath": 7, "output_tensors_list": 7, "input_tensor": 7, "opt": [7, 16], "none": [7, 8, 9, 12, 17, 25, 28, 29], "_mp_fn": [7, 15, 16], "init_process_group": [7, 15, 21, 30], "init_method": [7, 15, 30], "progress": [7, 18], "instanc": [7, 8, 12, 24, 30], "blob": [7, 10, 12, 15, 28], "distributed_c10d": 7, "_exception_logg": 7, "all_gath": [7, 12, 15], "tensor_list": 7, "async_op": 7, "fals": [7, 9, 12, 16, 24, 28], "_get_default_group": 7, "certain": [7, 17, 19, 20, 27], "remap": 7, "_functional_collect": 7, "all_reduce_inplac": 7, "eventu": 7, "reach": [7, 12], "rewrit": [7, 18, 19, 27, 28], "reduceop": 7, "group_nam": 7, "torch_library_impl": 7, "four": [7, 18], "oper": [7, 10, 11, 12, 15, 16, 17, 18, 30], "align": [7, 14], "while": [7, 9, 12, 18, 19, 21, 27], "signatur": 7, "remain": [7, 16, 18, 19, 27, 31], "restrict": 7, "appli": [7, 12, 20, 24, 25, 30], "usag": [7, 12, 17, 18, 19, 24, 25, 27, 30], "test_collective_ops_tpu": 7, "demonstr": [7, 18, 20, 25, 30], "scenario": [7, 22], "sum": [7, 12, 20, 24, 25], "reduct": [7, 12], "aggreg": 7, "all_gather_into_tensor": 7, "gather": [7, 12, 28], "reduce_scatter_tensor": 7, "reduc": [7, 12, 13, 14, 15, 16, 17, 18, 24], "across": [7, 12, 15, 16, 17, 24, 29], "all_to_all_singl": 7, "output_split_s": 7, "input_split_s": 7, "although": [7, 15, 19, 27], "accept": [7, 28], "argument": [7, 9, 10, 12, 18, 20, 22, 24], "limit": [7, 12, 15, 16], "reflect": 7, "compromis": 7, "maintain": 7, "constraint": [7, 15, 17], "alltoal": [7, 12], "rise": 8, "openai": [8, 10], "triton": [8, 11], "popular": 8, "order": [8, 12, 16, 17, 18, 28, 29], "pariti": 8, "continu": [8, 15, 22], "push": 8, "let": [8, 15, 16, 17, 18, 22, 29], "custom_kernel": 8, "jax_import_guard": 8, "pl": [8, 15, 16, 28], "jnp": 8, "add_vectors_kernel": 8, "x_ref": 8, "y_ref": 8, "o_ref": 8, "x": [8, 9, 10, 12, 16, 17, 18, 19, 21, 23, 24, 25, 26, 27, 28, 29], "y": [8, 10, 12, 17, 18, 19, 24, 25, 26, 27, 28], "jit": [8, 10, 22], "add_vector": 8, "arrai": [8, 12, 18, 25, 29], "pallas_cal": 8, "out_shap": 8, "shapedtypestruct": 8, "dtype": [8, 9, 10, 15, 19, 20, 26, 27], "otherwis": [8, 12, 17, 18, 19, 25, 27], "program": [8, 9, 10, 12, 17, 18, 19, 22, 27, 28, 29], "hang": 8, "lock": 8, "q": [8, 9], "randn": [8, 9, 12, 14, 15, 16, 21, 22, 26, 28, 29], "128": [8, 9, 15, 24, 26, 31], "k": [8, 9, 17], "make_kernel_from_palla": 8, "pt_kernel": 8, "lambda": [8, 24], "liner": 8, "flash": [8, 10], "attent": [8, 10], "besid": 8, "op": [8, 9, 11, 12, 14, 17, 18, 19, 20, 27, 28, 29], "suppor": 8, "flash_attent": 8, "paged_attent": 8, "queri": [8, 15], "squeez": 8, "dim": [8, 12], "key_cach": 8, "value_cach": 8, "context_len": 8, "block_tabl": 8, "pages_per_compute_block": 8, "megacore_mod": 8, "vllm": 8, "util": [8, 11, 12, 16, 17, 21, 24, 25, 26, 30], "effect": [8, 12], "memori": [8, 11, 12, 13, 17, 18, 19, 24, 27], "kv": 8, "proper": [8, 29], "jax_nightly_releas": 8, "jaxlib_nightly_releas": 8, "exported_program_to_stablehlo": 9, "xm": [9, 12, 14, 16, 17, 18, 20, 21, 22, 23, 24, 25, 26, 29], "torchvis": [9, 14, 22], "xla_devic": [9, 12, 15, 16, 17, 18, 20, 21, 22, 23, 26, 29], "resnet18": [9, 14, 22], "sampl": [9, 12, 15, 17], "tupl": [9, 12, 19, 23, 25, 27, 29], "sample_input": 9, "224": [9, 14], "stablehlo_program": 9, "callabl": [9, 12, 24], "get_stablehlo_text": 9, "get_stablehlo_bytecod": [9, 12], "sample_input_xla": 9, "output2": 9, "allclos": 9, "atol": 9, "1e": [9, 17, 22], "One": [9, 12, 13, 18, 24], "tmp": [9, 16, 17, 24], "stablehlo_dir": 9, "empti": [9, 12], "doesn": [9, 16, 17, 19, 25, 27], "load": [9, 10, 12, 15, 17, 21, 24, 26, 30], "stablehlographmodul": 9, "stablehlo_program2": 9, "output3": 9, "server": [9, 12, 15, 18], "env": [9, 12, 15, 28], "nightli": [9, 17, 18, 24, 28], "resnet_tf": 9, "p": [9, 15, 17, 19, 27], "8500": 9, "mount": [9, 16], "model_nam": 9, "accomplish": 9, "tf_saved_model_integr": 9, "save_torch_module_as_tf_saved_model": 9, "nn": [9, 12, 15, 16, 21, 22, 24, 26, 28], "trace": [9, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 27, 28], "exported_model": 9, "exportedprogram": 9, "pathlik": 9, "stablehloexportopt": 9, "alias": [9, 17, 20], "save_torch_model_as_stablehlo": 9, "torchmodel": 9, "arg": [9, 12, 16, 18, 23, 24], "constant": [9, 17, 18, 28], "ndarrai": [9, 12], "human": 9, "readabl": [9, 18], "mlir": 9, "form": [9, 15, 17, 19, 27, 31], "posit": [9, 12], "meta": 9, "l__fn___layers_15_feed_forward_w2": 9, "l__fn___layers_13_feed_forward_w1": 9, "l__fn___layers_3_attention_wo": 9, "l__fn___layers_12_ffn_norm_weight": 9, "l__fn___layers_25_attention_wo": 9, "serial": [9, 15, 16], "stablehlofunc": 9, "stage": 9, "guarante": [9, 12], "plan": [9, 13, 15], "major": 9, "agre": [9, 18], "scaled_dot_product_attent": 9, "decompos": 9, "low": [9, 13, 17], "dure": [9, 12, 17, 18, 22, 24, 28], "lower": [9, 11, 17, 18, 19, 20, 27], "captur": [9, 12, 17, 18], "downstream": [9, 20], "ml": [9, 29], "crucial": 9, "geneart": 9, "pattern": [9, 17, 19, 22, 27], "bunch": 9, "challeng": 9, "error": [9, 12, 16, 17], "prone": 9, "robust": 9, "outlin": [9, 26], "stablehlocompositebuild": 9, "arbitari": 9, "region": [9, 12, 14, 17, 20, 28], "non": [9, 12, 14, 19, 20, 27, 29], "hardcod": [9, 28], "store": [9, 10, 12, 17], "attribut": 9, "retriev": [9, 12, 16, 19, 22, 27, 28], "pratic": 9, "scaled_product_attent": 9, "mark_pattern_util": 9, "__init__": [9, 21, 26], "super": [9, 21, 22], "q_proj": 9, "linear": [9, 12, 15, 16, 20, 21, 26], "bia": 9, "k_proj": 9, "v_proj": 9, "builder": 9, "b": [9, 12, 15, 18, 19, 20, 22, 27, 29], "sdpa": 9, "25": [9, 13], "other_attr": 9, "val": 9, "mark_input": 9, "attn_out": 9, "mark_output": 9, "input_arg": 9, "10": [9, 12, 15, 16, 17, 18, 19, 21, 22, 23, 26, 27, 30], "stablehlo_gm": 9, "shown": [9, 15, 19, 27], "irtohlo": 9, "56": 9, "mhlo": 9, "cross_program_prefetch": 9, "input_output_alia": 9, "is_dynam": 9, "use_auto_spmd_partit": 9, "func": 9, "arg0": 9, "10x8x128xf32": 9, "arg1": 9, "128x128xf32": 9, "arg2": 9, "arg3": 9, "9": [9, 18, 19, 21, 24, 27], "composite_attribut": 9, "500000e": 9, "01": [9, 10], "f32": 9, "decomposit": 9, "11": [9, 17, 19, 27], "privat": [9, 15], "actual": [9, 14, 18, 19, 21, 27, 28], "encapsul": 9, "propag": [9, 17], "high": [10, 13, 18, 21, 26], "deep": [10, 11, 17], "learn": [10, 15], "languag": 10, "empow": 10, "full": [10, 12, 16, 17, 24], "potenti": [10, 12, 15, 17, 25], "given": [10, 12, 17, 18, 19, 21, 24, 27, 29], "add_kernel": 10, "x_ptr": 10, "pointer": 10, "y_ptr": 10, "output_ptr": 10, "n_element": 10, "block_siz": 10, "tl": 10, "constexpr": 10, "element": [10, 12, 19, 25, 27, 28], "tutori": [10, 17, 18, 21, 28], "l28": 10, "pid": 10, "program_id": 10, "axi": [10, 12, 25], "block_start": 10, "offset": 10, "arang": 10, "mask": [10, 17, 19, 27], "xla_triton": 10, "16": [10, 16, 18, 24, 26, 29], "int64": 10, "empty_lik": 10, "grid": 10, "cdiv": 10, "triton_cal": 10, "itself": [10, 12, 24], "kwarg": [10, 12, 24, 28], "payload": [10, 12, 15], "regard": [10, 16, 22], "buffer": [10, 12], "_xla_gpu_custom_cal": 10, "dep": 10, "connect": [11, 12, 15, 28], "overview": [11, 29], "eager": [11, 12, 19, 21, 26, 27], "mode": [11, 12, 19, 21, 26, 27, 28, 30], "troubleshoot": 11, "palla": 11, "stablehlo": [11, 12], "mix": [11, 12, 29], "precis": 11, "advanc": [11, 29], "topic": [11, 29], "distribut": [11, 16, 17, 21, 24, 25, 28, 29], "checkpoint": [11, 15, 18, 24, 29], "distributeddataparallel": [11, 15], "ddp": [11, 15], "torchdynamo": 11, "while_loop": 11, "shard": [11, 12, 29, 30], "quantiz": 11, "recompil": [11, 13, 14, 16, 17, 18], "hardwar": [11, 12, 17, 18, 20], "plugin": [11, 15], "bazel": 11, "int": [12, 15, 19, 27, 28], "device_count": [12, 28], "address": [12, 15, 28, 31], "wait": [12, 17, 18], "pend": [12, 14], "whether": [12, 16, 20], "block": [12, 18, 24, 28], "finish": [12, 18], "full_graph": 12, "num_different_graphs_allow": 12, "lazytensor": [12, 14, 18], "repres": [12, 15, 19, 27], "happen": [12, 14, 15, 16, 17, 18, 19, 27], "decid": [12, 17, 19, 27], "funciton": 12, "act": [12, 16], "context": [12, 15, 17, 19, 20, 27], "throw": [12, 16], "info": [12, 17, 19, 27, 29], "exit": [12, 17, 20, 21], "pt_xla_debug": 12, "messag": [12, 17], "dump": [12, 17], "allow": [12, 16, 17, 18, 20, 28, 29, 30], "rais": [12, 17], "exceed": 12, "foo": 12, "sin": 12, "co": 12, "foo2": 12, "compiled_foo2": 12, "manual_se": [12, 15], "seed": 12, "random": [12, 14, 15, 18, 26], "integ": [12, 17], "rng": [12, 15], "device_typ": 12, "local_process_count": 12, "local_device_count": 12, "total": [12, 19, 27, 29], "addressable_device_count": 12, "visibl": [12, 19, 27], "global_device_count": 12, "global_runtime_device_count": [12, 25, 28, 29], "especi": [12, 15, 18, 22, 28], "world_siz": [12, 15, 20, 21, 24, 28], "particip": [12, 15], "job": [12, 18, 22], "global_ordin": [12, 15, 16, 21, 24], "global": [12, 15, 16, 28, 30], "ordin": [12, 16], "thread": [12, 15, 16, 17, 30], "predict": 12, "relationship": [12, 16, 17], "worker": [12, 15, 16, 18, 24, 30], "id": [12, 15, 17, 18], "nor": 12, "contigu": [12, 16, 17], "local_ordin": 12, "get_master_ip": 12, "master": [12, 15, 16, 30], "discoveri": 12, "use_spmd": [12, 28, 29, 30], "forc": [12, 15, 17, 19, 23, 27], "mean": [12, 15, 16, 17, 18, 19, 21, 25, 27, 28], "replic": [12, 28, 29], "spmd_advanc": 12, "md": [12, 15], "initialize_cach": [12, 16], "readonli": [12, 16], "persist": [12, 16, 30], "devkind": 12, "cuda": [12, 15, 16, 18, 19, 20, 26, 27, 31], "deprec": 12, "xla_device_hw": 12, "union": 12, "real": [12, 22], "is_master_ordin": 12, "multi": [12, 13, 28, 31], "num_host": 12, "boolean": 12, "indic": [12, 17, 18, 19, 27], "reduce_typ": 12, "float": [12, 19, 20, 27], "pin_layout": 12, "reduce_sum": 12, "reduce_mul": 12, "reduce_and": 12, "reduce_or": 12, "reduce_min": 12, "reduce_max": 12, "replica": [12, 15], "layout": [12, 26], "pine": 12, "prevent": [12, 18, 20, 22, 28], "corrupt": 12, "unpin": 12, "hlomodul": 12, "constrain": [12, 15], "hold": [12, 28, 29], "along": [12, 24], "dimens": [12, 13, 28, 29], "all_to_al": 12, "split_dimens": 12, "concat_dimens": 12, "split_count": 12, "www": 12, "org": [12, 15, 24], "operation_semant": 12, "upon": 12, "split": 12, "concat": 12, "count": [12, 17], "add_step_closur": 12, "closur": 12, "run_async": 12, "step": [12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 27, 28, 30], "mani": [12, 15, 17, 18, 19, 27, 31], "report": 12, "consol": 12, "post": [12, 17], "tensorboard": [12, 18], "etc": [12, 14, 17, 19, 27, 28], "intermediari": 12, "inspect": 12, "typic": 12, "barrier": [12, 15, 16, 18], "materi": [12, 17, 18, 19, 27, 28], "queu": 12, "though": [12, 16, 21], "advis": 12, "throttl": 12, "event": 12, "asynchron": [12, 28, 30], "wait_device_op": 12, "async": [12, 22], "whose": [12, 13], "optimizer_step": [12, 16, 18, 20, 21, 24], "optimizer_arg": 12, "dict": [12, 24], "gradid": 12, "parallelload": [12, 28], "dataparallel": 12, "loader": [12, 17, 18, 22], "dictionari": 12, "gradient": [12, 16, 20, 24, 30], "save": [12, 17, 24, 30], "file_or_path": 12, "textio": 12, "master_onli": [12, 24], "global_mast": 12, "transfer": [12, 15, 17, 18, 28], "care": [12, 16, 17, 19, 27], "taken": [12, 16, 17, 19, 21, 27, 30], "view": [12, 16, 17], "recreat": [12, 16], "destin": [12, 16], "nest": [12, 24], "locat": 12, "control": [12, 13, 16, 17, 28], "obj_to_sav": 12, "path_to_sav": 12, "rendezv": 12, "tag": [12, 15], "byte": 12, "mesh": [12, 15, 25], "xla_rendezv": 12, "sent": [12, 17], "exchang": 12, "mesh_reduc": 12, "reduce_fn": 12, "toxlatensorarena": 12, "receiv": 12, "copi": [12, 15, 16, 17, 18], "np": [12, 25, 29], "accuraci": [12, 21, 24], "test_accuraci": 12, "set_rng_stat": 12, "get_rng_stat": 12, "get_memory_info": 12, "memoryinfo": 12, "bytes_us": 12, "290816": 12, "bytes_limit": 12, "34088157184": 12, "peak_bytes_us": 12, "500816": 12, "get_stablehlo": 12, "var": [12, 28], "xla_hlo_debug": [12, 17], "root": [12, 19, 27], "bytecod": [12, 22], "parallel_load": [12, 15, 16, 17], "mpdeviceload": [12, 16, 18, 28], "dataload": [12, 16, 18, 21, 28, 30], "background": [12, 30], "upload": [12, 18, 28], "per_device_load": [12, 28], "constructor": 12, "train_device_load": 12, "train_load": [12, 16, 28], "xla_multiprocess": 12, "spawn": [12, 15, 16, 18], "fn": 12, "nproc": [12, 15], "daemon": 12, "start_method": 12, "moment": 12, "maximum": [12, 13, 18, 26], "valueerror": 12, "mark_shard": [12, 25, 28, 29], "xlashardedtensor": [12, 30], "partition_spec": [12, 28, 29], "annot": [12, 28, 29], "partit": [12, 28], "spec": [12, 28], "intern": [12, 15, 16, 17, 19, 27, 28, 31], "spmdpartition": [12, 28], "topologi": [12, 16, 28, 29], "device_mesh": [12, 28], "mesh_shap": [12, 25, 28, 29], "ax": [12, 28, 29], "impact": [12, 15, 17, 19, 21, 27], "dynamo_custom_op": 12, "dynamo": [12, 18, 22, 26], "recogniz": 12, "num_devic": [12, 25, 28, 29], "device_id": [12, 25, 28, 29], "32": [12, 17, 18], "clear_shard": 12, "clear": 12, "cast": [12, 20], "t1": [12, 16, 17, 29], "get_1d_mesh": 12, "set_global_mesh": 12, "get_global_mesh": 12, "axis_nam": [12, 28], "v4": [12, 14, 15, 16, 18, 22, 28], "ravel": 12, "reshap": 12, "fill": 12, "assign": [12, 16, 18], "Its": 12, "length": [12, 19, 27], "len": [12, 18], "get_xla_supported_devic": 12, "get_logical_mesh": 12, "ordereddict": [12, 28, 29], "hybridmesh": [12, 28], "ici_mesh_shap": [12, 28], "dcn_mesh_shap": [12, 28], "hybrid": 12, "ici": 12, "dcn": [12, 28], "increas": 12, "intens": 12, "mdl": 12, "inner": [12, 24, 28], "outer": [12, 24, 25, 28], "slice": [12, 18, 28], "fsdp": [12, 24, 25, 28, 29], "eager_mod": [12, 14], "wa": [12, 15, 17, 18, 19, 27, 30], "d": [12, 13, 19, 20, 27], "eagerli": [12, 14, 16, 17, 19, 27], "metric": [12, 21], "metrics_report": [12, 17], "short_metrics_report": [12, 17], "counter_nam": 12, "metric_nam": 12, "activ": [12, 16, 17, 21, 24, 25, 26], "counter_valu": 12, "metric_data": 12, "total_sampl": 12, "accumul": 12, "retain": 12, "circular": 12, "natur": 13, "in_tensor": 13, "randint": [13, 26], "out_tensor": 13, "nonzero": [13, 17, 18, 19, 27], "word": [13, 19, 27], "further": [13, 18, 21], "categor": 13, "unbound": 13, "alloc": 13, "infinit": [13, 25], "phase": 13, "layer": [13, 14, 24, 25, 28], "perceptron": 13, "mlp": 13, "xla_experiment": 13, "masked_select": 13, "masked_scatt": 13, "your_script": [13, 18], "100": [13, 17, 24], "29": [13, 21, 22], "49": [13, 22], "20": [13, 16, 17, 21, 26], "03": 13, "102": 13, "hit": [13, 19, 27], "198": 13, "1953": 13, "motiv": 13, "excess": 13, "half": 13, "drop": [13, 17], "try": [13, 17, 18, 19, 27], "python3": [13, 15, 16, 17, 18, 24], "test_dynamic_shape_model": 13, "testdynamicshapemodel": 13, "test_backward_pass_with_dynamic_input": 13, "expand": [13, 22], "feel": [13, 17, 21], "review": [13, 25], "rfc": [13, 28, 31], "64": [14, 22, 24], "mark_step": [14, 15, 16, 17, 18, 21], "drawback": 14, "approach": [14, 19, 21, 24, 27], "often": [14, 17, 19, 27], "confus": 14, "preprocess": [14, 26], "small": [14, 17, 18, 19, 21, 22, 27], "leak": 14, "expens": [14, 17, 19, 27], "hard": [14, 19, 21, 22, 27], "why": [14, 19, 27], "mitig": 14, "ux": 14, "mark": [14, 16], "compiled_model": 14, "right": [14, 19, 22, 27], "awai": 14, "pretti": [14, 16, 19, 21, 27], "straight": 14, "enter": 14, "reenabl": 14, "perfomr": 14, "compar": [14, 15, 16, 17, 21, 22, 23], "recommen": 14, "overhad": 14, "step_fn": 14, "loss_fn": [14, 15, 16, 20, 21, 22], "zero_grad": [14, 15, 16, 20, 21], "logit": [14, 25], "loss": [14, 15, 16, 20, 22, 24, 25], "ask": [14, 17, 19, 27], "refactor": 14, "decod": 14, "much": [14, 15, 16, 18, 19, 22, 27], "llama2": 14, "fake": [14, 30], "chip": [14, 15], "300": [14, 17], "observ": [14, 15, 21], "147": 14, "65": [14, 17], "45": 14, "train_decoder_only_bas": [14, 17], "perfomran": 14, "tri": [14, 18], "resnet50": [14, 15, 16, 22, 24], "exepct": 14, "loop": [14, 16, 17, 18, 19, 25, 27, 30], "meant": 14, "encount": [15, 17, 18], "bug": [15, 17, 21], "r2": [15, 17, 28], "init": [15, 16, 21, 22, 23], "renam": 15, "torchrun": [15, 16, 31], "xpu": 15, "neuron": 15, "xrt_tpu_config": 15, "30": [15, 24], "thousand": 15, "preview": 15, "safe": 15, "section": [15, 16, 17, 18, 28], "broadcast": 15, "broadcast_master_param": 15, "pjrt_backend": 15, "diff": [15, 18], "42": 15, "gradient_as_bucket_view": [15, 21], "mseloss": [15, 21], "sgd": [15, 16, 20, 21, 22], "lr": [15, 16, 21, 22, 24, 25], "001": [15, 21], "confirm": 15, "__name__": [15, 16, 21], "__main__": [15, 16, 21], "localservic": 15, "localhost": [15, 21], "51011": 15, "master_addr": [15, 21], "master_port": [15, 21], "12355": [15, 21, 31], "Or": [15, 16, 19, 27], "overhead": [15, 21, 22], "grpc": 15, "torchbench": 15, "35": [15, 17], "tpuvm": [15, 16, 18, 28], "2048": 15, "mnist": [15, 16, 17, 20], "test_train_mp_mnist": [15, 21], "fake_data": [15, 17, 21, 31], "alpha": [15, 16], "central2": [15, 18], "git": [15, 17, 18, 24], "depth": [15, 17], "branch": [15, 17, 19, 27], "test_train_mp_imagenet": [15, 17, 21], "batch_siz": [15, 24, 31], "256": 15, "num_epoch": [15, 21, 24], "By": [15, 19, 27], "tpu_process_bound": 15, "tpu_visible_chip": 15, "r1": 15, "13": [15, 16, 21, 23], "docker_imag": 15, "gcr": 15, "io": [15, 24], "sudo": [15, 18], "rm": 15, "privileg": 15, "net": [15, 18, 20], "gpu_num_devic": 15, "nnode": [15, 31], "num_gpu_devic": 15, "pjrt_distribut": 15, "physic": [15, 28, 29], "12": [15, 17, 22, 24], "number_gpu_vm": [15, 31], "node_rank": [15, 31], "current_node_rank": 15, "nproc_per_nod": [15, 31], "number_local_gpu_devic": 15, "rdzv_endpoint": [15, 31], "internal_ip_address": 15, "multinode_train": 15, "endpoint": [15, 31], "machine_0": 15, "machine_1": 15, "machine_0_internal_ip_address": [15, 31], "ident": 15, "page": 15, "mostli": [15, 24], "interchang": 15, "perspect": [15, 16], "subtl": 15, "importantli": 15, "architectur": [15, 24], "thu": [15, 17], "batch": [15, 16, 17, 18, 28], "latenc": 15, "deseri": 15, "send": [15, 16, 18, 28], "direct": [15, 17], "independ": [15, 16, 17], "significantli": [15, 16, 18], "xla_dist": 15, "scp": [15, 16], "sdk": 15, "collect": [15, 21, 22, 29, 30], "enhanc": 15, "stabil": [15, 17, 20], "xmp": [15, 16, 18], "substanti": 15, "practic": [15, 19, 25, 27], "unreli": 15, "due": [15, 17, 18, 31], "inbound": 15, "could": [15, 18, 19, 27, 28], "failur": 15, "entir": [15, 24], "restart": 15, "impos": 15, "middl": [15, 18, 19, 27], "unwant": 15, "permit": 15, "subset": 15, "old": 15, "alter": 15, "synchron": [15, 16, 18, 28, 30], "consid": [15, 18], "all_gather_object": 15, "gloo": [15, 21, 30], "subgroup": 15, "monitor": 15, "_": [15, 22, 23], "altern": [15, 19, 20, 26, 27], "less": [15, 19, 22, 27], "reliabl": 15, "than": [15, 17, 19, 21, 24, 27], "strongli": 15, "_all_gath": 15, "int32": 15, "zeros_lik": 15, "get_world_s": 15, "averag": 15, "task": 15, "175": 15, "chart": 15, "breakdown": 15, "tfrt": 15, "legaci": 15, "streamexecutor": 15, "tpu_legaci": 15, "comparison": [15, 29], "regular": [16, 17, 18, 26], "t0": 16, "matrix": 16, "multipli": [16, 29], "mm": [16, 20], "neural": 16, "l_in": 16, "l_out": 16, "floattensor": 16, "highlight": [16, 18], "nllloss": 16, "momentum": 16, "switch": [16, 17, 19, 21, 27], "acquir": 16, "mp_device_load": 16, "three": 16, "multithread": [16, 17], "own": [16, 24], "onto": 16, "preload": [16, 18], "overlap": [16, 18, 22, 28], "batches_per_execut": 16, "consolid": [16, 24], "all_reduce_gradi": 16, "parent": 16, "talk": 16, "basi": 16, "howto": 16, "focu": [16, 19, 27], "train_mnist_xla": 16, "outsid": 16, "infrastructur": 16, "awar": 16, "fakedata": 16, "But": [16, 17, 19, 27], "immedi": [16, 28], "hand": 16, "record": [16, 17, 18], "defer": 16, "fuse": [16, 18], "invis": 16, "caller": 16, "insert": [16, 18], "paper": 16, "opaqu": [16, 17], "appear": [16, 17, 18], "unlik": [16, 18], "adjust": 16, "preserv": [16, 17], "appreci": 16, "accommod": 16, "previous": 16, "state_dict": [16, 24, 30], "footprint": 16, "xser": 16, "stream": 16, "amount": [16, 17, 18, 19, 27], "restor": 16, "load_state_dict": [16, 30], "unavail": [16, 17], "consum": [16, 19, 27], "disk": 16, "occur": 16, "your_cache_path": 16, "mp_fn": 16, "xla_cache_": 16, "runnabl": [16, 21, 25], "subject": 17, "peculiar": 17, "detial": 17, "__version__": 17, "cu121": 17, "t2": [17, 29], "200": 17, "rx": 17, "conclud": 17, "diagnos": 17, "extrem": 17, "pt_xla_debug_level": 17, "slip": 17, "analyz": [17, 18], "summari": 17, "compiletim": 17, "frequent": 17, "21": 17, "transferfromdevicetim": 17, "23": 17, "hash": 17, "c74c3b91b855b2b123f833b0d5f86943": 17, "107": 17, "frame": 17, "trigger": [17, 18, 19, 27], "dk3": 17, "1055": 17, "44": 17, "__next__": 17, "train_loop_fn": 17, "48": [17, 21], "start_train": 17, "73": 17, "548000": 17, "gb": 17, "922460": 17, "547871": 17, "124478": 17, "028210": 17, "steptrac": 17, "frequenc": 17, "pair": 17, "met": 17, "spent": [17, 18], "destroi": 17, "percentil": 17, "totalsampl": 17, "202": 17, "06m09s401ms746": 17, "001u": 17, "valuer": 17, "778ms572": 17, "062u": 17, "rate": [17, 21], "425201": 17, "001ms32": 17, "778u": 17, "001ms61": 17, "283u": 17, "001ms79": 17, "236u": 17, "001ms110": 17, "973u": 17, "50": [17, 18, 23], "001ms228": 17, "773u": 17, "80": 17, "001ms339": 17, "183u": 17, "90": 17, "001ms434": 17, "305u": 17, "95": 17, "002ms921": 17, "063u": 17, "99": [17, 21], "21s102ms853": 17, "173u": 17, "cachedsynctensor": 17, "395": [17, 21], "area": 17, "rout": 17, "qualifi": 17, "33": [17, 21, 22], "_local_scalar_dens": 17, "epoch": [17, 18, 24], "clear_al": 17, "xla_dynamo_debug": 17, "bottleneck": [17, 18], "notebook": 17, "train_resnet_benchmark": 17, "behav": 17, "evalu": [17, 18, 19, 27], "suggest": 17, "bad": 17, "degrad": [17, 18], "speedup": [17, 22], "indirect": 17, "solut": [17, 19, 26, 27], "variat": 17, "pad": [17, 18, 19, 27], "fix": [17, 18, 22, 25], "translat": 17, "item": [17, 18], "substitut": 17, "flow": 17, "clip_grad_norm": 17, "problemat": 17, "clip_grad_norm_": 17, "dramat": 17, "total_norm": 17, "zero": [17, 24, 30], "param_norm": 17, "grad": 17, "norm": 17, "norm_typ": 17, "add_": 17, "clip_coef": 17, "max_norm": 17, "mul_": 17, "data_parallel": 17, "last": 17, "dataset": [17, 21, 24], "stride": 17, "reconstruct": 17, "shallow": 17, "ty": 17, "made": [17, 18, 19, 27, 28], "_get_xla_tensors_text": [17, 19, 27], "_get_xla_tensors_hlo": 17, "prior": [17, 30], "degre": 17, "xla_ir_debug": 17, "henc": [17, 22], "respons": [17, 18, 22, 30], "xla_save_tensors_fil": 17, "realli": [17, 19, 22, 27], "big": [17, 19, 27], "left": 17, "append": 17, "sheet": 17, "xla_save_tensors_fmt": 17, "text": 17, "dot": 17, "graphviz": 17, "xla_flag": 17, "xla_dump_to": 17, "dir_nam": 17, "unoptim": 17, "optimz": 17, "xla_metrics_fil": 17, "xla_save_hlo_fil": 17, "offend": 17, "xla_sync_wait": 17, "xla_use_eager_debug_mod": 17, "bypass": 17, "overal": [17, 18], "optimizaiton": 17, "tf_cpp_log_thread_id": 17, "tf_cpp_vmodul": 17, "vlog": 17, "tf_cpp_min_log_level": 17, "turn": 17, "warn": 17, "tf_vlog": 17, "xla_dump_hlo_graph": 17, "xla_util": 17, "cc": 17, "save1": 17, "xla_graph_executor": 17, "pjrt_computation_cli": 17, "dir": 17, "pytorch_test_with_slow": 17, "test_torch": 17, "test_put_xla_uint8": 17, "torch_test_devic": 17, "pytorch_test_bas": 17, "brief": 18, "basic": [18, 19, 21, 27], "reader": 18, "modif": 18, "fetch": 18, "discuss": [18, 29], "opcod": 18, "fed": 18, "attach": [18, 28], "callback": 18, "xla_tensor_z": 18, "cut": [18, 19, 27], "transferfromdevic": 18, "tell": [18, 19, 27], "properti": [18, 19, 27], "illustr": [18, 29], "suppos": 18, "tensors_on_devic": 18, "z": [18, 19, 27], "subgraph": [18, 19, 27], "signal": 18, "far": 18, "suitabl": 18, "trade": [18, 19, 27], "off": 18, "spend": 18, "fusion": 18, "worth": [18, 19, 27], "latter": [18, 24], "wheel": [18, 24], "runtime_vers": 18, "project_id": 18, "accelerator_typ": 18, "tpu_nam": 18, "your_tpu_nam": 18, "subnetwork": 18, "tpusubnet": 18, "pip3": 18, "cp38": 18, "linux_x86_64": 18, "whl": 18, "apt": 18, "libopenbla": 18, "dev": [18, 21], "libgl1": 18, "guidelin": 18, "bar": 18, "rememb": 18, "txt2img": 18, "prompt": 18, "photograph": 18, "astronaut": 18, "ride": 18, "hors": 18, "relat": 18, "precision_scop": 18, "addition": [18, 20, 24], "particular": 18, "frozenclipembedd": 18, "simplic": [18, 19, 27], "ddim": 18, "top": 18, "attr": 18, "statement": [18, 19, 27], "stop": 18, "fall": [18, 25], "difficult": 18, "readi": 18, "investig": [18, 21], "cover": [18, 28], "huggingfac": 18, "sd": 18, "xl": 18, "cd": [18, 24], "text_to_imag": 18, "inference_tpu_single_devic": 18, "lora": 18, "model_id": 18, "stabilityai": 18, "pipelin": 18, "dpmsolvermultistepschedul": 18, "txt": 18, "invisible_watermark": 18, "transform": [18, 24, 29], "safetensor": 18, "licens": 18, "card": 18, "cli": 18, "_your_copied_token__": 18, "pipe": 18, "hour": 18, "wherea": 18, "likewis": 18, "gpt": 18, "15": 18, "min": 18, "subsequ": 18, "advantag": 18, "mayb": 18, "notic": 18, "piec": 18, "__call__": 18, "commit": 18, "caveat": 18, "rule": [18, 20], "thumb": 18, "durat": [18, 30], "constantli": 18, "idl": 18, "inference_tpu_": 18, "capture_profil": 18, "gap": 18, "xp": 18, "measur": 18, "portion": 18, "busi": 18, "scroll": 18, "occupi": 18, "displai": 18, "largest": 18, "zoom": 18, "timelin": 18, "period": 18, "examin": 18, "did": 18, "pipe_watermark": 18, "closer": 18, "preced": 18, "proceed": [18, 25], "watermark": 18, "cv2": 18, "pywt": 18, "leav": 18, "broken": 18, "rerun": 18, "scale_model_input": 18, "ran": 18, "my_funct": 18, "preocess": 18, "debug_single_process": 18, "magic": [18, 19, 27], "treat": 18, "xla_no_special_scalar": 18, "hurt": [19, 27], "perf": [19, 27], "pov": [19, 27], "sai": [19, 27], "assur": [19, 27], "gone": [19, 27], "coverag": [19, 27], "aim": [19, 25, 27], "explan": [19, 27], "mainli": [19, 27], "problem": [19, 27], "beginn": [19, 27], "propos": [19, 27], "reli": [19, 27], "impract": [19, 27], "assumpt": [19, 27], "ye": [19, 26, 27], "sentenc": [19, 27], "bucket": [19, 27, 30], "kinda": [19, 27], "anti": [19, 27], "frontend": [19, 27], "matter": [19, 27], "workaround": [19, 27], "okai": [19, 27], "teach": [19, 27], "produc": [19, 20, 21, 27], "theoret": [19, 27], "sort": [19, 27], "obviou": [19, 27], "s64": [19, 27], "inde": [19, 27], "_get_xla_tensor_dimension_s": [19, 27], "commonli": [19, 27], "wrong": [19, 27], "wors": [19, 27], "probabl": [19, 27], "know": [19, 21, 27], "upper": [19, 27], "nit": [19, 27], "rand": [19, 27], "solv": [19, 27], "kept": [19, 27], "earli": [19, 27], "accessor": [19, 27], "2d": [19, 25, 27], "implicitli": [19, 27], "doubl": [19, 27], "overload": [19, 27], "explod": [19, 27], "convers": [19, 27], "cheap": [19, 27], "ve": [19, 27], "hoc": [19, 27], "think": [19, 27], "verison": [19, 27], "bla": [19, 27], "blabla": [19, 27], "interpret": [19, 27], "proce": [19, 27], "uglier": [19, 27], "win": [19, 27], "pars": [19, 27], "torchscript": [19, 27], "somehow": [19, 27], "merg": [19, 27], "lazili": [19, 27, 28, 30], "properli": [19, 27], "thought": [19, 27], "trivial": [19, 27], "effort": [19, 27, 28], "side": [19, 27], "bandwidth": [19, 27], "automag": [19, 27], "gold": [19, 27], "smart": [19, 27], "trick": [19, 27], "tbh": [19, 27], "longer": [19, 27], "unawar": [19, 27], "hope": [19, 27], "smash": [19, 27], "blocker": [19, 27], "ahead": [19, 27], "nnc": [19, 27], "exactli": [19, 27], "transpos": [19, 27], "brian": [19, 27], "hirsh": [19, 27], "bdhirsh": [19, 27], "question": [19, 27], "comment": [19, 27], "stick": [19, 27], "torch_warn": [19, 27], "yea": [19, 27], "hei": [19, 27], "won": [19, 20, 27], "blaze": [19, 27], "isn": [19, 27, 30], "abil": [19, 21, 27], "devirtu": [19, 27], "sound": [19, 27], "great": [19, 27], "carri": [19, 27, 28], "truth": [19, 27], "irvalu": [19, 27], "enforc": [19, 21, 27], "discrep": [19, 27], "followup": [19, 27], "1000": [19, 27], "my": [19, 27, 30], "presenc": [19, 27], "get_dimention_s": [19, 27], "didn": [19, 27], "exponenti": [19, 27], "blowup": [19, 27], "fewer": [19, 27], "opportun": [19, 27], "recogn": [19, 22, 27], "feasibl": [19, 27], "annoi": [19, 27], "wasn": [19, 27], "materiz": [19, 27], "combo": [19, 27], "extend": 20, "float32": 20, "datatyp": 20, "float16": 20, "bfloat16": [20, 26], "syncfre": 20, "autocast": 20, "summar": 20, "elig": 20, "suppli": 20, "addmm": 20, "addmm_": 20, "prefer": 20, "float64": 20, "respect": 20, "unlist": 20, "__matmul__": 20, "addbmm": 20, "addmv": 20, "addr": 20, "baddbmm": 20, "bmm": 20, "conv1d": 20, "conv2d": [20, 24], "conv3d": 20, "conv_transpose1d": 20, "conv_transpose2d": 20, "conv_transpose3d": 20, "matmul": 20, "relu": [20, 21], "prelu": 20, "max_pool2d": 20, "batch_norm": 20, "log_softmax": 20, "binary_cross_entropy_with_logit": 20, "prod": 20, "cdist": 20, "chloeski": 20, "invers": 20, "reflection_pad": 20, "replication_pad": 20, "mse_loss": 20, "cosine_embbeding_loss": 20, "nll_loss": 20, "multilabel_margin_loss": 20, "qr": 20, "svd": 20, "triangular_solv": 20, "linalg_svd": 20, "linalg_inv_ex": 20, "widest": 20, "index_copi": 20, "scaler": [20, 26], "gradscal": 20, "_fetch_gradi": 20, "xla_use_f16": 20, "underflow": 20, "imagenet": 20, "minimum": [21, 24, 25], "nccl": 21, "new_rank": 21, "ddp_model": 21, "final": [21, 28], "launcher": 21, "demo_fn": 21, "touch": [21, 30], "five": 21, "sy": 21, "tempfil": 21, "cleanup": 21, "destroy_process_group": 21, "toymodel": 21, "net1": 21, "1000000": 21, "net2": 21, "demo_bas": 21, "graident_as_bucket_view": 21, "label": 21, "run_demo": 21, "tot": 21, "statist": 21, "unit": 21, "median": 21, "90th": 21, "deviat": 21, "cv": 21, "418": 21, "54": 21, "419": 21, "22": 21, "430": 21, "40": 21, "76": 21, "02": 21, "97": 21, "407": 21, "60": 21, "39": 21, "seem": 21, "17864": 21, "19": [21, 22], "20108": 21, "96": 21, "24351": 21, "74": 21, "5866": 21, "83": 21, "10701": 21, "11770": 21, "00": 21, "14313": 21, "78": 21, "3102": 21, "92": 21, "41": [21, 22], "round": 21, "heavili": [21, 22], "sens": 21, "amort": 21, "logdir": 21, "converg": 21, "caution": 21, "interest": 21, "known": 21, "crash": 21, "unmodifi": 22, "hook": 22, "biggest": [22, 24], "torchfx": 22, "technologi": 22, "fx": 22, "a_xla": 22, "b_xla": 22, "compiled_cod": 22, "eval_model": 22, "xla_resnet18": 22, "eval": 22, "dynamo_resnet18": 22, "no_grad": 22, "resent18": 22, "analysi": 22, "bench": 22, "59": 22, "resnext50_32x4d": 22, "91": 22, "alexnet": 22, "28": 22, "mobilenet_v2": 22, "18": 22, "62": 22, "mnasnet1_0": 22, "68": 22, "vgg16": 22, "bert_pytorch": 22, "squeezenet1_1": 22, "timm_vision_transform": 22, "52": 22, "geomean": 22, "04": 22, "train_model": 22, "crossentropyloss": 22, "pred": 22, "train_model_main": 22, "dynamo_train_model": 22, "xla_optim": 22, "weight_decai": 22, "extract": 22, "07": 22, "43": 22, "81": 22, "87": 22, "fwd": 22, "bwd": 22, "e2": 22, "hide": 22, "larger": [22, 24], "wit": 22, "promis": 22, "tradit": 22, "excit": 22, "upcom": [22, 28], "invest": 22, "matur": 22, "stori": 22, "_higher_order_op": 23, "fori_loop": 23, "cond_fn": 23, "body_fn": 23, "bodi": 23, "iteri": 23, "init_v": 23, "functionaltensor": 23, "lvl": 23, "cumul": 23, "ten": 23, "51": 23, "xlafullyshardeddataparallel": 24, "my_modul": [24, 25], "adam": [24, 25], "0001": [24, 25], "leftov": [24, 25], "arxiv": 24, "1910": 24, "02054": 24, "reshard_after_forward": 24, "test_train_mp_mnist_fsdp_with_ckpt": 24, "test_train_mp_imagenet_fsdp": 24, "interleav": 24, "submodul": 24, "fsdpvitmodel": 24, "checkpoint_modul": [24, 25], "3524": 24, "auto_wrap_polici": [24, 25], "size_based_auto_wrap_polici": 24, "polici": [24, 28], "100m": 24, "transformer_auto_wrap_polici": [24, 25], "transformer_layer_cl": [24, 25], "auto_wrapper_cal": 24, "remateri": 24, "resum": 24, "get_shard_metadata": 24, "consolidate_sharded_model_checkpoint": 24, "stitch": 24, "ckpt": 24, "shard_metadata": 24, "ckpt_path": 24, "pth": 24, "consolidate_sharded_ckpt": 24, "ckpt_prefix": 24, "your_sharded_checkpoint_fil": 24, "ckpt_suffix": 24, "_rank": 24, "inspir": 24, "structur": [24, 28], "fairscal": 24, "fullyshardeddataparallel": 24, "readthedoc": 24, "en": 24, "resort": 24, "train_resnet_fsdp_auto_wrap": 24, "newer": 24, "recurs": [24, 25], "98": 24, "drop_last": 24, "use_nested_fsdp": 24, "use_gradient_checkpoint": 24, "final_ckpt": 24, "75": 24, "download": 24, "1k": 24, "datadir": 24, "test_set_batch_s": 24, "eval_interv": 24, "num_warmup_epoch": 24, "lr_scheduler_divide_every_n_epoch": 24, "lr_scheduler_divisor": 24, "residu": 24, "algorithm": [24, 25], "ronghanghu": 24, "vit_10b_fsdp_exampl": 24, "vit": 24, "fsdpv2": 25, "famou": 25, "enjoi": 25, "tabl": 25, "spmd_fully_sharded_data_parallel": 25, "spmdfullyshardeddataparallel": 25, "autowrap": 25, "decoderlay": 25, "functool": 25, "decoder_only_model": 25, "shard_output": 25, "0th": 25, "children": 25, "fork": 25, "hf": 25, "abstract": [26, 28], "blockwis": 26, "int4": 26, "analog": 26, "classifi": 26, "flexibl": 26, "choos": [26, 30], "docstr": 26, "xla_quantized_matmul": 26, "n_input_featur": 26, "n_output_featur": 26, "w_int": 26, "127": 26, "int8": 26, "matmul_output": 26, "quantized_matmul": 26, "x_xla": 26, "w_int_xla": 26, "scaler_xla": 26, "matmul_output_xla": 26, "w": 26, "f_dynamo": 26, "dynamo_out_xla": 26, "myqlinearforxlabackend": 26, "load_weight": 26, "processed_w": 26, "processed_scal": 26, "stuff": 26, "orig_model": 26, "mymodel": 26, "q_weight": 26, "q_weights_for_xla": 26, "process_for_xla": 26, "q_linear": 26, "xlaquantizedlinear": 26, "in_featur": 26, "out_featur": 26, "load_quantized_weight": 26, "channel": 26, "sym": 26, "asym": 26, "w8a16": 26, "w8a8": 26, "w4a8": 26, "gspmd": [28, 29], "proced": 28, "src": [28, 30], "_input_sharding_": 28, "4d": 28, "input_shard": 28, "shardingspec": 28, "input_mesh": 28, "s2": 28, "s3": 28, "s4": 28, "_after": 28, "_the": 28, "unnecessari": 28, "forth": 28, "techniqu": 28, "decis": 28, "nice": 28, "arrang": 28, "center": 28, "multislic": 28, "denot": 28, "delai": 28, "subclass": 28, "__torch_dispatch__": 28, "global_tensor": 28, "strictli": 28, "local_shard": 28, "xlashard": 28, "4e8e5511555073ce8b6d1a436bf808c9333dcac6": 28, "xla_sharded_tensor": 28, "l12": 28, "ongo": 28, "distributedtensor": 28, "proof": 28, "concept": [28, 29], "distribute_tensor": 28, "devicemesh": 28, "big_tensor": 28, "100000": 28, "88": 28, "my_dtensor": 28, "stai": 28, "dynamo_mark_shard": 28, "placement": 28, "visualize_tensor_shard": 28, "visualize_shard": 28, "rich": 28, "2x2": 28, "generated_t": 28, "use_color": 28, "style": 28, "tile": 28, "partial_repl": 28, "envvar": 28, "xla_auto_spmd": 28, "_tensor": 28, "distribute_modul": 28, "auto_polici": 28, "mymodul": 28, "sharded_model": 28, "behvaior": 28, "xla_auto_use_group_shard": 28, "reshard": 28, "xla_auto_spmd_mesh": 28, "unset": 28, "hint": 29, "strategi": 29, "th": 29, "cluster": 29, "interconnect": 29, "encourag": 29, "fist": 29, "paral": 29, "dedic": 30, "planner": 30, "spmdsaveplann": 30, "spmdloadplann": 30, "dist_cp": 30, "distributed_checkpoint": 30, "xc": 30, "storage_writ": 30, "filesystemwrit": 30, "checkpoint_dir": 30, "storage_read": 30, "filesystemread": 30, "all_step": 30, "save_async": 30, "unblock": 30, "preemption": 30, "detect": 30, "provis": 30, "queuedresourc": 30, "autocheckpoint": 30, "chkpt_on_preempt": 30, "fsspec": 30, "filesystem": 30, "prime_optim": 30, "chkpt_mgr": 30, "tracked_step": 30, "highest": 30, "best_step": 30, "prime": 30, "enumer": 30, "attempt": 30, "unprim": 30, "destruct": 30, "discov": 30, "nvidia": 31, "resnet": 31, "num_gpu_machin": 31, "rank_of_current_machin": 31, "machine_0_ip_address": 31, "training_or_inference_script_using_spmd": 31, "xla_use_spmd": 31, "test_train_spmd_imagenet": 31}, "objects": {"": [[12, 0, 0, "-", "torch_xla"]], "torch_xla": [[12, 1, 1, "", "compile"], [12, 1, 1, "", "device"], [12, 1, 1, "", "device_count"], [12, 1, 1, "", "devices"], [12, 0, 0, "-", "experimental"], [12, 1, 1, "", "manual_seed"], [12, 0, 0, "-", "runtime"], [12, 1, 1, "", "sync"]], "torch_xla.core": [[12, 0, 0, "-", "xla_model"]], "torch_xla.core.xla_model": [[12, 1, 1, "", "add_step_closure"], [12, 1, 1, "", "all_gather"], [12, 1, 1, "", "all_reduce"], [12, 1, 1, "", "all_to_all"], [12, 1, 1, "", "get_memory_info"], [12, 1, 1, "", "get_rng_state"], [12, 1, 1, "", "get_stablehlo"], [12, 1, 1, "", "get_stablehlo_bytecode"], [12, 1, 1, "", "is_master_ordinal"], [12, 1, 1, "", "mesh_reduce"], [12, 1, 1, "", "optimizer_step"], [12, 1, 1, "", "rendezvous"], [12, 1, 1, "", "save"], [12, 1, 1, "", "set_rng_state"], [12, 1, 1, "", "wait_device_ops"], [12, 1, 1, "", "xla_device"], [12, 1, 1, "", "xla_device_hw"]], "torch_xla.debug": [[12, 0, 0, "-", "metrics"]], "torch_xla.debug.metrics": [[12, 1, 1, "", "counter_names"], [12, 1, 1, "", "counter_value"], [12, 1, 1, "", "metric_data"], [12, 1, 1, "", "metric_names"], [12, 1, 1, "", "metrics_report"], [12, 1, 1, "", "short_metrics_report"]], "torch_xla.distributed": [[12, 0, 0, "-", "parallel_loader"], [12, 0, 0, "-", "spmd"], [12, 0, 0, "-", "xla_multiprocessing"]], "torch_xla.distributed.parallel_loader": [[12, 2, 1, "", "MpDeviceLoader"]], "torch_xla.distributed.spmd": [[12, 2, 1, "", "HybridMesh"], [12, 2, 1, "", "Mesh"], [12, 1, 1, "", "clear_sharding"], [12, 1, 1, "", "get_1d_mesh"], [12, 1, 1, "", "get_global_mesh"], [12, 1, 1, "", "mark_sharding"], [12, 1, 1, "", "set_global_mesh"]], "torch_xla.distributed.xla_multiprocessing": [[12, 1, 1, "", "spawn"]], "torch_xla.experimental": [[12, 1, 1, "", "eager_mode"]], "torch_xla.runtime": [[12, 1, 1, "", "addressable_device_count"], [12, 1, 1, "", "device_type"], [12, 1, 1, "", "get_master_ip"], [12, 1, 1, "", "global_device_count"], [12, 1, 1, "", "global_ordinal"], [12, 1, 1, "", "global_runtime_device_count"], [12, 1, 1, "", "initialize_cache"], [12, 1, 1, "", "is_spmd"], [12, 1, 1, "", "local_device_count"], [12, 1, 1, "", "local_ordinal"], [12, 1, 1, "", "local_process_count"], [12, 1, 1, "", "use_spmd"], [12, 1, 1, "", "world_size"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"]}, "titleterms": {"learn": [0, 1, 11], "about": [0, 1, 11], "gpu": [0, 10, 15, 20, 31], "tpu": [1, 4, 15, 16, 18, 20, 24, 28], "bazel": 2, "pytorch": [2, 3, 4, 6, 7, 8, 9, 11, 12, 16, 17, 18, 22, 24, 27, 28, 29], "xla": [2, 3, 4, 6, 7, 8, 11, 12, 16, 17, 18, 20, 22, 24, 26, 27, 28, 29], "depend": [2, 8, 10], "how": [2, 21, 26, 29], "build": 2, "librari": 2, "torch": [2, 7, 9, 15, 28], "plugin": [2, 6], "remot": 2, "cach": [2, 16], "run": [2, 3, 9, 16, 17, 18, 28, 31], "test": [2, 3, 5, 17, 23], "code": [2, 4, 18, 26], "coverag": 2, "languag": 2, "server": 2, "codegen": 3, "migrat": 3, "guid": [3, 5, 29], "befor": [3, 5], "you": [3, 5, 19, 27], "start": [3, 5, 19, 27], "file": [3, 5, 9], "structur": [3, 5], "old": 3, "op": [3, 5, 7, 26], "lower": [3, 5, 7], "step": [3, 4], "1": [3, 18, 19, 27], "identifi": 3, "2": [3, 18, 19, 25, 27], "inspect": 3, "gener": [3, 9], "lazyir": 3, "h": 3, "3": [3, 19, 27], "implement": [3, 6], "miss": 3, "ir": 3, "function": 3, "torch_xla": [3, 12, 19], "csrc": 3, "ops_xla_shape_fn": 3, "cpp": 3, "4": 3, "ops_lower_fn": 3, "5": 3, "cleanup": 3, "verifi": 3, "result": 3, "sampl": 3, "pr": 3, "configur": 4, "develop": 4, "environ": [4, 17], "visual": 4, "studio": 4, "creat": [4, 16], "connect": 4, "your": 4, "set": 4, "up": 4, "workspac": 4, "next": 4, "understand": [5, 17], "oper": [5, 9, 19, 20, 26, 27], "unit": [5, 17], "tip": 5, "custom": [6, 8, 10], "hardwar": 6, "pjrt": [6, 15], "c": 6, "api": [6, 7, 12, 14], "packag": 6, "support": [7, 20, 26], "distribut": [7, 12, 15, 30], "collect": 7, "stack": 7, "non": 7, "dynamo": [7, 17], "case": [7, 19, 23, 27], "descript": 7, "kernel": [8, 10], "via": [8, 10], "palla": 8, "adopt": 8, "abov": 8, "compat": 8, "us": [8, 19, 21, 23, 25, 26, 27, 29], "built": 8, "flashattent": 8, "exampl": [8, 18, 20, 23, 24, 25], "usag": [8, 14, 23], "integr": [8, 22, 28], "pagedattent": 8, "export": 9, "stablehlo": 9, "save": [9, 16], "bytecod": 9, "disk": 9, "convert": [9, 18], "serv": 9, "common": [9, 17], "wrapper": 9, "i": [9, 19, 27, 29], "want": 9, "directli": 9, "tf": 9, "saved_model": 9, "format": 9, "without": [9, 19, 27], "need": 9, "an": [9, 16], "separ": 9, "command": 9, "other": 9, "produc": 9, "save_as_stablehlo": 9, "preserv": 9, "high": 9, "level": 9, "composit": 9, "triton": 10, "document": 11, "acceler": 11, "featur": [11, 22, 26], "improv": 11, "workload": 11, "perform": [11, 15, 17, 18], "contribut": 11, "runtim": [12, 15], "xla_model": 12, "spmd": [12, 25, 28, 29, 31], "experiment": [12, 26], "debug": [12, 17, 28], "dynam": [13, 19, 27], "shape": [13, 19, 27], "bound": [13, 19, 27], "eager": 14, "mode": [14, 29], "compil": [14, 16, 17, 28], "basic": 14, "infer": [14, 18, 22], "train": [14, 15, 22, 24], "benchmark": [14, 17, 21], "tl": 15, "dr": 15, "benefit": 15, "quickstart": 15, "cpu": [15, 16], "pod": [15, 16, 18, 24, 28], "docker": 15, "singl": [15, 16, 18], "node": 15, "multi": [15, 16], "differ": 15, "from": [15, 16, 19, 27], "xrt": 15, "multithread": 15, "v2": 15, "v3": [15, 24], "chang": 15, "xm": 15, "rendezv": 15, "new": 15, "devic": [16, 18, 28], "tensor": [16, 17, 19, 27], "ar": 16, "model": [16, 26], "multipl": [16, 18], "process": [16, 30], "deep": 16, "dive": 16, "lazi": 16, "memori": [16, 23], "layout": 16, "move": 16, "load": [16, 28], "further": [16, 29], "read": [16, 29], "troubleshoot": 17, "saniti": 17, "check": 17, "version": 17, "A": 17, "simpl": [17, 23], "calcul": 17, "resnet": [17, 24], "With": 17, "fake": [17, 21], "data": [17, 21, 24, 25, 28], "tool": [17, 28], "auto": [17, 28], "metric": 17, "analysi": [17, 18], "execut": 17, "get": 17, "report": 17, "The": 17, "clear": 17, "profil": [17, 18], "known": 17, "caveat": 17, "quirk": 17, "more": 17, "variabl": 17, "combin": 17, "reproduc": 17, "ci": 17, "cd": 17, "failur": 17, "overview": 18, "setup": 18, "stabl": 18, "diffus": 18, "lightn": 18, "hf": 18, "sourc": [19, 27], "recompil": [19, 27], "let": [19, 27], "": [19, 27], "first": [19, 27], "some": [19, 27], "fact": [19, 27], "constraint": [19, 27], "input": [19, 27], "dataset": [19, 27], "output": [19, 25, 27], "can": [19, 27], "fix": [19, 27], "when": [19, 27], "queri": [19, 27], "its": [19, 27], "real": [19, 21, 27], "dimens": [19, 27], "what": [19, 27, 29], "control": [19, 23, 27], "flow": [19, 27], "conclus": [19, 27], "appendix": [19, 27], "automat": 20, "mix": 20, "precis": 20, "amp": 20, "best": 20, "practic": 20, "do": 21, "distributeddataparallel": 21, "ddp": 21, "background": 21, "motiv": 21, "resnet50": 21, "mnist": [21, 24], "disclaim": 21, "torchdynamo": 22, "gap": 22, "take": 22, "awai": 22, "optim": [23, 28, 30], "util": 23, "while_loop": 23, "group": [23, 30], "pure": 23, "python": 23, "while": 23, "loop": 23, "fulli": [24, 25], "shard": [24, 25, 28], "parallel": [24, 25], "script": 24, "imagenet": 24, "instal": 24, "clone": 24, "repo": 24, "8": 24, "50": 24, "10": 24, "billion": 24, "paramet": 24, "gradient": 25, "checkpoint": [25, 30], "huggingfac": 25, "llama": 25, "quantiz": 26, "call": 26, "modul": 26, "swap": 26, "matrix": 26, "multipli": 26, "advanc": 28, "topic": 28, "awar": 28, "host": 28, "virtual": 28, "hybrid": 28, "mesh": [28, 29], "xlashardedtensor": 28, "dtensor": 28, "activ": 28, "user": 29, "partit": 29, "spec": 29, "checkpointmanag": 30, "restor": 30, "state": 30}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.todo": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"Learn about GPUs": [[0, "learn-about-gpus"]], "Learn about TPUs": [[1, "learn-about-tpus"]], "Bazel in Pytorch/XLA": [[2, "bazel-in-pytorch-xla"]], "Bazel dependencies": [[2, "bazel-dependencies"]], "How to build XLA libraries": [[2, "how-to-build-xla-libraries"]], "How to build the Torch/XLA plugin": [[2, "how-to-build-the-torch-xla-plugin"]], "Remote caching": [[2, "remote-caching"]], "Running tests": [[2, "running-tests"]], "Code coverage": [[2, "code-coverage"]], "Language Server": [[2, "language-server"]], "Building PyTorch/XLA": [[2, "building-pytorch-xla"]], "Codegen migration Guide": [[3, "codegen-migration-guide"]], "Before you start": [[3, "before-you-start"], [5, "before-you-start"]], "File structure": [[3, "file-structure"], [5, "file-structure"]], "PyTorch Codegen files": [[3, "pytorch-codegen-files"]], "PyTorch/XLA Codegen files": [[3, "pytorch-xla-codegen-files"]], "PyTorch/XLA Old Op Lowering files": [[3, "pytorch-xla-old-op-lowering-files"]], "Codegen step by step": [[3, "codegen-step-by-step"]], "1. Identify the op": [[3, "identify-the-op"]], "2. Codegen the op and inspect the generated file": [[3, "codegen-the-op-and-inspect-the-generated-file"]], "LazyIr.h": [[3, "lazyir-h"]], "3. Implement the missing IR function": [[3, "implement-the-missing-ir-function"]], "torch_xla/csrc/ops/ops_xla_shape_fn.h": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-h"]], "torch_xla/csrc/ops/ops_xla_shape_fn.cpp": [[3, "torch-xla-csrc-ops-ops-xla-shape-fn-cpp"]], "4. Implement the lowering function": [[3, "implement-the-lowering-function"]], "torch_xla/csrc/ops/ops_lower_fn.cpp": [[3, "torch-xla-csrc-ops-ops-lower-fn-cpp"]], "5. Cleanup": [[3, "cleanup"]], "Run the test and verify the result": [[3, "run-the-test-and-verify-the-result"]], "Sample PRs": [[3, "sample-prs"]], "Configure a development environment": [[4, "configure-a-development-environment"]], "Visual Studio Code": [[4, "visual-studio-code"]], "Creating and connecting to your TPU": [[4, "creating-and-connecting-to-your-tpu"]], "Setting up a Visual Studio Code workspace with PyTorch/XLA": [[4, "setting-up-a-visual-studio-code-workspace-with-pytorch-xla"]], "Next steps": [[4, "next-steps"]], "OP Lowering Guide": [[5, "op-lowering-guide"]], "Understanding the operation": [[5, "understanding-the-operation"]], "Unit Test": [[5, "unit-test"]], "Tips": [[5, "tips"]], "Custom Hardware Plugins": [[6, "custom-hardware-plugins"]], "Implementing a PJRT Plugin": [[6, "implementing-a-pjrt-plugin"]], "PJRT C API Implementation": [[6, "pjrt-c-api-implementation"]], "PyTorch/XLA Plugin Package": [[6, "pytorch-xla-plugin-package"]], "Support of Torch Distributed API in PyTorch/XLA": [[7, "support-of-torch-distributed-api-in-pytorch-xla"]], "Collective ops lowering": [[7, "collective-ops-lowering"]], "Collective ops lowering stack": [[7, "collective-ops-lowering-stack"]], "non-Dynamo case": [[7, "non-dynamo-case"]], "Dynamo case": [[7, "dynamo-case"]], "API description": [[7, "api-description"]], "Custom Kernels via Pallas": [[8, "custom-kernels-via-pallas"]], "Adopt the above kernel to be compatible with PyTorch/XLA": [[8, "adopt-the-above-kernel-to-be-compatible-with-pytorch-xla"]], "Use built-in kernels": [[8, "use-built-in-kernels"]], "FlashAttention": [[8, "id1"]], "Example usage": [[8, "example-usage"], [8, "id3"]], "Integration Example": [[8, "integration-example"], [8, "id4"]], "PagedAttention": [[8, "id2"]], "Dependencies": [[8, "dependencies"], [10, "dependencies"]], "Torch Export to StableHLO": [[9, "torch-export-to-stablehlo"]], "Saving StableHLO bytecodes to disk": [[9, "saving-stablehlo-bytecodes-to-disk"]], "Convert saved StableHLO for serving": [[9, "convert-saved-stablehlo-for-serving"]], "Common wrappers": [[9, "common-wrappers"]], "I want to save directly tf.saved_model format without needing to run an separate command.": [[9, "i-want-to-save-directly-tf-saved-model-format-without-needing-to-run-an-separate-command"]], "Other common wrappers": [[9, "other-common-wrappers"]], "Files produced by save_as_stablehlo.": [[9, "files-produced-by-save-as-stablehlo"]], "Preserving High-Level PyTorch Operations in StableHLO by generating stablehlo.composite": [[9, "preserving-high-level-pytorch-operations-in-stablehlo-by-generating-stablehlo-composite"]], "Custom GPU Kernels via Triton": [[10, "custom-gpu-kernels-via-triton"]], "PyTorch/XLA documentation": [[11, "pytorch-xla-documentation"]], "Learn about Pytorch/XLA": [[11, null]], "Learn about accelerators": [[11, null]], "PyTorch/XLA features": [[11, null]], "Improve Pytorch/XLA workload performance": [[11, null]], "Contribute to Pytorch/XLA": [[11, null]], "PyTorch/XLA API": [[12, "pytorch-xla-api"]], "torch_xla": [[12, "module-torch_xla"]], "runtime": [[12, "module-torch_xla.runtime"]], "xla_model": [[12, "module-torch_xla.core.xla_model"]], "distributed": [[12, "module-torch_xla.distributed.parallel_loader"]], "spmd": [[12, "module-torch_xla.distributed.spmd"]], "experimental": [[12, "module-torch_xla.experimental"]], "debug": [[12, "module-torch_xla.debug.metrics"]], "Dynamic shape": [[13, "dynamic-shape"]], "Bounded dynamic shape": [[13, "bounded-dynamic-shape"]], "Eager Mode + Compile API": [[14, "eager-mode-compile-api"]], "Basic Usage": [[14, "basic-usage"]], "Inference": [[14, "inference"], [22, "inference"]], "Training": [[14, "training"], [22, "training"]], "Benchmark": [[14, "benchmark"]], "PJRT Runtime": [[15, "pjrt-runtime"]], "TL;DR": [[15, "tl-dr"]], "Benefits": [[15, "benefits"]], "Quickstart": [[15, "quickstart"]], "CPU": [[15, "cpu"]], "TPU": [[15, "tpu"]], "Pods": [[15, "pods"]], "Docker": [[15, "docker"]], "GPU": [[15, "gpu"]], "Single-node GPU training": [[15, "single-node-gpu-training"]], "Multi-node GPU training": [[15, "multi-node-gpu-training"]], "Differences from XRT": [[15, "differences-from-xrt"]], "Multithreading on TPU v2/v3": [[15, "id3"]], "Changes to xm.rendezvous": [[15, "changes-to-xm-rendezvous"]], "PJRT and torch.distributed": [[15, "pjrt-and-torch-distributed"]], "Performance": [[15, "performance"]], "New TPU runtime": [[15, "new-tpu-runtime"]], "PyTorch on XLA Devices": [[16, "pytorch-on-xla-devices"]], "Creating an XLA Tensor": [[16, "creating-an-xla-tensor"]], "XLA Tensors are PyTorch Tensors": [[16, "xla-tensors-are-pytorch-tensors"]], "Running Models on XLA Devices": [[16, "running-models-on-xla-devices"]], "Running on a Single XLA Device": [[16, "running-on-a-single-xla-device"]], "Running on Multiple XLA Devices with Multi-processing": [[16, "running-on-multiple-xla-devices-with-multi-processing"]], "Running on TPU Pods": [[16, "running-on-tpu-pods"]], "XLA Tensor Deep Dive": [[16, "id3"]], "XLA Tensors are Lazy": [[16, "xla-tensors-are-lazy"]], "Memory Layout": [[16, "memory-layout"]], "Moving XLA Tensors to and from the CPU": [[16, "moving-xla-tensors-to-and-from-the-cpu"]], "Saving and Loading XLA Tensors": [[16, "saving-and-loading-xla-tensors"]], "Compilation Caching": [[16, "compilation-caching"]], "Further Reading": [[16, "further-reading"], [29, "further-reading"]], "Troubleshoot": [[17, "troubleshoot"]], "Sanity Check": [[17, "sanity-check"]], "Check PyTorch/XLA Version": [[17, "check-pytorch-xla-version"]], "Perform A Simple Calculation": [[17, "perform-a-simple-calculation"]], "Run Resnet With Fake Data": [[17, "run-resnet-with-fake-data"]], "Performance Debugging": [[17, "performance-debugging"]], "PyTorch/XLA Debugging Tool": [[17, "pytorch-xla-debugging-tool"]], "Perform A Auto-Metrics Analysis": [[17, "perform-a-auto-metrics-analysis"]], "Compilation & Execution Analysis": [[17, "compilation-execution-analysis"]], "Get A Metrics Report": [[17, "get-a-metrics-report"]], "Understand The Metrics Report": [[17, "understand-the-metrics-report"]], "Clear The Metrics Report": [[17, "clear-the-metrics-report"]], "PyTorch/XLA + Dynamo Debugging Tool": [[17, "pytorch-xla-dynamo-debugging-tool"]], "Performance Profiling": [[17, "performance-profiling"]], "Simple Benchmarking": [[17, "simple-benchmarking"]], "Known Performance Caveats": [[17, "known-performance-caveats"]], "XLA Tensor Quirks": [[17, "xla-tensor-quirks"]], "More Debugging Tools": [[17, "more-debugging-tools"]], "Environment Variables": [[17, "environment-variables"]], "Common Debugging Environment Variables Combinations": [[17, "common-debugging-environment-variables-combinations"]], "Reproducing PyTorch/XLA CI/CD unit test failures.": [[17, "reproducing-pytorch-xla-ci-cd-unit-test-failures"]], "Pytorch/XLA overview": [[18, "pytorch-xla-overview"]], "TPU Setup": [[18, "tpu-setup"]], "Converting code to PyTorch XLA": [[18, "converting-code-to-pytorch-xla"]], "Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device": [[18, "example-1-stable-diffusion-inference-in-pytorch-lightning-on-a-single-tpu-device"]], "Example 2. HF Stable Diffusion Inference": [[18, "example-2-hf-stable-diffusion-inference"]], "Running on a Single TPU device": [[18, "running-on-a-single-tpu-device"]], "Profiling and performance analysis": [[18, "profiling-and-performance-analysis"]], "Running on Multiple TPU Devices": [[18, "running-on-multiple-tpu-devices"]], "Running on Pods": [[18, "running-on-pods"]], "Source of recompilations in torch_xla": [[19, "source-of-recompilations-in-torch-xla"]], "Let\u2019s first start with some facts/constraints:": [[19, "lets-first-start-with-some-facts-constraints"], [27, "lets-first-start-with-some-facts-constraints"]], "#1. From input dataset.": [[19, "from-input-dataset"], [27, "from-input-dataset"]], "#2. From operator output": [[19, "from-operator-output"], [27, "from-operator-output"]], "2.1 Bounded dynamic shape can fix the case when you use the tensor with dynamic shape as a Tensor, without querying its real dimension.": [[19, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"], [27, "bounded-dynamic-shape-can-fix-the-case-when-you-use-the-tensor-with-dynamic-shape-as-a-tensor-without-querying-its-real-dimension"]], "2.2 what if real dimension is queried on a tensor with dynamic shape?": [[19, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"], [27, "what-if-real-dimension-is-queried-on-a-tensor-with-dynamic-shape"]], "#3. From control flow": [[19, "from-control-flow"], [27, "from-control-flow"]], "Conclusion:": [[19, "conclusion"], [27, "conclusion"]], "Appendix:": [[19, "appendix"], [27, "appendix"]], "Automatic Mixed Precision": [[20, "automatic-mixed-precision"]], "AMP for XLA:TPU": [[20, "amp-for-xla-tpu"]], "AMP for XLA:TPU Best Practices": [[20, "amp-for-xla-tpu-best-practices"]], "Supported Operators": [[20, "supported-operators"]], "AMP for XLA:GPU": [[20, "amp-for-xla-gpu"]], "AMP for XLA:GPU Best Practices": [[20, "amp-for-xla-gpu-best-practices"]], "Examples": [[20, "examples"]], "How to do DistributedDataParallel(DDP)": [[21, "how-to-do-distributeddataparallel-ddp"]], "Background / Motivation": [[21, "background-motivation"]], "How to use DistributedDataParallel": [[21, "how-to-use-distributeddataparallel"]], "Benchmarking": [[21, "benchmarking"]], "Resnet50 with fake data": [[21, "resnet50-with-fake-data"]], "MNIST with fake data": [[21, "mnist-with-fake-data"]], "MNIST with real data": [[21, "mnist-with-real-data"]], "Disclaimer": [[21, "disclaimer"]], "TorchDynamo integration in PyTorch XLA": [[22, "torchdynamo-integration-in-pytorch-xla"]], "Integration": [[22, "integration"]], "Feature gaps": [[22, "feature-gaps"]], "Take away": [[22, "take-away"]], "Optimize memory utilization using while_loop": [[23, "optimize-memory-utilization-using-while-loop"]], "while_loop": [[23, "while-loop"]], "Usage:": [[23, "usage"]], "simple example with while_loop:": [[23, "simple-example-with-while-loop"]], "Control group test case": [[23, "control-group-test-case"]], "Control group example with pure python while loop": [[23, "control-group-example-with-pure-python-while-loop"]], "Fully Sharded Data Parallel in PyTorch XLA": [[24, "fully-sharded-data-parallel-in-pytorch-xla"]], "Example training scripts on MNIST and ImageNet": [[24, "example-training-scripts-on-mnist-and-imagenet"]], "Installation": [[24, "installation"]], "Clone PyTorch/XLA repo": [[24, "clone-pytorch-xla-repo"]], "Train MNIST on v3-8 TPU": [[24, "train-mnist-on-v3-8-tpu"]], "Train ImageNet with ResNet-50 on v3-8 TPU": [[24, "train-imagenet-with-resnet-50-on-v3-8-tpu"]], "Example training scripts on TPU pod (with 10 billion parameters)": [[24, "example-training-scripts-on-tpu-pod-with-10-billion-parameters"]], "Fully Sharded Data Parallel using SPMD": [[25, "fully-sharded-data-parallel-using-spmd"]], "Sharding output": [[25, "sharding-output"]], "Gradient checkpointing": [[25, "gradient-checkpointing"]], "HuggingFace Llama 2 Example": [[25, "huggingface-llama-2-example"]], "Quantized Operations for XLA (Experimental feature)": [[26, "quantized-operations-for-xla-experimental-feature"]], "How to use:": [[26, "how-to-use"]], "Call XLA quantized op in model code": [[26, "call-xla-quantized-op-in-model-code"]], "Module Swap": [[26, "module-swap"]], "Supported Quantized Operations:": [[26, "supported-quantized-operations"]], "Matrix Multiply": [[26, "matrix-multiply"]], "Source of recompilations in Pytorch/XLA": [[27, "source-of-recompilations-in-pytorch-xla"]], "PyTorch/XLA SPMD advanced topics": [[28, "pytorch-xla-spmd-advanced-topics"]], "Sharding-Aware Host-to-Device Data Loading": [[28, "sharding-aware-host-to-device-data-loading"]], "Virtual Device Optimization": [[28, "virtual-device-optimization"]], "Hybrid Mesh": [[28, "hybrid-mesh"]], "Running SPMD on TPU Pod": [[28, "running-spmd-on-tpu-pod"]], "XLAShardedTensor": [[28, "xlashardedtensor"]], "DTensor Integration": [[28, "dtensor-integration"]], "Activation Sharding for torch.compile": [[28, "activation-sharding-for-torch-compile"]], "SPMD Debugging Tool": [[28, "spmd-debugging-tool"]], "Auto-Sharding": [[28, "auto-sharding"]], "PyTorch/XLA SPMD User Guide": [[29, "pytorch-xla-spmd-user-guide"]], "What is PyTorch/XLA SPMD?": [[29, "what-is-pytorch-xla-spmd"]], "How to use PyTorch/XLA SPMD?": [[29, "how-to-use-pytorch-xla-spmd"]], "SPMD Mode": [[29, "spmd-mode"]], "Mesh": [[29, "mesh"]], "Partition Spec": [[29, "partition-spec"]], "Distributed Checkpointing": [[30, "distributed-checkpointing"]], "CheckpointManager": [[30, "checkpointmanager"]], "Restoring Optimizer State": [[30, "restoring-optimizer-state"]], "Process Groups": [[30, "process-groups"]], "Running SPMD on GPU": [[31, "running-spmd-on-gpu"]]}, "indexentries": {"hybridmesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.HybridMesh"]], "mesh (class in torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.Mesh"]], "mpdeviceloader (class in torch_xla.distributed.parallel_loader)": [[12, "torch_xla.distributed.parallel_loader.MpDeviceLoader"]], "add_step_closure() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.add_step_closure"]], "addressable_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.addressable_device_count"]], "all_gather() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_gather"]], "all_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_reduce"]], "all_to_all() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.all_to_all"]], "clear_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.clear_sharding"]], "compile() (in module torch_xla)": [[12, "torch_xla.compile"]], "counter_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_names"]], "counter_value() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.counter_value"]], "device() (in module torch_xla)": [[12, "torch_xla.device"]], "device_count() (in module torch_xla)": [[12, "torch_xla.device_count"]], "device_type() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.device_type"]], "devices() (in module torch_xla)": [[12, "torch_xla.devices"]], "eager_mode() (in module torch_xla.experimental)": [[12, "torch_xla.experimental.eager_mode"]], "get_1d_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_1d_mesh"]], "get_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.get_global_mesh"]], "get_master_ip() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.get_master_ip"]], "get_memory_info() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_memory_info"]], "get_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_rng_state"]], "get_stablehlo() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo"]], "get_stablehlo_bytecode() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.get_stablehlo_bytecode"]], "global_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_device_count"]], "global_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_ordinal"]], "global_runtime_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.global_runtime_device_count"]], "initialize_cache() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.initialize_cache"]], "is_master_ordinal() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.is_master_ordinal"]], "is_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.is_spmd"]], "local_device_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_device_count"]], "local_ordinal() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_ordinal"]], "local_process_count() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.local_process_count"]], "manual_seed() (in module torch_xla)": [[12, "torch_xla.manual_seed"]], "mark_sharding() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.mark_sharding"]], "mesh_reduce() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.mesh_reduce"]], "metric_data() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_data"]], "metric_names() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metric_names"]], "metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.metrics_report"]], "module": [[12, "module-torch_xla"], [12, "module-torch_xla.core.xla_model"], [12, "module-torch_xla.debug.metrics"], [12, "module-torch_xla.distributed.parallel_loader"], [12, "module-torch_xla.distributed.spmd"], [12, "module-torch_xla.distributed.xla_multiprocessing"], [12, "module-torch_xla.experimental"], [12, "module-torch_xla.runtime"]], "optimizer_step() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.optimizer_step"]], "rendezvous() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.rendezvous"]], "save() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.save"]], "set_global_mesh() (in module torch_xla.distributed.spmd)": [[12, "torch_xla.distributed.spmd.set_global_mesh"]], "set_rng_state() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.set_rng_state"]], "short_metrics_report() (in module torch_xla.debug.metrics)": [[12, "torch_xla.debug.metrics.short_metrics_report"]], "spawn() (in module torch_xla.distributed.xla_multiprocessing)": [[12, "torch_xla.distributed.xla_multiprocessing.spawn"]], "sync() (in module torch_xla)": [[12, "torch_xla.sync"]], "torch_xla": [[12, "module-torch_xla"]], "torch_xla.core.xla_model": [[12, "module-torch_xla.core.xla_model"]], "torch_xla.debug.metrics": [[12, "module-torch_xla.debug.metrics"]], "torch_xla.distributed.parallel_loader": [[12, "module-torch_xla.distributed.parallel_loader"]], "torch_xla.distributed.spmd": [[12, "module-torch_xla.distributed.spmd"]], "torch_xla.distributed.xla_multiprocessing": [[12, "module-torch_xla.distributed.xla_multiprocessing"]], "torch_xla.experimental": [[12, "module-torch_xla.experimental"]], "torch_xla.runtime": [[12, "module-torch_xla.runtime"]], "use_spmd() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.use_spmd"]], "wait_device_ops() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.wait_device_ops"]], "world_size() (in module torch_xla.runtime)": [[12, "torch_xla.runtime.world_size"]], "xla_device() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device"]], "xla_device_hw() (in module torch_xla.core.xla_model)": [[12, "torch_xla.core.xla_model.xla_device_hw"]]}}) \ No newline at end of file