Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
3 changes: 3 additions & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
distribute_tensor as distribute_tensor,
)
from keras.src.distribution.distribution_lib import distribution as distribution
from keras.src.distribution.distribution_lib import (
get_device_count as get_device_count,
)
from keras.src.distribution.distribution_lib import initialize as initialize
from keras.src.distribution.distribution_lib import list_devices as list_devices
from keras.src.distribution.distribution_lib import (
Expand Down
14 changes: 14 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ def list_devices(device_type=None):
return [f"{device.platform}:{device.id}" for device in jax_devices]


def get_device_count(device_type=None):
"""Returns the number of available JAX devices.
Args:
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
If `None`, it defaults to counting "gpu" or "tpu" devices if
available, otherwise it counts "cpu" devices. It does not
return the sum of all device types.
Returns:
int: The total number of JAX devices for the specified type.
"""
device_type = device_type.lower() if device_type else None
return jax.device_count(device_type)


def distribute_variable(value, layout):
"""Create a distributed variable for JAX.

Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


@pytest.mark.skipif(
backend.backend() != "jax",
reason="Backend specific test",
backend.backend() != "jax" or len(jax.devices()) != 8,
reason="Backend specific test and requires 8 devices",
)
class JaxDistributionLibTest(testing.TestCase):
def _create_jax_layout(self, sharding):
Expand All @@ -42,6 +42,10 @@ def _create_jax_layout(self, sharding):

return sharding

def test_get_device_count(self):
self.assertEqual(backend_dlib.get_device_count(), 8)
self.assertEqual(backend_dlib.get_device_count("cpu"), 8)

def test_list_devices(self):
self.assertEqual(len(distribution_lib.list_devices()), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
Expand Down
14 changes: 14 additions & 0 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def list_devices(device_type=None):
return distribution_lib.list_devices(device_type)


@keras_export("keras.distribution.get_device_count")
def get_device_count(device_type=None):
"""Returns the number of available JAX devices.
Args:
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
If `None`, it defaults to counting "gpu" or "tpu" devices if
available, otherwise it counts "cpu" devices. It does not
return the sum of all device types.
Returns:
int: The total number of JAX devices for the specified type.
"""
return distribution_lib.get_device_count(device_type=device_type)


@keras_export("keras.distribution.initialize")
def initialize(job_addresses=None, num_processes=None, process_id=None):
"""Initialize the distribution system for multi-host/process setting.
Expand Down