Skip to content

Commit

Permalink
Address Nick nits and fix CUDAGraph correctness
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
andoorve committed Jul 2, 2024
1 parent 5a4b323 commit c92257c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def forward(
kv_caches[i - self.start_layer],
attn_metadata)

if get_pp_group().is_last_rank:
hidden_states = self.ln_f(hidden_states)
return hidden_states
else:
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})

hidden_states = self.ln_f(hidden_states)
return hidden_states


class GPT2LMHeadModel(nn.Module):

Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,15 @@ def forward(
residual,
)

if get_pp_group().is_last_rank:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
else:
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class LlamaForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,8 +1359,8 @@ def forward(
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
else:
return self.output_buffers

return self.output_buffers

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def from_broadcasted_tensor_dict(
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict.pop("virtual_engine"),
virtual_engine=tensor_dict["virtual_engine"],
)

def as_broadcastable_tensor_dict(
Expand Down

0 comments on commit c92257c

Please sign in to comment.