Skip to content

Commit

Permalink
Optimize broadcast for npu llama (#12028)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored Sep 6, 2024
1 parent e5581e6 commit 58555bd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def from_pretrained(cls, *args, **kwargs):
ignore_argument(kwargs, "pipeline_parallel_stages")
optimize_model = kwargs.pop("optimize_model", False)
max_output_len = kwargs.pop("max_output_len", 1024)
max_output_len = max_output_len - 1
max_prompt_len = kwargs.pop("max_prompt_len", 512)
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def run_decode(
with torch.inference_mode():
while True:

dist.broadcast(control, src=0)
dist.broadcast(control, src=0, async_op=False)
if control.item() == -2:
break
elif control.item() == -1:
Expand Down Expand Up @@ -595,6 +595,8 @@ def __init__(self, model, max_seq_len, intra_pp=2, inter_pp=2, transpose_value_c
self.output_queues = []
self.decoder_processes = []

self.forward_signal = torch.tensor(0, dtype=torch.int)

for rank in range(1, world_size):
input_q = mp.Queue()
output_q = mp.Queue()
Expand Down Expand Up @@ -643,21 +645,20 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
t0 = time.perf_counter()

if self.cache_past_key_value != past_key_value:
control = torch.tensor(-1, dtype=torch.int)
dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value)

control = torch.tensor(0, dtype=torch.int)
dist.broadcast(control, src=0)
t0 = time.perf_counter()
dist.broadcast(self.forward_signal, src=0, async_op=True)
t1 = time.perf_counter()
hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1)
t1 = time.perf_counter()
t2 = time.perf_counter()
return hidden_states, past_key_value

def shutdown(self):
Expand Down

0 comments on commit 58555bd

Please sign in to comment.