Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize zero step function and make it customizable #7233

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 77 additions & 47 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,20 @@ def _clip_grad_norm(
if p.grad is not None:
p.grad.detach().mul_(clip_value)

@torch.no_grad()
def step(self, closure=None, **kwargs):
"""
Performs a single optimizer step and syncs parameters across all ranks.
"""
assert self.inited, "must call init_zero() first"

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
def _get_sharding_scheme(self, **kwargs):
if "sharding_scheme" in kwargs:
¦ return kwargs["sharding_scheme"]
else:
¦ return [
¦ ¦ {
¦ ¦ ¦ "scale_factor": 1.0,
¦ ¦ ¦ "sharding_group": self.sharding_groups,
¦ ¦ ¦ "group_size": self.local_world_size,
¦ ¦ },
¦ ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you seems to copy some formatter metadata in the pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed


def _reduce_gradients(self, **kwargs):
sharding_scheme = self._get_sharding_scheme(**kwargs)

# sync to base optimizer
self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups)
Expand All @@ -356,30 +359,34 @@ def step(self, closure=None, **kwargs):
if self.coalesce_cc_reduce_scatter:
padded_grads.append(padded_grad)
else:
grad_shard = xm.reduce_scatter(
xm.REDUCE_SUM,
padded_grad,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
grad_shard = padded_grad
for step in sharding_scheme:
grad_shard = xm.reduce_scatter(
xm.REDUCE_SUM,
grad_shard,
scale=step['scale_factor'] / step['group_size'],
scatter_dim=0,
shard_count=step['group_size'],
pin_layout=self.pin_layout,
groups=step['sharding_group'],
)
if grad_shard.dtype != self.optimizer_dtype:
grad_shard = grad_shard.to(dtype=self.optimizer_dtype)
shard.grad = grad_shard

if self.coalesce_cc_reduce_scatter:
grad_shards = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
padded_grads,
scale=1.0 / self.local_world_size,
scatter_dim=0,
shard_count=self.local_world_size,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb_reduce_scatter,
)
grad_shards = padded_grads
for step in sharding_scheme:
grad_shards = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
grad_shards,
scale=step['scale_factor'] / step['group_size'],
scatter_dim=0,
shard_count=step['group_size'],
pin_layout=self.pin_layout,
groups=step['sharding_group'],
bucket_cap_mb=self.bucket_cap_mb_reduce_scatter,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
Expand All @@ -393,10 +400,7 @@ def step(self, closure=None, **kwargs):
shard.grad = grad_shard
index += 1

if self.grad_clipping:
# Update unscale/clip with sub partitions
self._clip_grad_norm(max_norm=self.max_norm)

def _update_parameters(self, **kwargs):
# Step the wrapped optimizer
# Closure already executed, pass none here
self.base_optimizer.step(closure=None, **kwargs)
Expand All @@ -408,9 +412,31 @@ def step(self, closure=None, **kwargs):
# sync back
self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups)

@torch.no_grad()
def step(self, closure=None, **kwargs):
"""
Performs a single optimizer step and syncs parameters across all ranks.
"""
assert self.inited, "must call init_zero() first"

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

self._reduce_gradients(**kwargs)

if self.grad_clipping:
# Update unscale/clip with sub partitions
self._clip_grad_norm(max_norm=self.max_norm)

self._update_parameters(**kwargs)

return loss

def allgather_weights_and_update_full_parameter(self):
sharding_scheme = self._get_sharding_scheme(**kwargs)

# All gather the new weights across the ranks and assign them to the full parameters
sharded_data = []
for param_group, sharded_param_group in zip(
Expand All @@ -424,22 +450,26 @@ def allgather_weights_and_update_full_parameter(self):
if self.coalesce_cc_all_gather:
sharded_data.append(shard_data)
else:
padded_param = xm.all_gather(
shard_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
padded_param = shard_data
for step in reversed(sharding_scheme):
padded_param = xm.all_gather(
padded_param,
dim=0,
pin_layout=self.pin_layout,
groups=step['sharding_group'],
)
param.data.copy_(padded_param.data[:param.size(0)])

if self.coalesce_cc_all_gather:
padded_params = xm.all_gather_bucketized(
sharded_data,
dim=0,
pin_layout=self.pin_layout,
groups=self.sharding_groups,
bucket_cap_mb=self.bucket_cap_mb_all_gather,
)
padded_params = sharded_data
for step in reversed(sharding_scheme):
padded_params = xm.all_gather_bucketized(
padded_params,
dim=0,
pin_layout=self.pin_layout,
groups=step['sharding_group'],
bucket_cap_mb=self.bucket_cap_mb_all_gather,
)
index = 0
for param_group, sharded_param_group in zip(
self.param_groups, self.base_optimizer.param_groups):
Expand Down