Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to parse huggingface models #887

Open
JSeam2 opened this issue Dec 10, 2024 · 3 comments
Open

Unable to parse huggingface models #887

JSeam2 opened this issue Dec 10, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@JSeam2
Copy link
Collaborator

JSeam2 commented Dec 10, 2024

Describe the bug

We are presently unable to parse huggingface model and call gen-settings via the optimum-cli flow

Expected behaviors

We should be able to parse a huggingface model directly.

optimum-cli export onnx -m sshleifer/tiny-gpt2 --optimize O1 --device cpu --opset 14 --sequence_length 64 ./tinygpt2

And then run

ezkl gen-settings -M model.onnx

Steps to reproduce the bug

  1. Download a model from huggingface via the cli call, we download the small tiny-gpt2 model which is fairly small (500KB) and should be usable on most devices running ezkl.
optimum-cli export onnx -m sshleifer/tiny-gpt2 --optimize O1 --device cpu --opset 14 --sequence_length 64 ./tinygpt2

  1. Execute
ezkl gen-settings -M model.onnx

[E] [2024-12-10 07:15:32:345, ezkl] - [graph] [tract] Undetermined symbol in expression:sequence_length
  1. Hardcoding the values on onnx directly resulted in more problems suggesting a greater incompatibility with Tract
    Examples of errors
[E] [2024-12-10 07:15:48:138, ezkl] - [graph] [tract] Can not broadcast 128 against 256
[E] [2024-12-10 07:18:30:675, ezkl] - [graph] [tract] Failed analyse for node #96 "/transformer/h.0/attn/Concat_3" InferenceConcat

Script used to perform surgery on onnx

import onnx
from onnx import helper
import numpy as np

def print_tensor_shapes(model, prefix=""):
    print(f"\n{prefix} Tensor shapes:")
    for input in model.graph.input:
        print(f"Input {input.name}: {[dim.dim_value if hasattr(dim, 'dim_value') else dim.dim_param for dim in input.type.tensor_type.shape.dim]}")
    for output in model.graph.output:
        print(f"Output {output.name}: {[dim.dim_value if hasattr(dim, 'dim_value') else dim.dim_param for dim in output.type.tensor_type.shape.dim]}")

def hardcode_sequence_lengths(model_path, past_sequence_length, sequence_length, batch_size, output_path):
    """
    Modify ONNX model to replace both past_sequence_length and sequence_length with fixed values
    
    Args:
        model_path: Path to input ONNX model
        past_sequence_length: Integer value to replace past_sequence_length
        sequence_length: Integer value to replace sequence_length
        batch_size: Integer value for batch size
        output_path: Path to save modified model
    """
    # Load the model
    model = onnx.load(model_path)
    
    # Print original shapes
    print_tensor_shapes(model, "Before modification")

    # Update input shapes
    for input in model.graph.input:
        tensor_type = input.type.tensor_type
        
        # Handle different input types
        if input.name == 'input_ids':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length

        elif input.name == 'attention_mask':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length + past_sequence_length

        elif input.name == 'position_ids':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length

        elif 'past_key_values' in input.name:
            tensor_type.shape.dim[0].dim_value = batch_size
            # dim[1] is num_heads (2)
            tensor_type.shape.dim[2].dim_value = past_sequence_length
            tensor_type.shape.dim[3].dim_value = 64  # head dimension
    
    # Update output shapes
    for output in model.graph.output:
        tensor_type = output.type.tensor_type

        if output.name == 'logits':
            tensor_type.shape.dim[0].dim_value = batch_size
            tensor_type.shape.dim[1].dim_value = sequence_length
            # dim[2] is vocab_size (50257)

        elif 'present' in output.name:
            tensor_type.shape.dim[0].dim_value = batch_size
            # dim[1] is num_heads (2)
            tensor_type.shape.dim[2].dim_value = sequence_length + past_sequence_length
            tensor_type.shape.dim[3].dim_value = 64  # head dimension

    # Print modified shapes
    print_tensor_shapes(model, "After modification")
    
    # Check model validity
    onnx.checker.check_model(model)
    
    # Save the modified model
    onnx.save(model, output_path)

if __name__ == "__main__":
    sequence_length = 128      # Your desired sequence length
    past_sequence_length = 128 # Your desired past sequence length
    batch_size = 1            # Your desired batch size
    hardcode_sequence_lengths(
        "model.onnx",
        past_sequence_length,
        sequence_length,
        batch_size,
        "surgery.onnx"
    )
@JSeam2 JSeam2 added the bug Something isn't working label Dec 10, 2024
@JSeam2
Copy link
Collaborator Author

JSeam2 commented Dec 10, 2024

Related to #181

If we can address these issues we can offer support for huggingface which will be very powerful

@JSeam2 JSeam2 closed this as completed Dec 10, 2024
@JSeam2 JSeam2 reopened this Dec 10, 2024
@Jot-De
Copy link

Jot-De commented Dec 20, 2024

This would be a great feature! @JSeam2
As a user, is there any way to circumnavigate this issue and use huggingface models in ezkl 16.2?

@JSeam2
Copy link
Collaborator Author

JSeam2 commented Dec 25, 2024

@Jot-De you can use optimum-cli directly atm and find the input shape to the model and create a random input.json

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants