Skip to content

Add rocm perf yml file #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: rocm-jaxlib-v0.5.0
Choose a base branch
from
Open

Conversation

Ruturaj4
Copy link

@Ruturaj4 Ruturaj4 commented May 12, 2025

This PR adds a new GitHub Actions workflow that:

Builds JAX with ROCm support inside a Docker container.
Runs training for the following MaxText models:

  • llama2_7b
  • gemma_2b
  • gpt3_6b
  • mixtral_8x1b

Captures stdout logs for each model and extracts per-step timing

Ignores step 0 (warmup) when computing metrics

Computes median_step_time per model and saves it to summary.json

Uploads logs and metrics as workflow artifacts

A Python analysis script (analyze_maxtext_logs.py) is added under jax/build/rocm/ to parse logs and generate the summary.

@Ruturaj4 Ruturaj4 closed this May 12, 2025
@Ruturaj4 Ruturaj4 reopened this May 12, 2025
Comment on lines 103 to 108
times.append(float(m.group(1)))
if times:
summary[model] = {
"median_step_time": round(float(np.median(times)), 3),
"steps_counted": len(times)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grab the parsed steps too with the summary

Suggested change
times.append(float(m.group(1)))
if times:
summary[model] = {
"median_step_time": round(float(np.median(times)), 3),
"steps_counted": len(times)
}
times.append(float(m.group(1)))
if times:
step_info = list([{"step": n, "time": t} for n,t in enumerate(times)])
summary[model] = {
"steps": step_info,
"median_step_time": round(float(np.median(times)), 3),
"steps_counted": len(times)
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good


- name: Run MaxText training and save logs
run: |
docker exec maxtext_container bash -c "pip install -r requirements.txt"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see this pip install being an issue across jax versions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, no, it will be different for different branches used in rocm/maxtext, for e.g. currently I am using rv_jax, but I will be renaming it to jax_0.5.0 or something

@Ruturaj4 Ruturaj4 force-pushed the bring_rocm_dlm_perf branch from d2912f5 to 952cc4f Compare May 13, 2025 12:13
@charleshofer
Copy link
Collaborator

This is a port from the other performance CI PR, right? Could you add a description and link to the original PR?

@i-chaochen
Copy link

i-chaochen commented May 21, 2025

are we considering of grok and alphafold models? @Ruturaj4 @JehandadKhan

Yes, we are def planning to add alphafold, however grok testing takes too much time to download weights. If grok training can be done or if there are ways to run grok faster, we are happy to add those as well!

@Arech8
Copy link

Arech8 commented May 22, 2025

Why did you chose to report median step time?

I don't know the rationale for that, but in general, I'm not sure that median is a correct metric here. It rejects outliers and alone it totally doesn't describe distribution of values, but that is exactly what is important to know:

  • if there are important outliers (in any direction):
    • it's important to investigate them. Like, if there's some java-like garbage collection step that make an app totally unresponsive, - this is just a no-go in many contexts. This is still true for model training/inferencing.
    • outliers significantly influence total runtime and a perception of "fast" or "slow" for users.
  • robust statistics which median is a part of, are the best in describing shapes of any distributions, however no single metric of robust statistics is able to do that in isolation: several of them must be used. At least, min + max also, but quartiles (25% + 75%) are generally super useful also.
    • if you want one value, a mean is much better in that as it contains equally weighted information from all samples, while median describes only 1 or 2 samples at best, leaving just nothing about the rest.
    • mean have additional nice property that it allows to forecast a total runtime in a different circumstances. For example, if you've measured from 100 epochs that your average time per epoch is 1s, then you have some reasons to expect that 1000 epochs will last 1000 seconds. If you do the same for median - you can say nothing even about the next 100 epoch run.

TLDR: mean metric seem much better here. For the best results, I'd make 6 values: [0, 25, 50, 75, 100]% quantiles + mean too (b/c of the last bullet point) (and God forbid of stddev)

run: |
docker exec maxtext_container bash -c "pip install -r requirements.txt"
for config in \
MaxText/configs/models/gpu/llama2_7b_rocm.yml \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, batch sizes & whatnot config values in the yamls are tailored for MI250 (as per CI machine spec above) ?
If so and the branch is going to also be used anywhere else (like merged to main), then perhaps it's wise to rename the files to change _rocm suffix to _MI250 to make the tailoring apparent.
MI250 already implies ROCm, but ROCm only implies every GPU officially supported by ROCm, which might not even be true in case of the smaller Instincts that might just not cope with batch sizes & whatnot used..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants