-
Notifications
You must be signed in to change notification settings - Fork 4
Add rocm perf yml file #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm-jaxlib-v0.5.0
Are you sure you want to change the base?
Conversation
.github/workflows/rocm-perf.yml
Outdated
times.append(float(m.group(1))) | ||
if times: | ||
summary[model] = { | ||
"median_step_time": round(float(np.median(times)), 3), | ||
"steps_counted": len(times) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grab the parsed steps too with the summary
times.append(float(m.group(1))) | |
if times: | |
summary[model] = { | |
"median_step_time": round(float(np.median(times)), 3), | |
"steps_counted": len(times) | |
} | |
times.append(float(m.group(1))) | |
if times: | |
step_info = list([{"step": n, "time": t} for n,t in enumerate(times)]) | |
summary[model] = { | |
"steps": step_info, | |
"median_step_time": round(float(np.median(times)), 3), | |
"steps_counted": len(times) | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
|
||
- name: Run MaxText training and save logs | ||
run: | | ||
docker exec maxtext_container bash -c "pip install -r requirements.txt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see this pip install being an issue across jax versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh, no, it will be different for different branches used in rocm/maxtext, for e.g. currently I am using rv_jax, but I will be renaming it to jax_0.5.0 or something
d2912f5
to
952cc4f
Compare
This is a port from the other performance CI PR, right? Could you add a description and link to the original PR? |
are we considering of grok and alphafold models? @Ruturaj4 @JehandadKhan Yes, we are def planning to add alphafold, however grok testing takes too much time to download weights. If grok training can be done or if there are ways to run grok faster, we are happy to add those as well! |
Why did you chose to report median step time? I don't know the rationale for that, but in general, I'm not sure that median is a correct metric here. It rejects outliers and alone it totally doesn't describe distribution of values, but that is exactly what is important to know:
TLDR: mean metric seem much better here. For the best results, I'd make 6 values: [0, 25, 50, 75, 100]% quantiles + mean too (b/c of the last bullet point) (and God forbid of stddev) |
run: | | ||
docker exec maxtext_container bash -c "pip install -r requirements.txt" | ||
for config in \ | ||
MaxText/configs/models/gpu/llama2_7b_rocm.yml \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, batch sizes & whatnot config values in the yamls are tailored for MI250 (as per CI machine spec above) ?
If so and the branch is going to also be used anywhere else (like merged to main
), then perhaps it's wise to rename the files to change _rocm
suffix to _MI250
to make the tailoring apparent.
MI250 already implies ROCm, but ROCm only implies every GPU officially supported by ROCm, which might not even be true in case of the smaller Instincts that might just not cope with batch sizes & whatnot used..
This PR adds a new GitHub Actions workflow that:
Builds JAX with ROCm support inside a Docker container.
Runs training for the following MaxText models:
Captures
stdout
logs for each model and extracts per-step timingIgnores
step 0
(warmup) when computing metricsComputes
median_step_time
per model and saves it tosummary.json
Uploads logs and metrics as workflow artifacts
A Python analysis script (
analyze_maxtext_logs.py
) is added underjax/build/rocm/
to parse logs and generate the summary.