Skip to content

Commit b71c0e3

Browse files
authored
Print axolotl art if train is called outside of cli: (axolotl-ai-cloud#2627) [skip ci]
1 parent ddaebf8 commit b71c0e3

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

src/axolotl/cli/art.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@
1616
@@@@ @@@@@@@@@@@@@@@@
1717
"""
1818

19+
HAS_PRINTED_LOGO = False
20+
1921

2022
def print_axolotl_text_art():
2123
"""Prints axolotl ASCII art."""
24+
25+
global HAS_PRINTED_LOGO # pylint: disable=global-statement
26+
if HAS_PRINTED_LOGO:
27+
return
2228
if is_main_process():
29+
HAS_PRINTED_LOGO = True
2330
print(AXOLOTL_LOGO)

src/axolotl/common/datasets.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def load_datasets(
4848
*,
4949
cfg: DictDefault,
5050
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
51+
debug: bool = False,
5152
) -> TrainDatasetMeta:
5253
"""
5354
Loads one or more training or evaluation datasets, calling
@@ -56,6 +57,7 @@ def load_datasets(
5657
Args:
5758
cfg: Dictionary mapping `axolotl` config keys to values.
5859
cli_args: Command-specific CLI arguments.
60+
debug: Whether to print out tokenization of sample
5961
6062
Returns:
6163
Dataclass with fields for training and evaluation datasets and the computed
@@ -77,20 +79,25 @@ def load_datasets(
7779
preprocess_iterable=preprocess_iterable,
7880
)
7981

80-
if cli_args and (
81-
cli_args.debug
82-
or cfg.debug
83-
or cli_args.debug_text_only
84-
or int(cli_args.debug_num_examples) > 0
85-
):
82+
if ( # pylint: disable=too-many-boolean-expressions
83+
cli_args
84+
and (
85+
cli_args.debug
86+
or cfg.debug
87+
or cli_args.debug_text_only
88+
or int(cli_args.debug_num_examples) > 0
89+
)
90+
) or debug:
8691
LOG.info("check_dataset_labels...")
8792

88-
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
93+
num_examples = cli_args.debug_num_examples if cli_args else 1
94+
text_only = cli_args.debug_text_only if cli_args else False
95+
train_samples = sample_dataset(train_dataset, num_examples)
8996
check_dataset_labels(
9097
train_samples,
9198
tokenizer,
92-
num_examples=cli_args.debug_num_examples,
93-
text_only=cli_args.debug_text_only,
99+
num_examples=num_examples,
100+
text_only=text_only,
94101
)
95102

96103
LOG.info("printing prompters...")

src/axolotl/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
2222
from transformers.trainer import Trainer
2323

24+
from axolotl.cli.art import print_axolotl_text_art
2425
from axolotl.common.datasets import TrainDatasetMeta
2526
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
2627
fix_untrained_tokens,
@@ -516,6 +517,8 @@ def train(
516517
Returns:
517518
Tuple of (model, tokenizer) after training
518519
"""
520+
print_axolotl_text_art()
521+
519522
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
520523
(
521524
trainer,

0 commit comments

Comments
 (0)