-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Metal thread safety #2104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Metal thread safety #2104
Conversation
This looks interesting. But in general I'm not convinced we need to go this route of fine-grained locking. It might work just as well, maybe even better, and be a lot cleaner/ faster to have a lock where we do the task submission in the main A couple higher level comments:
|
Thank you! Let me see what I can do along those lines. The tests definitely make sense, I wasn't sure what performance benchmark made sense along these lines. If the existing ones are fine and see what impact the changes have? |
I would do more end-to-end benchmarks. And focuse on more latency sensitive ones since this type of change matters there. So for example LM inference with a smallish LM (like 4-bit 1-3B in size) would be a good place to start (you can use mlx-lm for that). |
82a117e
to
28902ec
Compare
I'm still working on some good simple tests, I ran in to a few more errors with the prior proposed changes. But I wanted to ask what you thought of this approach, I appreciate any feedback. I've spot checked the default model with |
I like this new approach as it's much simpler. Though I do wonder about the possibility of deadlock. Say we have two streams: Stream A is waiting on the output of Stream B Something like that seems plausible in a multi-threaded setup. I'm not sure it's necessarily a dealbreaker because sharing graphs across threads is not a good idea for other reasons. But it would be good to setup up a few C++ tests to really exercise the multi-threaded cases we expect this to work for. |
I've added a few tests, how do they look to you? The changes caused one test around buffers to very occasionally fail ( |
I ran a few benchmarks, apologies for the delay! @awni Prompt TPS
Generation TPS
The benchmark was pretty simple, prompt was very short (could make it longer). I set the max tokens to 1000 (which the qwen models sometimes reached in my benchmark). Here's the code too for reference. More trials could be run, and with a longer prompt, but hopefully this gives a decent idea on the time difference. from mlx_lm import load, stream_generate
import pandas as pd
max_tokens = 1_000
verbose = False
warmup_count = 3
num_trials = 10
df_results = pd.DataFrame()
checkpoints = [
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/Llama-3.2-3B-Instruct-4bit",
"mlx-community/Qwen3-0.6B-4bit",
"mlx-community/Qwen3-0.6B-6bit",
"mlx-community/Qwen3-0.6B-8bit",
"mlx-community/Qwen3-1.7B-3bit",
"mlx-community/Qwen3-1.7B-4bit",
]
for checkpoint in checkpoints:
model, tokenizer = load(path_or_hf_repo=checkpoint)
prompt = "Hello! I'm teaching a science class on our solar system and wanted to ask for your help! " \
"Could you tell what the planets in our solar system are called, and a little about each one?"
conversation = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
conversation=conversation, add_generation_prompt=True
)
for _ in range(warmup_count):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
for i in range(num_trials):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
response_dict = {
'model': checkpoint,
'trial': i,
'prompt_tokens': response.prompt_tokens,
'prompt_tps': response.prompt_tps,
'generation_tokens': response.generation_tokens,
'generation_tps': response.generation_tps,
'peak_memory': response.peak_memory,
}
df_trial = pd.DataFrame(response_dict, index=[0])
df_results = pd.concat([df_results, df_trial], ignore_index=True)
print(df_results.head())
df_results.to_csv('trial_runs.csv', index=False) As usual, any feedback is greatly appreciated! |
The |
It does seem to be slightly slower (with some high variance on the prompt tps). Not sure where to go from here. If this makes the PR a no go, if it's possible I can try relaxing some of the mutex locks, or if it's something with my benchmark. Prompt TPS
Generation TPS
|
Proposed changes
These changes are an attempt to improve thread safety for the metal backend. This is related to #2067
Please let me know what you think.
Checklist
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes