Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add seperate prefill detokenization thread #152

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

zhihaoshan-google
Copy link
Collaborator

Add a seperate prefill detokenization thread to make sure the detokenization of prefill and decode is not blocked with each other.

@zhihaoshan-google zhihaoshan-google requested review from jwyang-google and vipannalla and removed request for vipannalla November 18, 2024 18:39
@zhihaoshan-google zhihaoshan-google force-pushed the prefill_detokenization branch 2 times, most recently from 676a90e to 2088c2c Compare November 18, 2024 23:40
Copy link
Collaborator

@patemotter patemotter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, just some clarification question/suggestion.

@zhihaoshan-google
Copy link
Collaborator Author

Thank you, Pate!

@patemotter
Copy link
Collaborator

jetstream/core/orchestrator.py:274:74: C0303: Trailing whitespace (trailing-whitespace)

:-/

@zhihaoshan-google
Copy link
Collaborator Author

I will make sure test pass before submission. Sorry for the mistake.

@vipannalla
Copy link
Collaborator

Looks good. What improvement do you see with this change?

@zhihaoshan-google
Copy link
Collaborator Author

Looks good. What improvement do you see with this change?

it's in the analysis doc.

@zhihaoshan-google zhihaoshan-google merged commit d462ca9 into main Nov 19, 2024
3 checks passed
@zhihaoshan-google zhihaoshan-google deleted the prefill_detokenization branch November 19, 2024 06:56
@wangkuiyi
Copy link

wangkuiyi commented Dec 12, 2024

Hi Zhihao, Pate, and Vipan,

I am afraid that this PR may disrupt the correct detokenization order given a certain request. Please let me know if my understanding is incorrect.

The proposed change creates separate detokenization backlogs: P backlogs for prefill engines and G backlogs for generation engines. For each active request, the first token (from the prefill thread) and subsequent tokens (from the generate thread) end up in different backlogs. This separation could lead to incorrect ordering.

To maintain proper token ordering, all tokens from a single request should be processed through the same backlog before being returned via the request's channel.

Consider this edge case: if Python's thread scheduler never executes the detokenization threads that read from the prefill-detokenization backlogs, only the tokens from generate threads would be returned. This would result in missing first tokens for all requests.

cc. @changlan

@zhihaoshan-google
Copy link
Collaborator Author

zhihaoshan-google commented Dec 12, 2024

Thank you for the comment, Yi!

Yes, you are right. Sorry for missing the potential breaks.

The previous problem of the TTFT before this PR from my understanding is:

Long time to convert from jax Array to numpy array

I am currently thinking of two options:

  1. Still use two threads: add a condition variable per request to make sure the order (waiting for the first token to add in the request return channel before adding the second and following token).

  2. Change back to one thread: Wait for prefill and generation/decode jitted function to complete inside the prefill and generation/decode thread directly instead of in the detokenization thread. Also to mitigate the runtime latency (for dispatching the device program), we can schedule several decode/generate steps together. (this might be most common way for this issue)

I might need to do a quick evaluation and get back to you.

Let me do a quick revert tomorrow morning.

@wangkuiyi
Copy link

wangkuiyi commented Dec 12, 2024

Thank you, Zhihao, for the clarification! It is very helpful!

How about merging the P prefill detokenization backlogs and the G generation detokenization backlogs into a single set of N detokenization backlogs? These backlogs will be associated with N detokenization threads?

For any active request r, any prefill or generation thread processing r simply adds the output token IDs to the i-th of these N tokenization backlogs, where i is calculated as the hash of r modulo N.

This ensures that all response tokens for r go to the same detokenization backlog. This also ensures that all tokens for r arrive in the correct order.

@zhihaoshan-google
Copy link
Collaborator Author

Thank you for the quick response and suggestion, Yi! This is definitely a great idea!

I will try it. Currently, I have observed that the JetStream runtime get stuck occasionally for several seconds and then restore to normal (when I do the disaggregated serving via 1 prefill engine and 1 generate/decode engine). I ran the py-spy and saw the GIL is hold by one thread during the stuck and seems not really do anything. I am just afraid that introducing more threads will make it worse. (as it seems the N is scaled with the QPS throughput)

Maybe the above option #2 with running detokenization in the process instead thread is better. I will evaluate them.

@wangkuiyi
Copy link

Thank you, Zhihao, for the context and the quick reaction! Looking forward to your evaluation result!

@changlan
Copy link

changlan commented Jan 31, 2025

Hi Zhihao and Vipan, just following up on this issue - do we have a conclusion for this, or has the change been reverted already? It is blocking our update due to potential incorrect results. Thanks!

@vipannalla
Copy link
Collaborator

@sixiang-google, can you take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants