Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Oct 19, 2023
1 parent cb0edd2 commit 7a11a42
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions src/llmtuner/dsets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def get_dataset(
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")
5 changes: 3 additions & 2 deletions src/llmtuner/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DataArguments:
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
)
interleave_probs: Optional[str] = field(
default=None,
Expand Down Expand Up @@ -106,7 +106,8 @@ def __post_init__(self):
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")

def init_for_training(self): # support mixing multiple datasets
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
try:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
Expand Down
4 changes: 2 additions & 2 deletions src/llmtuner/tuner/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def get_train_args(
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
# Check arguments
data_args.init_for_training(training_args.seed)

if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
Expand Down

0 comments on commit 7a11a42

Please sign in to comment.