Skip to content

Conversation

@molbap
Copy link
Contributor

@molbap molbap commented Mar 18, 2025

What does this PR do?

This PR introduces a tool to trace ALL model inputs and outputs to a json nested format. Main values, shapes, dtypes are all outputted.

Usage

It's a simple context manager to add before calling a forward on your inputs. To note, inference_mode is enforced here.

import torch
from PIL import Image
from transformers import LlavaProcessor, LlavaForConditionalGeneration
from transformers.model_debugging_utils import  model_addition_debugger_context

# load pretrained model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)

# create random image input
torch.random.manual_seed(673)
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())

# prompt
prompt = "<image>Describe this image."

# process inputs
inputs = processor(text=prompt, images=random_image, return_tensors="pt")

with model_addition_debugger_context(model=model, debug_path="model_debug"):
    output = model.forward(**inputs)

Why??

Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time.

Example output

Here, you can see two jsons of the same model, where the sole difference is the epsilon of a layer normalization. I added the decorator, and we can see with string matching that the outputs start to differ. I'll include a snippet with difflib too I think to make it simpler.

Screenshot from 2025-03-18 16-44-16

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nice! Just missing something like register_model_for_debug(LlamaModel) !

@molbap molbap marked this pull request as ready for review March 20, 2025 14:53
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this/move it to another PR, with #36827 done we should be able to get rid of this check in make fixup

To note, this decorator enforces `torch.inference_mode()`.
## Usage
add decorator to your model class
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would put this in the model_debugging_utils

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in the doc :) not sure of the output cause the builder is broken currently

@molbap molbap merged commit 1d3f35f into main Mar 20, 2025
23 of 24 checks passed
@molbap molbap deleted the add_model_visual_debugger branch March 20, 2025 16:37
@molbap molbap mentioned this pull request Apr 9, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* draft of model tracer visualiser

* add context manager in addition to decorator

* add debug utils to init

* move model debugging utils to dedicated file

* add documentation

* protect some imports

* format

* move and protect imports

* format

* doc: improve errors in case of broken dummy imports.

* format

* use automatic torch backend

* update doc

* fix backend

* (TEMP) move to dummies while backend wait

* update documentation

* doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants