Hello, thanks for your great work. I have some little questions.
When testing a Qwen2 based model, like llava_qwen or lmms-lab/LongVA-7B, on V-NIAH benchmark,
there is a function apply_seq_parallel_monkey_patch("zigzag_ring_attn", "llama").
- How can this monkey patch work since Qwen2 has a different architecture from LLaMA?
def apply_zigzag_ring_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)
- Does replacement have any effect for class Qwen2ForCausalLM_RingAttn?
- Then how is zigzag_ring_attn performed during benchmarking for llava_qwen based models?