Skip to content

refactor_logging_utils #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions training/utils/dist_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,44 @@ def load_checkpoint(pipe, args):
try:
with open(os.path.join(checkpoint_step_path, 'meta.json')) as f:
meta = json.load(f)
except:
print('failed to load meta.')
except FileNotFoundError:
print(f"Checkpoint metadata file not found at {os.path.join(checkpoint_step_path, 'meta.json')}")
return # Or handle appropriately
except Exception as e:
print(f"Failed to load meta.json: {e}")
# Decide if you want to return or raise

pipe.global_step = latest_step

model_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt')
try:
pipe.model.model.load_state_dict(
torch.load(
os.path.join(
checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_checkpoint.pt'
), map_location=torch.device('cpu')
)
torch.load(model_path, map_location=torch.device('cpu'))
)
except:
print('failed to load model params.')
except FileNotFoundError:
print(f"Model checkpoint file not found: {model_path}")
except Exception as e:
print(f"Failed to load model params from {model_path}: {e}")

optimizer_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt')
try:
pipe.optimizer.load_state_dict(
torch.load(
os.path.join(
checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_optimizer.pt'
), map_location=torch.device('cpu')
)
torch.load(optimizer_path, map_location=torch.device('cpu'))
)
except:
print('failed to load optim states.')
except FileNotFoundError:
print(f"Optimizer checkpoint file not found: {optimizer_path}")
except Exception as e:
print(f"Failed to load optim states from {optimizer_path}: {e}")

scheduler_path = os.path.join(checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt')
try:
pipe.scheduler.load_state_dict(
torch.load(
os.path.join(
checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt'
)
)
torch.load(scheduler_path)
)
except:
print('failed to load scheduler states.')
except FileNotFoundError:
print(f"Scheduler checkpoint file not found: {scheduler_path}")
except Exception as e:
print(f"Failed to load scheduler states from {scheduler_path}: {e}")


def save_checkpoint(pipe, args) -> str:
Expand Down Expand Up @@ -109,29 +110,28 @@ def save_stream_dataloader_state_dict(dataloader, pipe, args):
latest_step = pipe.global_step
checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}")

os.system(f"mkdir -p {checkpoint_step_path}")
os.makedirs(checkpoint_step_path, exist_ok=True)

torch.save(
dataloader.dataset.state_dict(),
os.path.join(
checkpoint_step_path, f'dataset_state_dict.pt'
dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
try:
torch.save(
dataloader.dataset.state_dict(),
dataset_state_dict_path
)
)
except Exception as e:
print(f"Failed to save dataset state_dict to {dataset_state_dict_path}: {e}")

def load_stream_dataloader_state_dict(dataloader, pipe, args):

latest_step = pipe.global_step
checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}")
# checkpoint_step_path is already defined in load_checkpoint, but if this function can be called independently:
checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}")

dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
try:
state_dict = torch.load(
os.path.join(
checkpoint_step_path, f'dataset_state_dict.pt'
)
)

dataloader.data.load_state_dict(state_dict)

state_dict = torch.load(dataset_state_dict_path)
dataloader.dataset.load_state_dict(state_dict) # Corrected: dataloader.dataset.load_state_dict
except FileNotFoundError:
print(f"Dataset state_dict file not found: {dataset_state_dict_path}")
except Exception as e:

print('failed to load dataset state_dict.')
print(f"Failed to load dataset state_dict from {dataset_state_dict_path}: {e}")
21 changes: 11 additions & 10 deletions training/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def init_train_logger(args):
if train_log_backend == 'print':
pass
elif train_log_backend == 'loguru':
os.system("mkdir -p logs")
loguru.logger.add("logs/file_{time}.log")
log_file_path = getattr(args, 'log_file_path', "logs/file_{time}.log")
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
loguru.logger.add(log_file_path)
elif train_log_backend == 'wandb':

assert _has_wandb
Expand All @@ -34,11 +35,13 @@ def init_train_logger(args):
import re
args.project_name = "test-" + \
re.sub('[^a-zA-Z0-9 \n\.]', '_', args.task_name)

wandb.init(
project=args.project_name,
config=args,
)
try:
wandb.init(
project=args.project_name,
config=args,
)
except Exception as e:
print(f"Error initializing wandb: {e}")

else:
raise Exception('Unknown logging backend.')
Expand All @@ -52,6 +55,4 @@ def train_log(x, *args, **kargs):
elif train_log_backend == 'wandb':
wandb.log(x, *args, **kargs)
else:
raise Exception('Unknown logging backend.')


raise Exception('Unknown logging backend.')
230 changes: 230 additions & 0 deletions training/utils/test_dist_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import unittest
from unittest.mock import patch, MagicMock, mock_open, call
import os
import shutil
import json
import torch # Keep torch for torch.device and potentially other utilities
import sys

