Skip to content

Commit

Permalink
remove input pos keyword in generate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Gasoonjia committed Sep 24, 2024
1 parent a70d7b5 commit 937e7ed
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 472 deletions.
2 changes: 1 addition & 1 deletion torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _download_hf_snapshot(
try:
snapshot_download(
model_config.distribution_path,
cache_dir=artifact_dir,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns=None if "llava" in model_config.name else "*safetensors*",
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def prefill(
logits = model(x)
else:
# input_pos: [B, S]
logits = model(x, input_pos=input_pos)
logits = model(x, input_pos)
# print(f"logits {logits.shape}")

# print(f"x: {x},\n input_pos: {input_pos}\n")
Expand All @@ -398,7 +398,7 @@ def decode_one_token(
else:
logits = model(x)
else:
logits = model(x, input_pos=input_pos)
logits = model(x, input_pos)
# print(f"x: {x},\n input_pos: {input_pos}\n")
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)

Expand Down
Loading

0 comments on commit 937e7ed

Please sign in to comment.