Skip to content

Commit

Permalink
Enable PJRT plugins by default (#7268)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jun 13, 2024
1 parent 192dec6 commit 1cad403
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
3 changes: 2 additions & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def _init_xla_lazy_backend():
from .experimental import plugins
from ._internal import neuron, xpu # Additional built-in plugins

if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1':
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS',
'0' if _XLAC._has_cuda_support() else '1') == '1':
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def configure_multiprocess(self, local_rank, local_world_size):
return configure_topology(local_rank, local_world_size)

def physical_chip_count(self):
return num_available_chips()
# HACK: We may reduce the number of processes we spawn depending on TPU
# topology settings
return num_local_processes()

def client_create_options(self):
return {
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,13 @@ void InitXlaModuleBindings(py::module m) {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/true);
});
m.def("_has_cuda_support", []() {
#ifdef GOOGLE_CUDA
return true;
#else
return false;
#endif
});
m.def("_xla_gpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down

0 comments on commit 1cad403

Please sign in to comment.