Skip to content

Conversation

@glistening
Copy link
Contributor

It fuses LlamaAttention from TinyLlama model.
Fused attention works as onert attention op.

TICO-DCO-1.0-Signed-off-by: Sanggyu Lee sg5.lee@samsung.com

@glistening
Copy link
Contributor Author

glistening commented Nov 4, 2025

Generated circle needs to run using onert.

I heard there is:

from test.utils import tag

@tag.use_onert
class Gemma3(TestModuleBase):
    def __init__(self):

But it requires class definition while my case does not have class definition at all.

How can I run the generated circle usign onert?

Comment on lines 28 to 67
from torch.library import Library

lib = Library("circle", "DEF")
lib.define(
"""
attention.llama(
Tensor hidden_states,
Tensor wq,
Tensor wk,
Tensor wv,
Tensor wo,
Tensor position_cos,
Tensor position_sin,
Tensor attention_mask,
Tensor past_key,
Tensor past_value,
Tensor cache_position
) -> Tensor
"""
)

# ATTENTION FUSER


@torch.library.register_fake("circle::attention.llama")
def attention_llama(*args, **kwargs):
(
hidden_states,
q_proj,
k_proj,
v_proj,
o_proj,
position_cos,
position_sin,
attention_mask,
past_key,
past_value,
cache_position,
) = args
return hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use @custom_op decorator instead. And, the registration codes should be placed at tico/utils/register_custom_op.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go Thank you for review. I already added "No Merge" after comparing with #266. I will rearrange the code after TICO's way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can I run the genereated circle usign onert?

I think you can refactor this script into a single class. test/modules/model/TinyLlamaWithFusedRMSNorm/model.py can be a reference.

Copy link
Contributor Author

@glistening glistening Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I already know the PR. I did not want tedious and unnecessary jobs like making class, providing hand-made random inputs and so on. I think there would be better way. However, I am not familiar with TICO and it seems the necessary thing for TICO's value test.

@mhs4670go
Copy link
Contributor

When I ran below commands, I got errors.

pip install -r test/modules/model/TinyLlamaWithFusedAttention/requirements.txt
python test/modules/model/TinyLlamaWithFusedAttention/decode.py
Traceback (most recent call last):
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3119, in inline_call_
    sub_locals = func.bind_args(parent, args, kwargs)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 233, in bind_args
    bound = inspect.signature(fake_func).bind(*args, **kwargs)
  File "/home/seongwoo/.pyenv/versions/3.10.4/lib/python3.10/inspect.py", line 3179, in bind
    return self._bind(args, kwargs)
  File "/home/seongwoo/.pyenv/versions/3.10.4/lib/python3.10/inspect.py", line 3149, in _bind
    raise TypeError('missing a required argument: {arg!r}'. \
TypeError: missing a required argument: 'past_key_value'

@glistening
Copy link
Contributor Author

glistening commented Nov 5, 2025

@mhs4670go
I have no problem with torch=2.7.1, python=3.10.12, transformers=4.51.3.
Could you tell me your torch and transformers version?

@mhs4670go
Copy link
Contributor

mhs4670go commented Nov 5, 2025

Could you tell me your torch and transformers version?

torch: 2.6.0
transformers 4.57.1
python 3.10.4

python test/modules/model/TinyLlamaWithFusedAttention/decode.py               

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.

Lily picked up a flower. Once the sun was shining and the sun was shining. Lily was very happy. She wanted to see the sun in the sky. She walked and walked,

Traceback (most recent call last):
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/codegen.py", line 263, in __call__
    self.call_reconstruct(value)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/codegen.py", line 90, in call_reconstruct
    res = value.reconstruct(self)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 358, in reconstruct
    raise NotImplementedError
NotImplementedError

@mhs4670go
Copy link
Contributor

After installing transforemrs==4.49.0 and using register_dynamic_cache function, the conversion worked well.

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinevening Are you okay for this location? You told me that you prefer seperate directory for onert-only operators like op_attention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The location can be like below.

  • Visitor: tico/serialize/operators/
  • Custom op registration: tico/utils/register_custom_op.py
  • Adapter: tico/serialize/operators/adapters/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put adapters in the current directory (tico/serialize/operators/adapters/onert) and place other codes according to the @mhs4670go 's suggestion?

adapters may be open to users, so I think it would be good to specify "This adapter is only for onert".

Copy link
Contributor Author

@glistening glistening Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put adapters in the current directory (tico/serialize/operators/adapters/onert) and place other codes according to the @mhs4670go 's suggestion?

adapters may be open to users, so I think it would be good to specify "This adapter is only for onert".

Yes, it is exactly same to my understanding. Thank you for confirming.

As I wrote in #400 (comment), I will follow TICO's way like #266.

(Though I don't think it is a good way to put all operators in a single file — register_custom_op.py)
(It is 740 lines and will grow and grow as supported operators increase.)
(It would be good to give each operator its own file, and let register_custom_op.py do registering only.)

Anyway, again, I will follow TICO's current way.

Copy link
Contributor Author

@glistening glistening Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapters may be open to users so I think it would be good to specify "This adapter is only for onert".

@jinevening I totally agree. adapters will vary per models.
I am trying encoder-decoder model (for translation task).
I will sugegst adapters code structure after finishing encoder-decoder model, which requires another adpater (i.e.TRIVMultiheadAttention).

@glistening
Copy link
Contributor Author

glistening commented Nov 7, 2025

After installing transforemrs==4.49.0 and using register_dynamic_cache function, the conversion worked well.

Thank you.

  1. For your information, as @dayo09 wrote in Convert LlamaAttention as 1 op (attention) #160 (comment), since 4.50.0, dynamic_cache is supported by transformers. You don't need to explicitly call register_dynamic_cache.

  2. The latest transformers version broke the previous behavior. Specifying an exact Transformers version should be sufficient. Time permitting, I’ll try to find a working solution using the latest source, even though I don’t see a need to adopt the newest version of Transformers at this moment.

@glistening glistening marked this pull request as draft November 11, 2025 07:50
@glistening glistening force-pushed the attention_ branch 3 times, most recently from 69f5e92 to 82460ad Compare November 12, 2025 00:07
@glistening glistening marked this pull request as ready for review November 12, 2025 02:25
@glistening glistening force-pushed the attention_ branch 3 times, most recently from a6b7de6 to 8f43620 Compare November 12, 2025 02:48
@glistening glistening changed the title Fuse LlamaAttention to attention (onert) Add TinyLlama decoder model with fused attention Nov 12, 2025
)

verify_circle(circle_model_path, opt_circle_model_path)
# verify_circle(circle_model_path, opt_circle_model_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Samsung/tico_developers Can I disable verify_circle? The output model contains attention op, which is not supported by circle2circle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I don't think it's possible. The workaround I come up with as of now is as follows.

  • Add a skip decorator (Before merging this PR, the pass should be passed without verify_circle though.
  • Wait until circle2circle supports CircleAttention.
  • Remove the decorator.


if self._call_count == 2:
return self.fused_model(*args, **kwargs)
else:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TICO test framework get expected values using original model, then convert the model. For 1st run, it should not use adapter.

I tried not to use call_count, but to use inspect and get the caller's name. However, it is not allowed during torch tracing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if torch.ops.circle_custom.attention had own kernel instead of returning None, this trick wouldn't be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a trick at all. circle_custom.attention is a virtual operator for fusing, not for computing. I found no reason to put the actual implementation here.

The call_count related code can be eliminated if the TICO test framework is improved by isolating infering model and converting model. I see it already had problem:

# Let's infer torch model before `export`
# WHY?
# Some model changes its state during export (e.g., EfficientFormerL1)
# See https://github.com/pytorch/pytorch/issues/155114
torch_result = infer_nnmodule(
self.nnmodule,
forward_args=deepcopy(self.forward_args),
forward_kwargs=deepcopy(self.forward_kwargs),
)

@glistening
Copy link
Contributor Author

glistening commented Nov 12, 2025

@Samsung/tico_developers The output itself is valid in term of TICO, but it is incomplete in term of onert.

It is the reason it gots the following errors:

$ Product/x86_64-linux.debug/out/bin/onert_run ../TICO/./test/pt2_to_circle_test/artifacts/model/TinyLlamaWithFusedAttention/model/TinyLlamaWithFusedAttention.circle
Model Filename ../TICO/./test/pt2_to_circle_test/artifacts/model/TinyLlamaWithFusedAttention/model/TinyLlamaWithFusedAttention.circle
Error during model prepare : d_model must be divisible by n_head

$ Product/x86_64-linux.debug/out/bin/onert_run ../TICO/./test/pt2_to_circle_test/artifacts/model/TinyLlamaWithFusedAttention/model/TinyLlamaWithFusedAttention.circle
Model Filename ../TICO/./test/pt2_to_circle_test/artifacts/model/TinyLlamaWithFusedAttention/model/TinyLlamaWithFusedAttention.circle
Error during model prepare : Attention: shape mismatch between inputs

Before running circle, I manipulate the TICO generated model further using Samsung/ONE#16233.

./reshape.io.py input --by_shape [1,16,30,4] [1,16,32,4] | \
./transpose.io.kvcache.py 

If you want to run test using onert in TICO, o2o script needs to be run.

Also, circle2circle (currently c++) needs to be selectable circle2circle (python).

These are beyond the PR and this repo's responsibility.

I would like to hear your opinions.

@mhs4670go
Copy link
Contributor

mhs4670go commented Nov 12, 2025

@glistening

since 4.50.0, dynamic_cache is supported by transformers.

Even though I installed transformers==4.50.3, it got failed with same reason.

File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(DynamicCache)

from user code:
   File "/home/seongwoo/TICO/test/modules/model/TinyLlamaWithFusedAttention/model.py", line 38, in forward
    return self.fused_model(*args, **kwargs)

@dayo09

Could you try to run this test to check whether you meet the same error?

@glistening
Copy link
Contributor Author

glistening commented Nov 12, 2025

@mhs4670go

Even though I installed transformers==4.50.3, it got failed with same reason.

Hmm, I added register_dynamic_cache(). I hope it helps in your environment.

@mhs4670go
Copy link
Contributor

@glistening

I would like to hear your opinions.

Here's some workaround to come up with. Let me know if there's something wrong.

Error during model prepare : d_model must be divisible by n_head

First error can be resolved by generating inputs with max_length=32 instead.

Error during model prepare : Attention: shape mismatch between inputs

If you implement a kernel for CircleAttention, you can transpose the kv cache in the get_example_inputs instead of doing the same thing after circle export.

@glistening
Copy link
Contributor Author

glistening commented Nov 12, 2025

First error can be resolved by generating inputs with max_length=32 instead.

I already wrote error with max_length=32. It is incorrect in term of onert. It is the reason I used 30.

Error during model prepare : Attention: shape mismatch between inputs

If you implement a kernel for CircleAttention, you can transpose the kv cache in the get_example_inputs instead of doing the same thing after circle export.

The kernel spec should not be changed by the limitation of frontend tool.

I don't like the way to change something ( manual inputs, original models, runtime kernels, ... ) to generate what I want using TICO.

It is the reason o2o exists.

No, the detail minipulation needs to be done with explicit tool using offline.

I don't understand what you mean by

you can transpose the kv cache in the `get_example_inputs

If you're asking to change example inputs, it just gets prompt string.

@mhs4670go
Copy link
Contributor

I already wrote error with max_length=32. It is incorrect in term of onert. It is the reason I used 30.

Maybe I missed something. Could you elaborate more? Because I got a model that has [1,16,32,4] shape if I ran the test like below.

inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=32, # HERE
            truncation=True,
        )

The inputs' shape should be [1, 16, 30, 4]?

@mhs4670go
Copy link
Contributor

@glistening

As I said in #400 (comment), could you wait just a few more days until we have our meeting?

@glistening
Copy link
Contributor Author

glistening commented Nov 21, 2025

The kernel spec should not be changed by the limitation of frontend tool.

Well, I rather wanted the custom operator to be aligned with onert's or the ones you intended, not changing something. Because I think that's what should be done when custom operators are registered. Current CircleAttention is just an operator returning None.

But, if it bothers you, you can merge this test with skip decorator. And, I'm gonna have a meeting about this with @Samsung/tico_developers. It needs to have a discussion I guess.

I was confused and am still confused by "custom":

As I understand,

  1. TICO developers seem to consider all operators which are not defined by aten or tflite custom. Even the operator (e.g. rms_norm) in circle schema seems to be considered as a custom operator.

  2. At least for me, custom operator (as I mean) is the one in this PR like (LlamaAttention). I believe it should be possible to register custom operator w/o get review from this repo.

I hope TICO define custom precisely, and support custom operator (2nd case) soon.

@mhs4670go
Copy link
Contributor

mhs4670go commented Nov 24, 2025

@glistening

As a result of the meeting, we agreed to enable the following features without testing codes.

  • Attention Adapter
  • Attention Visitor
  • Custom op registration for CircleAttention (without kernel impl)

To make things easier, I separated the relevant portion into this PR. Please take a loot at it. I'll close this PR soon.

Regarding the testing, a number of components are not compatible with TICO. As demonstrated in this PR, it would be more appropriate to keep the tests in the ONE repository or another suitable location.

@glistening
Copy link
Contributor Author

glistening commented Nov 24, 2025

@Samsung/tico_developers
Thank you for your time for this PR.

Regarding the testing, a number of components are not compatible with TICO. As demonstrated in Samsung/ONE#16283, it would be more appropriate to keep the tests in the ONE repository or another suitable location.

Yes, I wrote Samsung/ONE#16283 to show how the whole end-to-end process is done (You can find TICO in the middle of the process.)

@mhs4670go
#419 seems to be identical to #400 except that #419 removed the changed under test.
(It is reasonable since no test for attenetion is the conclusion of your meeting.)
I just wonder why you create another PR. If you would like to reduce my effort, I appreciate it.
However, I would like to commit my change in my name.

@mhs4670go
Copy link
Contributor

If you would like to reduce my effort, I appreciate it.

This one:) Please feel free to modify this PR instead of mine.

@glistening
Copy link
Contributor Author

Thank you. It already has 13 commits and 46 comments (including this).
It would be better to keep the history in one place.

@glistening glistening force-pushed the attention_ branch 2 times, most recently from 543463b to a9f93c9 Compare November 24, 2025 09:46
past_value: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
return None
Copy link
Contributor

@dayo09 dayo09 Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glistening If you'd like to enable some tests on TICO, it could be done by implementing this function and add simple 'attention' function with KV cache. TICO also uses onert nightly, so it's testable.

# FILE: test/requirements_pre.txt
onert==0.2.0.dev250922

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea to test attention in TICO. It is beyond the scope of TICO.
It would be done in the scope of GGMA (end-to-end) perhaps in other repo.

I edited the title to specify the scope of this PR.

If it is meaningful, prefill phase of tinyllama (which has no fused-attention) may be testable usign onert in TICO w/o modifying or extending TICO test framework much as you suggested.

@glistening glistening changed the title Add TinyLlama decoder model with fused attention Add attention operator and adapter for onert Nov 25, 2025
Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit 7bcbfe3 into Samsung:main Nov 25, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants