Skip to content

Commit

Permalink
Minor fix on code style
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Aug 24, 2023
1 parent 0017fd9 commit d55bedd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 8 additions & 7 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def generate_stream(
stream_interval: int = 2,
judge_sent_end: bool = False,
):
if hasattr(model, "device"):
device = model.device

# Read parameters
prompt = params["prompt"]
len_prompt = len(prompt)
Expand Down Expand Up @@ -95,12 +98,12 @@ def generate_stream(

if model.config.is_encoder_decoder:
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=model.device)
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=model.device,
device=device,
)

past_key_values = out = None
Expand All @@ -115,17 +118,15 @@ def generate_stream(
)
logits = model.lm_head(out[0])
else:
out = model(
torch.as_tensor([input_ids], device=model.device), use_cache=True
)
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=model.device,
device=device,
),
encoder_hidden_states=encoder_output,
use_cache=True,
Expand All @@ -138,7 +139,7 @@ def generate_stream(
out = model(
input_ids=torch.as_tensor(
[[token] if not sent_interrupt else output_ids],
device=model.device,
device=device,
),
use_cache=True,
past_key_values=past_key_values if not sent_interrupt else None,
Expand Down
2 changes: 1 addition & 1 deletion fastchat/train/train_mem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from fastchat.train.llama2_flash_attn_monkey_patch import (
from fastchat.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)

Expand Down

0 comments on commit d55bedd

Please sign in to comment.