Skip to content

Conversation

acsweet
Copy link

@acsweet acsweet commented Apr 22, 2025

Proposed changes

These changes are an attempt to improve thread safety for the metal backend. This is related to #2067
Please let me know what you think.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Apr 22, 2025

This looks interesting. But in general I'm not convinced we need to go this route of fine-grained locking. It might work just as well, maybe even better, and be a lot cleaner/ faster to have a lock where we do the task submission in the main eval loop.

A couple higher level comments:

  • It would be good to add tests which show the problem and that the solution works
  • We'll need to do some performance benchmarking to be sure there are no regressions from any solution we come up with

@acsweet
Copy link
Author

acsweet commented Apr 22, 2025

Thank you! Let me see what I can do along those lines.

The tests definitely make sense, I wasn't sure what performance benchmark made sense along these lines. If the existing ones are fine and see what impact the changes have?

@awni
Copy link
Member

awni commented Apr 22, 2025

If the existing ones are fine and see what impact the changes have?

I would do more end-to-end benchmarks. And focuse on more latency sensitive ones since this type of change matters there. So for example LM inference with a smallish LM (like 4-bit 1-3B in size) would be a good place to start (you can use mlx-lm for that).

@acsweet acsweet force-pushed the metal-thread-safe branch from 82a117e to 28902ec Compare April 30, 2025 23:19
@acsweet
Copy link
Author

acsweet commented Apr 30, 2025

I'm still working on some good simple tests, I ran in to a few more errors with the prior proposed changes. But I wanted to ask what you thought of this approach, I appreciate any feedback.

I've spot checked the default model with mlx_lm.generate (mlx-community/Llama-3.2-3B-Instruct-4bit), and didn't see any noticeable differences, but I'll do a more robust benchmark with a wider range of model sizes too like you'd suggested.

@awni
Copy link
Member

awni commented May 1, 2025

I like this new approach as it's much simpler. Though I do wonder about the possibility of deadlock. Say we have two streams:

Stream A is waiting on the output of Stream B
Stream A is holding the metal lock in synchronize
Stream B gets stuck waiting to get the lock so it can run the eval

Something like that seems plausible in a multi-threaded setup. I'm not sure it's necessarily a dealbreaker because sharing graphs across threads is not a good idea for other reasons. But it would be good to setup up a few C++ tests to really exercise the multi-threaded cases we expect this to work for.

@acsweet
Copy link
Author

acsweet commented May 7, 2025

I've added a few tests, how do they look to you?

The changes caused one test around buffers to very occasionally fail (tests/array_tests.cpp "test array shared buffer"), I think related to how the deleter was handled with doctest and the test ending. I added a synchronize call to that test, if that makes sense there.

@acsweet
Copy link
Author

acsweet commented May 13, 2025

I ran a few benchmarks, apologies for the delay! @awni
Results are below, prompt tps and generated tps.

Prompt TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-1B-Instruct-4bit 1942.43 (±10.34) 1954.47 (±8.67)
mlx-community/Llama-3.2-3B-Instruct-4bit 696.70 (±26.15) 758.26 (±19.69)
mlx-community/Qwen3-0.6B-4bit 2592.38 (±21.81) 2605.61 (±24.54)
mlx-community/Qwen3-0.6B-6bit 2491.78 (±12.08) 2482.57 (±18.17)
mlx-community/Qwen3-0.6B-8bit 2526.31 (±20.74) 2541.08 (±10.92)
mlx-community/Qwen3-1.7B-3bit 1121.43 (±3.49) 1098.90 (±35.14)
mlx-community/Qwen3-1.7B-4bit 1124.30 (±13.53) 1134.61 (±4.48)

Generation TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-1B-Instruct-4bit 273.97 (±0.32) 274.73 (±0.44)
mlx-community/Llama-3.2-3B-Instruct-4bit 111.19 (±0.68) 111.66 (±0.25)
mlx-community/Qwen3-0.6B-4bit 290.42 (±4.89) 283.89 (±10.23)
mlx-community/Qwen3-0.6B-6bit 287.13 (±6.04) 264.96 (±14.38)
mlx-community/Qwen3-0.6B-8bit 261.41 (±1.57) 258.06 (±0.56)
mlx-community/Qwen3-1.7B-3bit 210.59 (±0.25) 205.58 (±2.31)
mlx-community/Qwen3-1.7B-4bit 185.28 (±0.44) 184.66 (±0.38)

The benchmark was pretty simple, prompt was very short (could make it longer). I set the max tokens to 1000 (which the qwen models sometimes reached in my benchmark). Here's the code too for reference. More trials could be run, and with a longer prompt, but hopefully this gives a decent idea on the time difference.

from mlx_lm import load, stream_generate
import pandas as pd

max_tokens = 1_000
verbose = False
warmup_count = 3
num_trials = 10
df_results = pd.DataFrame()

checkpoints = [
    "mlx-community/Llama-3.2-1B-Instruct-4bit",
    "mlx-community/Llama-3.2-3B-Instruct-4bit",
    "mlx-community/Qwen3-0.6B-4bit",
    "mlx-community/Qwen3-0.6B-6bit",
    "mlx-community/Qwen3-0.6B-8bit",
    "mlx-community/Qwen3-1.7B-3bit",
    "mlx-community/Qwen3-1.7B-4bit",
]

for checkpoint in checkpoints:
    model, tokenizer = load(path_or_hf_repo=checkpoint)

    prompt = "Hello! I'm teaching a science class on our solar system and wanted to ask for your help! " \
        "Could you tell what the planets in our solar system are called, and a little about each one?"
    conversation = [{"role": "user", "content": prompt}]

    prompt = tokenizer.apply_chat_template(
        conversation=conversation, add_generation_prompt=True
    )

    for _ in range(warmup_count):
        text = ""
        for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
            if verbose:
                print(response.text, end="", flush=True)
            text += response.text

    for i in range(num_trials):

        text = ""
        for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
            if verbose:
                print(response.text, end="", flush=True)
            text += response.text

        response_dict = {
            'model': checkpoint,
            'trial': i,
            'prompt_tokens': response.prompt_tokens,
            'prompt_tps': response.prompt_tps,
            'generation_tokens': response.generation_tokens,
            'generation_tps': response.generation_tps,
            'peak_memory': response.peak_memory,
        }

        df_trial = pd.DataFrame(response_dict, index=[0])
        df_results = pd.concat([df_results, df_trial], ignore_index=True)


print(df_results.head())
df_results.to_csv('trial_runs.csv', index=False)

As usual, any feedback is greatly appreciated!

@altaic
Copy link

altaic commented May 13, 2025

The Llama-3.2-3B-Instruct-4bit prompt processing results are a bit concerning since it's the only result where throughput reduction is outside of the variability in your testing, and it also happens to be the largest model. It seems like it'd be good to benchmark some larger models to see if a trend develops.

@acsweet
Copy link
Author

acsweet commented May 14, 2025

It does seem to be slightly slower (with some high variance on the prompt tps). Not sure where to go from here. If this makes the PR a no go, if it's possible I can try relaxing some of the mutex locks, or if it's something with my benchmark.

Prompt TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-3B-Instruct-4bit 733.51 (±41.78) 768.42 (±1.88)
mlx-community/gemma-3-4b-it-4bit 478.45 (±3.77) 484.88 (±7.67)
mlx-community/gemma-3-12b-it-4bit 170.33 (±0.59) 170.33 (±4.11)
mlx-community/gemma-3-12b-it-8bit 125.07 (±36.62) 154.84 (±0.94)

Generation TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-3B-Instruct-4bit 111.53 (±0.84) 111.91 (±0.14)
mlx-community/gemma-3-4b-it-4bit 85.00 (±0.06) 86.37 (±0.20)
mlx-community/gemma-3-12b-it-4bit 31.84 (±0.03) 32.33 (±0.05)
mlx-community/gemma-3-12b-it-8bit 17.43 (±0.69) 18.25 (±0.01)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants