From ea2a6f774299abea15450670049429655829784c Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:09:42 -0700 Subject: [PATCH] Revert "Modularize zero step function and make it customizable" (#7259) --- .../distributed/zero_redundancy_optimizer.py | 124 +++++++----------- 1 file changed, 47 insertions(+), 77 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 72ea3e7e12a..7e1e7b6cc10 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -324,20 +324,17 @@ def _clip_grad_norm( if p.grad is not None: p.grad.detach().mul_(clip_value) - def _get_sharding_scheme(self, **kwargs): - if "sharding_scheme" in kwargs: - return kwargs["sharding_scheme"] - else: - return [ - { - "scale_factor": 1.0, - "sharding_group": self.sharding_groups, - "group_size": self.local_world_size, - }, - ] - - def _reduce_gradients(self, **kwargs): - sharding_scheme = self._get_sharding_scheme(**kwargs) + @torch.no_grad() + def step(self, closure=None, **kwargs): + """ + Performs a single optimizer step and syncs parameters across all ranks. + """ + assert self.inited, "must call init_zero() first" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() # sync to base optimizer self._sync_param_groups(self.param_groups, self.base_optimizer.param_groups) @@ -359,34 +356,30 @@ def _reduce_gradients(self, **kwargs): if self.coalesce_cc_reduce_scatter: padded_grads.append(padded_grad) else: - grad_shard = padded_grad - for step in sharding_scheme: - grad_shard = xm.reduce_scatter( - xm.REDUCE_SUM, - grad_shard, - scale=step['scale_factor'] / step['group_size'], - scatter_dim=0, - shard_count=step['group_size'], - pin_layout=self.pin_layout, - groups=step['sharding_group'], - ) + grad_shard = xm.reduce_scatter( + xm.REDUCE_SUM, + padded_grad, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) shard.grad = grad_shard if self.coalesce_cc_reduce_scatter: - grad_shards = padded_grads - for step in sharding_scheme: - grad_shards = xm.reduce_scatter_bucketized( - xm.REDUCE_SUM, - grad_shards, - scale=step['scale_factor'] / step['group_size'], - scatter_dim=0, - shard_count=step['group_size'], - pin_layout=self.pin_layout, - groups=step['sharding_group'], - bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, - ) + grad_shards = xm.reduce_scatter_bucketized( + xm.REDUCE_SUM, + padded_grads, + scale=1.0 / self.local_world_size, + scatter_dim=0, + shard_count=self.local_world_size, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb_reduce_scatter, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): @@ -400,7 +393,10 @@ def _reduce_gradients(self, **kwargs): shard.grad = grad_shard index += 1 - def _update_parameters(self, **kwargs): + if self.grad_clipping: + # Update unscale/clip with sub partitions + self._clip_grad_norm(max_norm=self.max_norm) + # Step the wrapped optimizer # Closure already executed, pass none here self.base_optimizer.step(closure=None, **kwargs) @@ -412,31 +408,9 @@ def _update_parameters(self, **kwargs): # sync back self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) - @torch.no_grad() - def step(self, closure=None, **kwargs): - """ - Performs a single optimizer step and syncs parameters across all ranks. - """ - assert self.inited, "must call init_zero() first" - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - self._reduce_gradients(**kwargs) - - if self.grad_clipping: - # Update unscale/clip with sub partitions - self._clip_grad_norm(max_norm=self.max_norm) - - self._update_parameters(**kwargs) - return loss def allgather_weights_and_update_full_parameter(self): - sharding_scheme = self._get_sharding_scheme(**kwargs) - # All gather the new weights across the ranks and assign them to the full parameters sharded_data = [] for param_group, sharded_param_group in zip( @@ -450,26 +424,22 @@ def allgather_weights_and_update_full_parameter(self): if self.coalesce_cc_all_gather: sharded_data.append(shard_data) else: - padded_param = shard_data - for step in reversed(sharding_scheme): - padded_param = xm.all_gather( - padded_param, - dim=0, - pin_layout=self.pin_layout, - groups=step['sharding_group'], - ) + padded_param = xm.all_gather( + shard_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + ) param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc_all_gather: - padded_params = sharded_data - for step in reversed(sharding_scheme): - padded_params = xm.all_gather_bucketized( - padded_params, - dim=0, - pin_layout=self.pin_layout, - groups=step['sharding_group'], - bucket_cap_mb=self.bucket_cap_mb_all_gather, - ) + padded_params = xm.all_gather_bucketized( + sharded_data, + dim=0, + pin_layout=self.pin_layout, + groups=self.sharding_groups, + bucket_cap_mb=self.bucket_cap_mb_all_gather, + ) index = 0 for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups):