Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies = [
"jax>=0.6.0",
]

[project.entry-points.jax_plugins]
mpibackend4jax = "mpibackend4jax.plugin"

[tool.hatch.build.targets.wheel]
packages = ["src/mpibackend4jax"]

Expand Down
16 changes: 0 additions & 16 deletions src/mpibackend4jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,12 @@
"""

import os
from pathlib import Path

# Import the cluster to register it automatically
from .mpitrampoline_cluster import MPITrampolineLocalCluster

__version__ = "0.1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__version__ = "0.1.0"
__version__ = "0.1.1"


# Get the package installation directory
_package_dir = Path(__file__).parent
_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so"

# Set environment variables for MPITrampoline
if _mpiwrapper_lib.exists():
os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute())
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi"

print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}")
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi")
else:
print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}")
print("Please ensure the package was installed correctly.")


# Convenience function to check if MPITrampoline is properly configured
def is_configured():
Expand Down
20 changes: 20 additions & 0 deletions src/mpibackend4jax/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from pathlib import Path


def initialize():
# Get the package installation directory
_package_dir = Path(__file__).parent
_mpiwrapper_lib = _package_dir / "lib" / "libmpiwrapper.so"

# Set environment variables for MPITrampoline
if _mpiwrapper_lib.exists():
if "MPITRAMPOLINE_LIB" not in os.environ.keys():
os.environ["MPITRAMPOLINE_LIB"] = str(_mpiwrapper_lib.absolute())
print(f"mpibackend4jax: Set MPITRAMPOLINE_LIB={_mpiwrapper_lib.absolute()}")
if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys():
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi"
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi")
Comment on lines +15 to +17
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has no effect, would need to change it to jax.config.update('jax_cpu_collectives_implementation', 'mpi')

else:
print(f"Warning: MPIWrapper library not found at {_mpiwrapper_lib}")
print("Please ensure the package was installed correctly.")