Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,20 @@ def load(self, path: str, map_location=None):

# Load reference policy checkpoint if available
if hasattr(self, "ref_policy_checkpoint_path") and self.ref_policy_checkpoint_path:
state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device)
if hasattr(self.ref_policy, "module"):
# If wrapped by DeepSpeed
self.ref_policy.module.load_state_dict(state_dict)
try:
state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device)
if hasattr(self.ref_policy, "module"):
# If wrapped by DeepSpeed
self.ref_policy.module.load_state_dict(state_dict)
else:
self.ref_policy.load_state_dict(state_dict)
except (OSError, RuntimeError) as err:
logger.warning(
f"{self.rank=}: Failed to load reference policy from "
f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights."
)
else:
self.ref_policy.load_state_dict(state_dict)
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")
Comment on lines +779 to +792
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try-except block is a great addition for robustness. I have a couple of suggestions to make it even more robust:

  1. Broader Exception Handling: The current except (OSError, RuntimeError) might not catch all possible file corruption errors from torch.load, such as pickle.UnpicklingError or zipfile.BadZipFile. Since loading the reference policy is a best-effort operation, it would be safer to catch a broader Exception to prevent any loading-related crash.

  2. Consistent weights_only parameter: Earlier in this file (lines 625-636), torch.load is monkey-patched with weights_only=False to handle _pickle.UnpicklingError. For consistency and to prevent similar issues, especially as the default for weights_only may change in future PyTorch versions, it's a good practice to explicitly set weights_only=False in this torch.load call as well.

Here is a suggested change that incorporates these points.

Suggested change
try:
state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device)
if hasattr(self.ref_policy, "module"):
# If wrapped by DeepSpeed
self.ref_policy.module.load_state_dict(state_dict)
else:
self.ref_policy.load_state_dict(state_dict)
except (OSError, RuntimeError) as err:
logger.warning(
f"{self.rank=}: Failed to load reference policy from "
f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights."
)
else:
self.ref_policy.load_state_dict(state_dict)
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")
try:
state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device, weights_only=False)
if hasattr(self.ref_policy, "module"):
# If wrapped by DeepSpeed
self.ref_policy.module.load_state_dict(state_dict)
else:
self.ref_policy.load_state_dict(state_dict)
except Exception as err:
logger.warning(
f"{self.rank=}: Failed to load reference policy from "
f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights."
)
else:
logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}")

self.local_metrics = MetricsTracker(max_metrics=32, device=self.device)
return optimization_steps_done

Expand Down