# Import the functions to be tested
from training.utils.dist_checkpoint_utils import (
load_checkpoint,
save_checkpoint,
save_stream_dataloader_state_dict,
load_stream_dataloader_state_dict
)

# Mock get_pipeline_parallel_rank as it's used to construct file names
@patch('training.utils.dist_checkpoint_utils.get_pipeline_parallel_rank', MagicMock(return_value=0))
class TestDistCheckpointUtils(unittest.TestCase):

def setUp(self):
self.test_checkpoint_dir = "test_checkpoints"
# Ensure a clean state for each test
if os.path.exists(self.test_checkpoint_dir):
shutil.rmtree(self.test_checkpoint_dir)
os.makedirs(self.test_checkpoint_dir, exist_ok=True)

self.args = MagicMock()
self.args.checkpoint_path = self.test_checkpoint_dir

self.pipe = MagicMock()
self.pipe.global_step = 10
self.pipe.model = MagicMock()
self.pipe.model.model = MagicMock() # Mocking model.model attribute for state_dict
self.pipe.optimizer = MagicMock()
self.pipe.optimizer.state_dict = MagicMock(return_value={"opt_param": 1})
self.pipe.scheduler = MagicMock()
self.pipe.scheduler.state_dict = MagicMock(return_value={"sched_param": 1})
self.pipe.model.model.state_dict = MagicMock(return_value={"model_param": 1})

# Suppress print statements from the module
self.original_stdout = sys.stdout
sys.stdout = MagicMock()

def tearDown(self):
if os.path.exists(self.test_checkpoint_dir):
shutil.rmtree(self.test_checkpoint_dir)
sys.stdout = self.original_stdout # Restore stdout

def _create_dummy_checkpoint_files(self, step, create_meta=True, create_model=True, create_optimizer=True, create_scheduler=True):
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}")
os.makedirs(checkpoint_step_path, exist_ok=True)

if create_meta:
with open(os.path.join(checkpoint_step_path, 'meta.json'), 'w') as f:
json.dump({'step': step}, f)

if create_model:
torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt'))
if create_optimizer:
torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_optimizer.pt'))
if create_scheduler:
torch.save({}, os.path.join(checkpoint_step_path, 'prank_0_scheduler.pt'))

with open(os.path.join(self.test_checkpoint_dir, 'latest'), 'w') as f:
f.write(str(step))
return checkpoint_step_path

@patch('torch.load')
def test_load_checkpoint_success(self, mock_torch_load):
step = 10
self._create_dummy_checkpoint_files(step)
mock_torch_load.return_value = {"dummy_state": "value"}

load_checkpoint(self.pipe, self.args)

self.assertEqual(self.pipe.global_step, step)
self.pipe.model.model.load_state_dict.assert_called_once_with({"dummy_state": "value"})
self.pipe.optimizer.load_state_dict.assert_called_once_with({"dummy_state": "value"})
self.pipe.scheduler.load_state_dict.assert_called_once_with({"dummy_state": "value"})

def test_load_checkpoint_no_latest_file(self):
# No 'latest' file
load_checkpoint(self.pipe, self.args)
self.assertTrue(any("no checkpoint available, skipping" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))

def test_load_checkpoint_meta_file_not_found(self):
step = 5
self._create_dummy_checkpoint_files(step, create_meta=False)

load_checkpoint(self.pipe, self.args)
expected_msg = f"Checkpoint metadata file not found at {os.path.join(self.test_checkpoint_dir, f'checkpoint_{step}', 'meta.json')}"
self.assertTrue(any(expected_msg in call_args[0][0] for call_args in sys.stdout.write.call_args_list))

@patch('torch.load', side_effect=FileNotFoundError("File not found"))
def test_load_checkpoint_model_file_not_found(self, mock_torch_load):
step = 15
self._create_dummy_checkpoint_files(step, create_model=False) # Model file won't be there but mock matters more

# We need meta.json to proceed to loading attempts
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}")
with open(os.path.join(checkpoint_step_path, 'meta.json'), 'w') as f:
json.dump({'step': step}, f)
with open(os.path.join(self.test_checkpoint_dir, 'latest'), 'w') as f:
f.write(str(step))

load_checkpoint(self.pipe, self.args)
model_path = os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt')
self.assertTrue(any(f"Model checkpoint file not found: {model_path}" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))

@patch('torch.load', side_effect=RuntimeError("Torch load error"))
def test_load_checkpoint_torch_load_generic_error(self, mock_torch_load):
step = 20
self._create_dummy_checkpoint_files(step) # All files are there

