Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,51 @@ def _calculate_sharding_hbm_consumption(
return shardings


def _construct_minimal_sharding(
sds: jax.ShapeDtypeStruct,
devices: Sequence[jax.Device] | None = None,
) -> jax.sharding.Sharding:
"""Constructs a sharding that replicates the array as much as possible."""
del sds
devices = devices or jax.devices()
return jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(devices, ('a',)),
spec=jax.sharding.PartitionSpec(),
)


def construct_minimal_shardings(
abstract_state: PyTree, devices: Sequence[jax.Device] | None = None
) -> PyTree:
"""Construct a sharding that replicates each array as much as possible.

This method is subject to change and should not be considered stable.

Args:
abstract_state: PyTree of jax.ShapeDtypeStruct.
devices: Devices to shard across. If None, uses all available devices.

Returns:
PyTree of jax.sharding.Sharding.
"""
shardings = jax.tree.map(
lambda x: _construct_minimal_sharding(x, devices=devices), abstract_state
)

total_size = 0

def _calculate_sharding_hbm_consumption(
sds: jax.ShapeDtypeStruct, sharding: jax.sharding.Sharding
):
nonlocal total_size
shard_shape = sharding.shard_shape(sds.shape)
total_size += np.prod(shard_shape) * sds.dtype.itemsize

jax.tree.map(_calculate_sharding_hbm_consumption, abstract_state, shardings)
logging.info('Expected per-device HBM consumption: %s', total_size)
return shardings


def get_device_local_layout(arr: jax.Array) -> Any:
"""Returns device_local_layout of a jax.Array."""
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _partition_axis_name(offset: int) -> str:



def load_checkpoint(path: str) -> Any:
def load_checkpoint(path: str, is_replicated: bool = False) -> Any:
"""Loads a PyTree of test checkpoint from a provided path."""
logging.info('Loading checkpoint from path: %s', path)
path = epath.Path(path)
Expand All @@ -136,7 +136,10 @@ def load_checkpoint(path: str) -> Any:
abstract_state = jax.tree.map(
abstract_arrays.to_shape_dtype_struct, metadata.tree
)
shardings = sharding_utils.construct_maximal_shardings(abstract_state)
if is_replicated:
shardings = sharding_utils.construct_minimal_shardings(abstract_state)
else:
shardings = sharding_utils.construct_maximal_shardings(abstract_state)
abstract_state = jax.tree.map(
lambda sds, sharding: jax.ShapeDtypeStruct(
sds.shape, sds.dtype, sharding=sharding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ class CheckpointConfig:
spec: A dictionary defining the structure and type of the PyTree to be
generated. Example: { 'params': { 'dtype': 'float32', 'shape': [1024,
1024], 'sharding': ['data', 'model'] # PartitionSpec }, 'step': 'int' }
is_replicated: If True, the checkpoint will be generated with replicated
shardings. Default is False.
"""

path: str | None = None
random_seed: int = 0
spec: dict[str, Any] = dataclasses.field(default_factory=dict)
is_replicated: bool = False


@dataclasses.dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from absl import logging
from etils import epath
import jax
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation
from orbax.checkpoint._src.testing.benchmarks.core import configs
Expand Down Expand Up @@ -166,7 +167,11 @@ def run(self, repeat_index: int | None = None) -> TestResult:
self.checkpoint_config, mesh=self.mesh
)
else:
data = checkpoint_generation.load_checkpoint(self.checkpoint_config.path)
data = checkpoint_generation.load_checkpoint(
self.checkpoint_config.path, self.checkpoint_config.is_replicated
)

logging.info("data: %s", test_utils.pretty_format_pytree(data))

with benchmark_metrics.measure(
"sync_global_processes:benchmark:setup_pytree"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh:
logging.info('Creating mesh with config: %s', config)
devices = jax.devices()
num_devices = len(devices)
logging.info('num_devices: %s, devices: %s', num_devices, devices)
# Convert the user-friendly dict maps into ordered lists based on mesh_axes
ici_shape = [config.ici_parallelism.get(axis, 1) for axis in config.mesh_axes]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ You want a small, single-host TPU. It's instant and cheap.

```bash
# Create a standard v4-8
./manage_tpu.sh create --name orbax-dev-1 --type v4-8 --zone us-central2-b
./manage_tpu.sh create --tpu-name orbax-dev-1 --type v4-8 --zone us-central2-b
```

### Scenario B: "I need to test distributed scaling."
Expand All @@ -96,15 +96,15 @@ requires a "Queued Resource". Use `--node-count > 1`.
```bash
# Create a v6e-16 (2 nodes * 8 chips)
# This will poll the queue until your resource is ACTIVE.
./manage_tpu.sh create --name orbax-scale-1 --type v6e-16 --zone europe-west4-a --node-count 2
./manage_tpu.sh create --tpu-name orbax-scale-1 --type v6e-16 --zone europe-west4-a --node-count 2
```

### Scenario C: "Is it ready yet?"
You took a coffee break. Now check the status. The script automatically figures
out if you are asking about a VM or a Queue.

```bash
./manage_tpu.sh status --name orbax-scale-1
./manage_tpu.sh status --tpu-name orbax-scale-1
```

---
Expand Down Expand Up @@ -211,7 +211,7 @@ receives the configuration with minimal delay.
* **Solution**: The script acts as a facade. For operations like `status` or
`delete`, it attempts to find the resource in the QR API first. If that
fails (404), it falls back to the VM API.
* **Benefit**: a unified interface (`./manage_tpu.sh delete --name foo`) for all TPU types.
* **Benefit**: a unified interface (`./manage_tpu.sh delete --tpu-name foo`) for all TPU types.

### 3. Robust Setup (`setup_tpu.sh`)

Expand Down
17 changes: 17 additions & 0 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,3 +863,20 @@ def is_compression_used(

else:
return read_spec['metadata']['compressor'] is not None


def pretty_format_pytree(pytree: PyTree) -> str:
"""Returns a string representation of a PyTree for debugging."""
flat_tree = tree_utils.to_flat_dict(pytree)
lines = []
for k, v in flat_tree.items():
if isinstance(v, jax.Array):
leaf_info = str(
jax.ShapeDtypeStruct(v.shape, v.dtype, sharding=v.sharding)
)
elif isinstance(v, np.ndarray):
leaf_info = f'np.ndarray(shape={v.shape}, dtype={v.dtype})'
else:
leaf_info = str(v)
lines.append(f'{".".join(map(str, k))}: {leaf_info}')
return '\n'.join(lines)
Loading