Skip to content

Commit 7a89cfd

Browse files
author
maxtext authors
committed
Merge pull request #1562 from AI-Hypercomputer:bvandermoon-xpk-path
PiperOrigin-RevId: 746215624
2 parents ef398f7 + 4710727 commit 7a89cfd

File tree

4 files changed

+8
-32
lines changed

4 files changed

+8
-32
lines changed

benchmarks/globals.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from tempfile import gettempdir
2222
from benchmarks import xla_flags_library
2323

24-
from benchmarks.globals import PKG_DIR
25-
2624
# TODO(vbarr@) Abstract software features like checkpointing,
2725
# real data / synthetic data out of this config
2826
# TODO(vbarr@) Make slice dependent configurations to allow for a model's tuning
@@ -629,7 +627,7 @@ def _add_to_model_dictionary(
629627
"profiler": "xplane",
630628
"dataset_path": "gs://max-datasets-rogue",
631629
"dataset_type": "tfds",
632-
"tokenizer_path": os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer.llama2"),
630+
"tokenizer_path": os.path.join("assets", "tokenizer.llama2"),
633631
"sa_block_q": 1024,
634632
"sa_block_q_dkv": 2048,
635633
"sa_block_q_dq": 2048,
@@ -1605,7 +1603,7 @@ def _add_to_model_dictionary(
16051603
"reuse_example_batch": 1,
16061604
"enable_checkpointing": False,
16071605
"profiler": "xplane",
1608-
"tokenizer_path": os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer.llama2"),
1606+
"tokenizer_path": os.path.join("assets", "tokenizer.llama2"),
16091607
"sa_block_q": 2048,
16101608
"sa_block_q_dkv": 2048,
16111609
"sa_block_q_dq": 2048,
@@ -1638,7 +1636,7 @@ def _add_to_model_dictionary(
16381636
"reuse_example_batch": 1,
16391637
"enable_checkpointing": False,
16401638
"profiler": "xplane",
1641-
"tokenizer_path": os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer.llama2"),
1639+
"tokenizer_path": os.path.join("assets", "tokenizer.llama2"),
16421640
"sa_block_q": 2048,
16431641
"sa_block_q_dkv": 2048,
16441642
"sa_block_q_dq": 2048,

benchmarks/maxtext_v5e_model_configs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os.path
1717
from benchmarks import xla_flags_library
1818
from benchmarks.maxtext_trillium_model_configs import MaxTextModel, _add_to_model_dictionary
19-
from benchmarks.globals import PKG_DIR
2019

2120

2221
v5e_model_dict = {}
@@ -162,7 +161,7 @@
162161
"remat_policy": "save_qkv_proj",
163162
"max_target_length": 2048,
164163
"use_iota_embed": True,
165-
"tokenizer_path": os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer.llama2"),
164+
"tokenizer_path": os.path.join("assets", "tokenizer.llama2"),
166165
"dataset_path": "gs://max-datasets-rogue",
167166
"dataset_type": "synthetic",
168167
"reuse_example_batch": 1,
@@ -187,7 +186,7 @@
187186
"remat_policy": "qkv_proj_offloaded",
188187
"max_target_length": 2048,
189188
"use_iota_embed": True,
190-
"tokenizer_path": os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer.llama2"),
189+
"tokenizer_path": os.path.join("assets", "tokenizer.llama2"),
191190
"dataset_path": "gs://max-datasets-rogue",
192191
"dataset_type": "synthetic",
193192
"reuse_example_batch": 1,

benchmarks/maxtext_xpk_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class WorkloadConfig:
105105
generate_metrics_and_upload_to_big_query: bool = True
106106
hardware_id: str = 'v6e'
107107
metrics_gcs_file: str = ''
108-
base_config: str = os.path.join(PKG_DIR, "configs", "base.yml")
108+
base_config: str = os.path.join("MaxText", "configs", "base.yml")
109109
topology: str = dataclasses.field(init=False)
110110
num_devices_per_slice: int = dataclasses.field(init=False)
111111
db_project: str = ""
@@ -349,7 +349,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
349349
args["xla_flags"] = f"'{xla_flags_str}'"
350350
args["dataset"] = dataset
351351
args["run_type"] = "maxtext-xpk"
352-
args["config_file"] = os.path.join(PKG_DIR, "configs", "base.yml")
352+
args["config_file"] = os.path.join("MaxText", "configs", "base.yml")
353353
args["topology"] = wl_config.topology
354354
args["tuning_params"] = f"'{tuning_params_str}'"
355355
args["db_project"] = wl_config.db_project
@@ -413,7 +413,7 @@ def build_user_command(
413413
'export ENABLE_PATHWAYS_PERSISTENCE=1 &&',
414414
f'export JAX_PLATFORMS={jax_platforms} &&',
415415
'export ENABLE_PJRT_COMPATIBILITY=true &&',
416-
f'python3 -m MaxText.train {os.path.join(PKG_DIR, "configs", "base.yml")}',
416+
f'python3 -m MaxText.train {os.path.join("MaxText", "configs", "base.yml")}',
417417
f'{config_tuning_params}',
418418
f'steps={wl_config.num_steps}',
419419
f'model_name={wl_config.model.model_type}',

0 commit comments

Comments
 (0)