Skip to content

Commit 937e7ed

Browse files
committed
remove input pos keyword in generate.py
1 parent a70d7b5 commit 937e7ed

File tree

3 files changed

+4
-472
lines changed

3 files changed

+4
-472
lines changed

torchchat/cli/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _download_hf_snapshot(
3030
try:
3131
snapshot_download(
3232
model_config.distribution_path,
33-
cache_dir=artifact_dir,
33+
local_dir=artifact_dir,
3434
local_dir_use_symlinks=False,
3535
token=hf_token,
3636
ignore_patterns=None if "llava" in model_config.name else "*safetensors*",

torchchat/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def prefill(
374374
logits = model(x)
375375
else:
376376
# input_pos: [B, S]
377-
logits = model(x, input_pos=input_pos)
377+
logits = model(x, input_pos)
378378
# print(f"logits {logits.shape}")
379379

380380
# print(f"x: {x},\n input_pos: {input_pos}\n")
@@ -398,7 +398,7 @@ def decode_one_token(
398398
else:
399399
logits = model(x)
400400
else:
401-
logits = model(x, input_pos=input_pos)
401+
logits = model(x, input_pos)
402402
# print(f"x: {x},\n input_pos: {input_pos}\n")
403403
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)
404404

0 commit comments

Comments
 (0)