-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor training script for improved logging and modularity #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Merge `alcf-tests` into `main`
Merge `polaris-cuda122` branch into main
e.g.: ```bash $ PBS_O_WORKDIR=$(pwd) LR=0.00020 OVERRIDE_CKPT_OPT_PARAM=1 bash train_aGPT_7B.sh --train-range-to-skip 43000 47000 --override-opt_param-scheduler ``` will override the `lr_scheduler` params from the checkpoint and instead use the specified value, `LR=0.00020` instead.
Train skip range
Remve `--train-range-to-skip` logic from `pretrain_gpt_alcf.py` and remove redundant code.
Added Sophia Optimizer
Reviewer's Guide by SourceryThis pull request refactors the training script for improved logging and modularity, introduces new features such as Llama2Tokenizer and support for additional optimizers, enhances dataset handling and logging, and updates build scripts for compatibility with ALCF systems. The changes span multiple files and include significant modifications to core training logic, data processing, and system-specific optimizations. Sequence diagram for the training process with loggingsequenceDiagram
participant User
participant TrainingScript
participant Logger
User->>TrainingScript: Start training
TrainingScript->>Logger: Initialize logging
TrainingScript->>TrainingScript: initialize_megatron()
TrainingScript->>TrainingScript: setup_model_and_optimizer()
TrainingScript->>TrainingScript: train()
TrainingScript->>Logger: Log training progress
TrainingScript->>TrainingScript: evaluate()
TrainingScript->>Logger: Log evaluation results
TrainingScript->>TrainingScript: save_checkpoint_and_time()
TrainingScript->>Logger: Log checkpoint status
TrainingScript->>User: Training complete
ER diagram for dataset handlingerDiagram
DATASET {
string prefix
string data_impl
string splits_string
int num_samples
int seq_length
int seed
bool skip_warmup
}
DATASET ||--o{ DATASETBUILDER : builds
DATASETBUILDER {
string prefix
string corpus
int num_samples
int seq_length
int seed
bool skip_warmup
}
DATASETBUILDER ||--o{ BUILDCONCATDATASET : concatenates
BUILDCONCATDATASET {
int num_datasets
int num_samples
}
Class diagram for the refactored training scriptclassDiagram
class TrainingScript {
+initialize_megatron()
+setup_model_and_optimizer()
+train()
+evaluate()
+save_checkpoint_and_time()
}
class DatasetBuilder {
+Build()
}
class BuildConcatDataset {
+__getitem__(int idx)
}
class RMSNorm {
+__init__(int dim, float eps, bool sequence_parallel)
+_norm(torch.Tensor x)
}
TrainingScript --> DatasetBuilder : uses
TrainingScript --> BuildConcatDataset : uses
TrainingScript --> RMSNorm : uses
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @saforem2 - I've reviewed your changes and they look great!
Here's what I looked at during the review
- 🟡 General issues: 8 issues found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 6 issues found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| first_chunk, rest_chunk = ( | ||
| layernorm_output[:first_ns], | ||
| layernorm_output[first_ns:], | ||
| ) | ||
| first_chunk = torch.nn.functional.pad( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider more explicit handling of edge cases instead of padding
While padding is a valid approach, more explicit handling of these edge cases could improve readability and potentially performance. It might be worth exploring alternative approaches.
if input_ids.size(1) < self.chunk_length:
first_chunk = input_ids
else:
first_chunk = input_ids[:, :self.chunk_length]
| eps=args.adam_eps, | ||
| ) | ||
|
|
||
| elif args.optimizer.lower() == "galore_adamw": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider refactoring optimizer initialization to reduce code duplication
The current implementation repeats similar initialization code for multiple optimizers. Consider creating a factory function or using a dictionary mapping to initialize optimizers, which would improve maintainability and reduce the likelihood of errors when adding new optimizers.
def get_optimizer(args, model_params):
optimizer_map = {
"adam": torch.optim.Adam,
"adamw": torch.optim.AdamW,
"sgd": torch.optim.SGD,
"galore_adamw": GaLoreAdamW if not args.use_8bit else GaLoreAdamW8bit,
}
optimizer_cls = optimizer_map.get(args.optimizer.lower())
if optimizer_cls is None:
raise ValueError(f"Unsupported optimizer: {args.optimizer}")
return optimizer_cls(model_params, lr=args.lr, eps=args.adam_eps)
| rho = args.sophiag_rho, | ||
| weight_decay=args.weight_decay | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Improve error handling for unsupported optimizers
Instead of raising a generic TypeError, consider creating a custom exception (e.g., UnsupportedOptimizerError) and including the name of the unsupported optimizer in the error message. This would provide more informative error messages and make it easier to catch specific optimization-related errors.
else:
raise UnsupportedOptimizerError(f"Optimizer '{optimizer_name}' is not supported")
class UnsupportedOptimizerError(Exception):
pass
|
|
||
| dlp = Profile("DATASET") | ||
| class BlendableDataset(torch.utils.data.Dataset): | ||
| @dlp.log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Consider the performance impact of extensive logging
While logging is important for debugging and monitoring, excessive logging can impact performance. Consider adding a debug flag to conditionally enable detailed logging, allowing users to balance between performance and verbosity as needed.
| @dlp.log | |
| @dlp.log(condition=__debug__) |
| ) | ||
|
|
||
|
|
||
| def training_log( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Optimize logging function to reduce performance overhead
The training_log function performs many operations and writes to multiple logging systems. Consider batching writes to TensorBoard and wandb, and use asynchronous logging where possible to minimize the impact on training performance. Additionally, consider making some of the more expensive logging operations configurable or less frequent.
@dlp.log
async def training_log(
loss_dict,
total_loss_dict,
|
|
||
| # Use glob to find all files matching the pattern | ||
| json_gz_files = glob.glob(search_pattern, recursive=True) | ||
|
|
||
| return json_gz_files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)
| # Use glob to find all files matching the pattern | |
| json_gz_files = glob.glob(search_pattern, recursive=True) | |
| return json_gz_files | |
| return glob.glob(search_pattern, recursive=True) |
| in_list = in_list + " " +str(i) | ||
| command = "cat" + in_list + " > " + output_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Use f-string instead of string concatenation [×5] (use-fstring-for-concatenation)
| in_list = in_list + " " +str(i) | |
| command = "cat" + in_list + " > " + output_file | |
| in_list = f"{in_list} {str(i)}" | |
| command = f"cat{in_list} > {output_file}" |
|
|
||
| # Use glob to find all files matching the pattern | ||
| json_gz_files = glob.glob(search_pattern, recursive=True) | ||
|
|
||
| return json_gz_files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)
| # Use glob to find all files matching the pattern | |
| json_gz_files = glob.glob(search_pattern, recursive=True) | |
| return json_gz_files | |
| return glob.glob(search_pattern, recursive=True) |
| in_list = in_list + " " +str(i) | ||
| command = "cat" + in_list + " > " + output_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Use f-string instead of string concatenation [×5] (use-fstring-for-concatenation)
| in_list = in_list + " " +str(i) | |
| command = "cat" + in_list + " > " + output_file | |
| in_list = f"{in_list} {str(i)}" | |
| command = f"cat{in_list} > {output_file}" |
| if vol + val > 4608: | ||
| # add this item to list and reset vol, sublist | ||
| vol = 0 | ||
| sublist.append(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): We've found these issues:
- Hoist repeated code outside conditional statement (
hoist-statement-from-if) - Use f-string instead of string concatenation [×3] (
use-fstring-for-concatenation)
|
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @saforem2 - I've reviewed your changes and they look great!
Here's what I looked at during the review
- 🟡 General issues: 8 issues found
- 🟡 Security: 1 issue found
- 🟢 Testing: all looks good
- 🟡 Complexity: 7 issues found
- 🟡 Documentation: 8 issues found
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| elif args.optimizer.lower() == "galore_adamw": | ||
| from galore_torch import GaLoreAdamW, GaLoreAdamW8bit | ||
| # redefine way to call galore_adamw | ||
| optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): Error handling for missing GaLoreAdamW import
Consider adding a try-except block to handle potential ImportError if GaLoreAdamW is not available.
try:
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
except ImportError:
raise ImportError("GaLoreAdamW is not available. Please install the required package.")
| log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL") | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| dlp = Profile("DATASET") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Document the purpose and impact of the Profile
Add a brief comment explaining what this profiling is measuring and how it affects performance.
# Profile dataset operations for performance analysis
dlp = Profile("DATASET")
|
|
||
| dlp = Profile("DATASET") | ||
| class BlendableDataset(torch.utils.data.Dataset): | ||
| @dlp.log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (performance): Consider the performance impact of frequent logging
Evaluate if this logging is necessary in production or if it should be conditionally enabled for debugging.
@dlp.log_if_debug
| @@ -0,0 +1,114 @@ | |||
| #!/usr/bin/env python | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Add a module docstring explaining the purpose of this script
A brief description of what this test script does and how to use it would be helpful for maintainers.
#!/usr/bin/env python
"""
Test script for blendable dataset functionality.
This script performs tests on the blendable dataset implementation,
measuring performance and validating correctness of data blending operations.
Usage: Run this script directly to execute all tests.
"""
import time
start_time = time.time()
| @@ -0,0 +1,1418 @@ | |||
| #!/bin/bash --login | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider modularizing the helpers.sh script
This script is over 1400 lines long and covers multiple concerns. Consider splitting it into smaller, focused modules (e.g., mpi_setup.sh, env_config.sh, utility_functions.sh) for better maintainability and readability.
#!/bin/bash --login
source "${BASH_SOURCE%/*}/mpi_setup.sh"
source "${BASH_SOURCE%/*}/env_config.sh"
source "${BASH_SOURCE%/*}/utility_functions.sh"
| raise Exception( | ||
| 'Stage must have at least either encoder or decoder' | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)
Explanation
If a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:- get more information about what type of error it is
- define specific exception handling for it
This way, callers of the code can handle the error appropriately.
How can you solve this?
- Use one of the built-in exceptions of the standard library.
- Define your own error class that subclasses
Exception.
So instead of having code raising Exception or BaseException like
if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like
if incorrect_input(value):
raise ValueError("The input is incorrect")or
class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")| else: | ||
| raise Exception("Unsupported layer type, '%s'." % | ||
| self.layer_type.name) | ||
| raise Exception("Unsupported layer type, '%s'." % self.layer_type.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)
Explanation
If a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:- get more information about what type of error it is
- define specific exception handling for it
This way, callers of the code can handle the error appropriately.
How can you solve this?
- Use one of the built-in exceptions of the standard library.
- Define your own error class that subclasses
Exception.
So instead of having code raising Exception or BaseException like
if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like
if incorrect_input(value):
raise ValueError("The input is incorrect")or
class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")| capturable: bool): | ||
|
|
||
| for i, param in enumerate(params): | ||
| grad = grads[i] if not maximize else -grads[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Swap if/else branches of if expression to remove negation (swap-if-expression)
| grad = grads[i] if not maximize else -grads[i] | |
| grad = -grads[i] if maximize else grads[i] |
Explanation
Negated conditions are more difficult to read than positive ones, so it is bestto avoid them where we can. By swapping the
if and else conditions around wecan invert the condition and make it positive.
| raise Exception('dummy timer should not be used to ' | ||
| 'calculate elapsed time') | ||
|
|
||
| raise Exception("dummy timer should not be used to " "calculate elapsed time") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)
Explanation
If a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:- get more information about what type of error it is
- define specific exception handling for it
This way, callers of the code can handle the error appropriately.
How can you solve this?
- Use one of the built-in exceptions of the standard library.
- Define your own error class that subclasses
Exception.
So instead of having code raising Exception or BaseException like
if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like
if incorrect_input(value):
raise ValueError("The input is incorrect")or
class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")| else: | ||
| raise Exception('unknown timing log option {}'.format( | ||
| self._log_option)) | ||
| raise Exception("unknown timing log option {}".format(self._log_option)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)
Explanation
If a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:- get more information about what type of error it is
- define specific exception handling for it
This way, callers of the code can handle the error appropriately.
How can you solve this?
- Use one of the built-in exceptions of the standard library.
- Define your own error class that subclasses
Exception.
So instead of having code raising Exception or BaseException like
if incorrect_input(value):
raise Exception("The input is incorrect")you can have code raising a specific error like
if incorrect_input(value):
raise ValueError("The input is incorrect")or
class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")
Summary by Sourcery
Refactor the training script to enhance logging, modularity, and data handling. Introduce new optimizers and refactor model components for better maintainability. Add helper scripts for ALCF system setup and deployment.
Enhancements:
Build:
Deployment: