Skip to content

Conversation

@saforem2
Copy link
Owner

@saforem2 saforem2 commented Oct 11, 2024

Summary by Sourcery

Refactor the training script to enhance logging, modularity, and data handling. Introduce new optimizers and refactor model components for better maintainability. Add helper scripts for ALCF system setup and deployment.

Enhancements:

  • Refactor the training script to improve logging and modularity, including the introduction of a logging system with different log levels and the use of decorators for profiling.
  • Enhance the data loading process by introducing a new dataset builder class and a concatenated dataset class to handle multiple datasets more efficiently.
  • Improve the optimizer setup by adding support for various optimizers, including new ones like GaLoreAdamW and SophiaG, and refactor the parameter group creation logic.
  • Refactor the model components, including the transformer layers and attention mechanisms, to improve readability and maintainability.
  • Introduce a new helper script for setting up and running training on ALCF systems, which includes functions for environment setup, job configuration, and command execution.

Build:

  • Add a new build script for data helpers to ensure that the necessary C++ extensions are compiled before running the training script.

Deployment:

  • Add a new script for launching training on Aurora using qsub, which sets up the environment and executes the training script.

saforem2 and others added 22 commits September 10, 2024 08:05
e.g.:
```bash
$ PBS_O_WORKDIR=$(pwd) LR=0.00020 OVERRIDE_CKPT_OPT_PARAM=1 bash train_aGPT_7B.sh --train-range-to-skip 43000 47000 --override-opt_param-scheduler
```

will override the `lr_scheduler` params from the checkpoint and instead
use the specified value, `LR=0.00020` instead.
Remve `--train-range-to-skip` logic from `pretrain_gpt_alcf.py` and
remove redundant code.
@sourcery-ai sourcery-ai bot changed the title @sourcery-ai Refactor training script for improved logging and modularity Oct 11, 2024
@sourcery-ai
Copy link

sourcery-ai bot commented Oct 11, 2024

Reviewer's Guide by Sourcery

This pull request refactors the training script for improved logging and modularity, introduces new features such as Llama2Tokenizer and support for additional optimizers, enhances dataset handling and logging, and updates build scripts for compatibility with ALCF systems. The changes span multiple files and include significant modifications to core training logic, data processing, and system-specific optimizations.

Sequence diagram for the training process with logging

sequenceDiagram
    participant User
    participant TrainingScript
    participant Logger
    User->>TrainingScript: Start training
    TrainingScript->>Logger: Initialize logging
    TrainingScript->>TrainingScript: initialize_megatron()
    TrainingScript->>TrainingScript: setup_model_and_optimizer()
    TrainingScript->>TrainingScript: train()
    TrainingScript->>Logger: Log training progress
    TrainingScript->>TrainingScript: evaluate()
    TrainingScript->>Logger: Log evaluation results
    TrainingScript->>TrainingScript: save_checkpoint_and_time()
    TrainingScript->>Logger: Log checkpoint status
    TrainingScript->>User: Training complete
Loading

ER diagram for dataset handling

erDiagram
    DATASET {
        string prefix
        string data_impl
        string splits_string
        int num_samples
        int seq_length
        int seed
        bool skip_warmup
    }
    DATASET ||--o{ DATASETBUILDER : builds
    DATASETBUILDER {
        string prefix
        string corpus
        int num_samples
        int seq_length
        int seed
        bool skip_warmup
    }
    DATASETBUILDER ||--o{ BUILDCONCATDATASET : concatenates
    BUILDCONCATDATASET {
        int num_datasets
        int num_samples
    }
Loading

Class diagram for the refactored training script

classDiagram
    class TrainingScript {
        +initialize_megatron()
        +setup_model_and_optimizer()
        +train()
        +evaluate()
        +save_checkpoint_and_time()
    }
    class DatasetBuilder {
        +Build()
    }
    class BuildConcatDataset {
        +__getitem__(int idx)
    }
    class RMSNorm {
        +__init__(int dim, float eps, bool sequence_parallel)
        +_norm(torch.Tensor x)
    }
    TrainingScript --> DatasetBuilder : uses
    TrainingScript --> BuildConcatDataset : uses
    TrainingScript --> RMSNorm : uses
Loading

File-Level Changes

Change Details Files
Refactor training script for improved logging and modularity
  • Replace print statements with structured logging using the logging module
  • Introduce Profile and PerfTrace classes for performance monitoring
  • Refactor main training loop for better readability and modularity
  • Add support for skipping specific training iterations
  • Implement more flexible command-line argument handling
megatron/training.py
Enhance dataset handling and introduce new dataset classes
  • Implement BuildConcatDataset and DatasetBuilder classes for efficient multi-dataset handling
  • Refactor BlendableDataset for improved performance and flexibility
  • Add support for corpus-specific dataset building and weighting
  • Implement distributed dataset index building for improved efficiency
megatron/data/gpt_dataset.py
megatron/data/blendable_dataset.py
Add support for Llama2Tokenizer and additional optimizers
  • Implement Llama2Tokenizer class
  • Add support for AdamW, SophiaG, and other optimizer variants
  • Implement GaLoreAdamW and GaLoreAdafactor optimizers
  • Add support for 8-bit Adam optimizer
megatron/tokenizer/tokenizer.py
megatron/optimizer/__init__.py
Update build scripts and environment setup for ALCF systems
  • Add helper functions for setting up the environment on different ALCF systems
  • Implement machine-specific configurations for Aurora, Sunspot, and Polaris
  • Add support for CCL (Collective Communications Library) on Aurora
  • Implement flexible DeepSpeed configuration generation
ALCF/helpers.sh
train_llama_alcf_aurora_qsub.sh
Improve checkpoint handling and learning rate management
  • Implement saving and loading of learning rate state
  • Refactor checkpoint loading and saving logic
  • Add support for resuming training from checkpoints with correct learning rate
megatron/checkpointing.py
Enhance parallel processing and distributed training capabilities
  • Refactor parallel transformer implementation for improved efficiency
  • Implement more flexible parallel attention mechanisms
  • Add support for different sequence parallelism configurations
  • Improve handling of rotary position embeddings in distributed settings
megatron/model/transformer.py
megatron/core/tensor_parallel/cross_entropy.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time. You can also use
    this command to specify where the summary should be inserted.

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @saforem2 - I've reviewed your changes and they look great!

Here's what I looked at during the review
  • 🟡 General issues: 8 issues found
  • 🟢 Security: all looks good
  • 🟢 Testing: all looks good
  • 🟡 Complexity: 6 issues found
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

first_chunk, rest_chunk = (
layernorm_output[:first_ns],
layernorm_output[first_ns:],
)
first_chunk = torch.nn.functional.pad(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider more explicit handling of edge cases instead of padding

While padding is a valid approach, more explicit handling of these edge cases could improve readability and potentially performance. It might be worth exploring alternative approaches.

if input_ids.size(1) < self.chunk_length:
    first_chunk = input_ids
else:
    first_chunk = input_ids[:, :self.chunk_length]

eps=args.adam_eps,
)

elif args.optimizer.lower() == "galore_adamw":
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider refactoring optimizer initialization to reduce code duplication

The current implementation repeats similar initialization code for multiple optimizers. Consider creating a factory function or using a dictionary mapping to initialize optimizers, which would improve maintainability and reduce the likelihood of errors when adding new optimizers.

def get_optimizer(args, model_params):
    optimizer_map = {
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
        "sgd": torch.optim.SGD,
        "galore_adamw": GaLoreAdamW if not args.use_8bit else GaLoreAdamW8bit,
    }
    optimizer_cls = optimizer_map.get(args.optimizer.lower())
    if optimizer_cls is None:
        raise ValueError(f"Unsupported optimizer: {args.optimizer}")
    return optimizer_cls(model_params, lr=args.lr, eps=args.adam_eps)

rho = args.sophiag_rho,
weight_decay=args.weight_decay
)
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Improve error handling for unsupported optimizers

Instead of raising a generic TypeError, consider creating a custom exception (e.g., UnsupportedOptimizerError) and including the name of the unsupported optimizer in the error message. This would provide more informative error messages and make it easier to catch specific optimization-related errors.

    else:
        raise UnsupportedOptimizerError(f"Optimizer '{optimizer_name}' is not supported")

class UnsupportedOptimizerError(Exception):
    pass


dlp = Profile("DATASET")
class BlendableDataset(torch.utils.data.Dataset):
@dlp.log
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider the performance impact of extensive logging

While logging is important for debugging and monitoring, excessive logging can impact performance. Consider adding a debug flag to conditionally enable detailed logging, allowing users to balance between performance and verbosity as needed.

Suggested change
@dlp.log
@dlp.log(condition=__debug__)

)


def training_log(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Optimize logging function to reduce performance overhead

The training_log function performs many operations and writes to multiple logging systems. Consider batching writes to TensorBoard and wandb, and use asynchronous logging where possible to minimize the impact on training performance. Additionally, consider making some of the more expensive logging operations configurable or less frequent.

@dlp.log
async def training_log(
    loss_dict,
    total_loss_dict,

Comment on lines 11 to 15

# Use glob to find all files matching the pattern
json_gz_files = glob.glob(search_pattern, recursive=True)

return json_gz_files
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)

Suggested change
# Use glob to find all files matching the pattern
json_gz_files = glob.glob(search_pattern, recursive=True)
return json_gz_files
return glob.glob(search_pattern, recursive=True)

Comment on lines 20 to 21
in_list = in_list + " " +str(i)
command = "cat" + in_list + " > " + output_file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Use f-string instead of string concatenation [×5] (use-fstring-for-concatenation)

Suggested change
in_list = in_list + " " +str(i)
command = "cat" + in_list + " > " + output_file
in_list = f"{in_list} {str(i)}"
command = f"cat{in_list} > {output_file}"

Comment on lines 11 to 15

# Use glob to find all files matching the pattern
json_gz_files = glob.glob(search_pattern, recursive=True)

return json_gz_files
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Inline variable that is immediately returned (inline-immediately-returned-variable)

Suggested change
# Use glob to find all files matching the pattern
json_gz_files = glob.glob(search_pattern, recursive=True)
return json_gz_files
return glob.glob(search_pattern, recursive=True)

Comment on lines 20 to 21
in_list = in_list + " " +str(i)
command = "cat" + in_list + " > " + output_file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Use f-string instead of string concatenation [×5] (use-fstring-for-concatenation)

Suggested change
in_list = in_list + " " +str(i)
command = "cat" + in_list + " > " + output_file
in_list = f"{in_list} {str(i)}"
command = f"cat{in_list} > {output_file}"

if vol + val > 4608:
# add this item to list and reset vol, sublist
vol = 0
sublist.append(key)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): We've found these issues:

@saforem2
Copy link
Owner Author

@sourcery-ai review

Copy link

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @saforem2 - I've reviewed your changes and they look great!

Here's what I looked at during the review
  • 🟡 General issues: 8 issues found
  • 🟡 Security: 1 issue found
  • 🟢 Testing: all looks good
  • 🟡 Complexity: 7 issues found
  • 🟡 Documentation: 8 issues found

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

elif args.optimizer.lower() == "galore_adamw":
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit
# redefine way to call galore_adamw
optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Error handling for missing GaLoreAdamW import

Consider adding a try-except block to handle potential ImportError if GaLoreAdamW is not available.

try:
    optimizer = GaLoreAdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
except ImportError:
    raise ImportError("GaLoreAdamW is not available. Please install the required package.")

log.setLevel(LOG_LEVEL) if RANK == 0 else log.setLevel("CRITICAL")
# --------------------------------------------------------------------------

dlp = Profile("DATASET")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Document the purpose and impact of the Profile

Add a brief comment explaining what this profiling is measuring and how it affects performance.

# Profile dataset operations for performance analysis
dlp = Profile("DATASET")


dlp = Profile("DATASET")
class BlendableDataset(torch.utils.data.Dataset):
@dlp.log
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider the performance impact of frequent logging

Evaluate if this logging is necessary in production or if it should be conditionally enabled for debugging.

    @dlp.log_if_debug

@@ -0,0 +1,114 @@
#!/usr/bin/env python
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Add a module docstring explaining the purpose of this script

A brief description of what this test script does and how to use it would be helpful for maintainers.

#!/usr/bin/env python
"""
Test script for blendable dataset functionality.

This script performs tests on the blendable dataset implementation,
measuring performance and validating correctness of data blending operations.
Usage: Run this script directly to execute all tests.
"""
import time
start_time = time.time()

@@ -0,0 +1,1418 @@
#!/bin/bash --login
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider modularizing the helpers.sh script

This script is over 1400 lines long and covers multiple concerns. Consider splitting it into smaller, focused modules (e.g., mpi_setup.sh, env_config.sh, utility_functions.sh) for better maintainability and readability.

#!/bin/bash --login

source "${BASH_SOURCE%/*}/mpi_setup.sh"
source "${BASH_SOURCE%/*}/env_config.sh"
source "${BASH_SOURCE%/*}/utility_functions.sh"

Comment on lines 541 to 543
raise Exception(
'Stage must have at least either encoder or decoder'
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)

ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
  • get more information about what type of error it is
  • define specific exception handling for it

This way, callers of the code can handle the error appropriately.

How can you solve this?

So instead of having code raising Exception or BaseException like

if incorrect_input(value):
    raise Exception("The input is incorrect")

you can have code raising a specific error like

if incorrect_input(value):
    raise ValueError("The input is incorrect")

or

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("The input is incorrect")

else:
raise Exception("Unsupported layer type, '%s'." %
self.layer_type.name)
raise Exception("Unsupported layer type, '%s'." % self.layer_type.name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)

ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
  • get more information about what type of error it is
  • define specific exception handling for it

This way, callers of the code can handle the error appropriately.

How can you solve this?

So instead of having code raising Exception or BaseException like

if incorrect_input(value):
    raise Exception("The input is incorrect")

you can have code raising a specific error like

if incorrect_input(value):
    raise ValueError("The input is incorrect")

or

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("The input is incorrect")

capturable: bool):

for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Swap if/else branches of if expression to remove negation (swap-if-expression)

Suggested change
grad = grads[i] if not maximize else -grads[i]
grad = -grads[i] if maximize else grads[i]


ExplanationNegated conditions are more difficult to read than positive ones, so it is best
to avoid them where we can. By swapping the if and else conditions around we
can invert the condition and make it positive.

raise Exception('dummy timer should not be used to '
'calculate elapsed time')

raise Exception("dummy timer should not be used to " "calculate elapsed time")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)

ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
  • get more information about what type of error it is
  • define specific exception handling for it

This way, callers of the code can handle the error appropriately.

How can you solve this?

So instead of having code raising Exception or BaseException like

if incorrect_input(value):
    raise Exception("The input is incorrect")

you can have code raising a specific error like

if incorrect_input(value):
    raise ValueError("The input is incorrect")

or

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("The input is incorrect")

else:
raise Exception('unknown timing log option {}'.format(
self._log_option))
raise Exception("unknown timing log option {}".format(self._log_option))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)

ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
  • get more information about what type of error it is
  • define specific exception handling for it

This way, callers of the code can handle the error appropriately.

How can you solve this?

So instead of having code raising Exception or BaseException like

if incorrect_input(value):
    raise Exception("The input is incorrect")

you can have code raising a specific error like

if incorrect_input(value):
    raise ValueError("The input is incorrect")

or

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("The input is incorrect")

@saforem2 saforem2 closed this pull request by merging all changes into main in 33962ee Nov 15, 2024
saforem2 added a commit that referenced this pull request Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants