-
Notifications
You must be signed in to change notification settings - Fork 30k
Closed
Labels
Description
System Info
transformers==4.54.0.dev0
torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
python: 3.10.13
os: nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
def chat():
MODEL_PATH = "google/gemma-3n-E2B-it"
processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
model = Gemma3nForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
# device_map="auto",
# attn_implementation="flash_attention_2" if gpu_prop.major >= 8 else None,
).to("cuda")
model = model.eval()
image_1 = Image.new("RGB", (100, 100), color="white")
image_2 = Image.new("RGB", (100, 100), color="black")
image_3 = Image.new("RGB", (100, 100), color="red")
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "你是一个中文语音智能助手,不要使用特殊字符回复,请使用中文回复。",
}
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "图片是什么颜色"},
{"type": "image", "image": image_1},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "这张图片是纯白色的,没有任何内容。"}],
},
{
"role": "user",
"content": [
{"type": "text", "text": "描述下图片"},
{"type": "image", "image": image_2},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "这张图片是纯白色的,没有任何内容。"}],
},
{
"role": "user",
"content": [
{"type": "text", "text": "你叫什么名字"},
{"type": "image", "image": image_3},
],
},
]
for i in range(3):
inputs = processor.apply_chat_template(
messages[: (i + 1) * 2],
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
for key, value in inputs.items():
print(f"{key}: {value.shape=}")
input_ids = inputs["input_ids"]
prompt = processor.decode(input_ids[0])
print(f"{prompt=}")
streamer = TextIteratorStreamer(
tokenizer=processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
# do_sample=False,
do_sample=True,
temperature=0.2,
top_k=10,
top_p=0.9,
# num_beams=1,
repetition_penalty=1.1,
max_new_tokens=1024,
use_cache=True,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
start = perf_counter()
times = []
with torch.inference_mode():
for new_text in streamer:
times.append(perf_counter() - start)
print(new_text, end="", flush=True)
generated_text += new_text
start = perf_counter()
print(f"\n{i}. {generated_text=} TTFT: {times[0]:.2f}s total time: {sum(times):.2f}s")
2nd turn have a bug:
Exception in thread Thread-6 (generate):
Traceback (most recent call last):
File "/usr/local/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/local/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2616, in generate
result = self._sample(
File "/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py", line 3600, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/utils/generic.py", line 955, in wrapper
@wraps(func)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1184, in forward
return compiled_fn(full_args)
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 323, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in inner_fn
outs = compiled_fn(args)
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 490, in wrapper
return compiled_fn(runtime_args)
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 466, in __call__
return self.current_callable(inputs)
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1208, in run
return compiled_fn(new_inputs)
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 398, in deferred_cudagraphify
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 420, in cudagraphify
manager = get_container(device_index).get_tree_manager()
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 341, in get_container
container_dict = get_obj(local, "tree_manager_containers")
File "/usr/local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 336, in get_obj
assert torch._C._is_key_in_tls(attr_name)
AssertionError
Expected behavior
support chat with history messages
zucchini-nlp