load_checkpoint(self.pipe, self.args)
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{step}")
model_path = os.path.join(checkpoint_step_path, 'prank_0_checkpoint.pt')
self.assertTrue(any(f"Failed to load model params from {model_path}: Torch load error" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))


@patch('torch.save')
@patch('os.makedirs')
def test_save_checkpoint_directory_creation_and_save(self, mock_os_makedirs, mock_torch_save):
self.pipe.global_step = 25

save_checkpoint(self.pipe, self.args)

checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_25")
mock_os_makedirs.assert_called_once_with(checkpoint_step_path, exist_ok=True)

self.assertEqual(mock_torch_save.call_count, 3) # model, optimizer, scheduler

# Check meta.json content
meta_path = os.path.join(checkpoint_step_path, 'meta.json')
self.assertTrue(os.path.exists(meta_path))
with open(meta_path, 'r') as f:
meta = json.load(f)
self.assertEqual(meta['step'], 25)

# Check latest file content
latest_path = os.path.join(self.test_checkpoint_dir, 'latest')
self.assertTrue(os.path.exists(latest_path))
with open(latest_path, 'r') as f:
latest_step_str = f.read()
self.assertEqual(latest_step_str, "25")


@patch('torch.save')
@patch('os.makedirs')
def test_save_stream_dataloader_state_dict_creation_and_save(self, mock_os_makedirs, mock_torch_save):
self.pipe.global_step = 30
mock_dataloader = MagicMock()
mock_dataloader.dataset.state_dict.return_value = {"dataset_state": "some_state"}

save_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args)

checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_30")
mock_os_makedirs.assert_called_once_with(checkpoint_step_path, exist_ok=True)

dataset_state_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
mock_torch_save.assert_called_once_with({"dataset_state": "some_state"}, dataset_state_path)

@patch('torch.save', side_effect=Exception("Failed to save dataset"))
@patch('os.makedirs')
def test_save_stream_dataloader_state_dict_save_error(self, mock_os_makedirs, mock_torch_save):
self.pipe.global_step = 35
mock_dataloader = MagicMock()
mock_dataloader.dataset.state_dict.return_value = {"dataset_state": "some_state"}

save_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args)

checkpoint_step_path = os.path.join(self.test_checkpoint_dir, "checkpoint_35")
dataset_state_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
self.assertTrue(any(f"Failed to save dataset state_dict to {dataset_state_path}: Failed to save dataset" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))


@patch('torch.load')
def test_load_stream_dataloader_state_dict_success(self, mock_torch_load):
self.pipe.global_step = 40
# We need to ensure the checkpoint directory for this step exists for path construction
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}")
os.makedirs(checkpoint_step_path, exist_ok=True)
# Dummy file for torch.load to be "successful"
dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
with open(dataset_state_dict_path, 'w') as f: f.write("dummy data")


mock_dataloader = MagicMock()
mock_dataloader.dataset = MagicMock() # Ensure dataset attribute exists

mock_torch_load.return_value = {"loaded_state": "value"}

load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args)

mock_torch_load.assert_called_once_with(dataset_state_dict_path)
mock_dataloader.dataset.load_state_dict.assert_called_once_with({"loaded_state": "value"})

@patch('torch.load', side_effect=FileNotFoundError("Dataset state file not found"))
def test_load_stream_dataloader_state_dict_file_not_found(self, mock_torch_load):
self.pipe.global_step = 45
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}")
# No need to create the dummy file, as we are testing FileNotFoundError

mock_dataloader = MagicMock()
mock_dataloader.dataset = MagicMock()

load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args)

dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
self.assertTrue(any(f"Dataset state_dict file not found: {dataset_state_dict_path}" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))

@patch('torch.load', side_effect=RuntimeError("Torch load error for dataset"))
def test_load_stream_dataloader_state_dict_generic_error(self, mock_torch_load):
self.pipe.global_step = 50
checkpoint_step_path = os.path.join(self.test_checkpoint_dir, f"checkpoint_{self.pipe.global_step}")
os.makedirs(checkpoint_step_path, exist_ok=True)
dataset_state_dict_path = os.path.join(checkpoint_step_path, 'dataset_state_dict.pt')
with open(dataset_state_dict_path, 'w') as f: f.write("dummy data") # File exists

mock_dataloader = MagicMock()
mock_dataloader.dataset = MagicMock()

load_stream_dataloader_state_dict(mock_dataloader, self.pipe, self.args)

self.assertTrue(any(f"Failed to load dataset state_dict from {dataset_state_dict_path}: Torch load error for dataset" in call_args[0][0] for call_args in sys.stdout.write.call_args_list))


if __name__ == '__main__':
unittest.main()
Loading