Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert a model with a control flow? #2008

Open
Ykid opened this issue Oct 11, 2023 · 0 comments
Open

How to convert a model with a control flow? #2008

Ykid opened this issue Oct 11, 2023 · 0 comments
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)

Comments

@Ykid
Copy link

Ykid commented Oct 11, 2023

Goal

  • My goal is to add KV cache to GPT2 by putting two inference mode in each branch of an if-else loop, the motivation behind which is mainly improving the inference speed by feeding in some cache.

What I tried

  • I'm basically following the Mix Tracing and Scripting section
  • It seems the results for both branches must be the same, for the cond operator, so I combined the tuple into one and pad it to be the same tensor
  • It seems tuples are not a valid input to the coreml models, so I created a wrapper for it.
  • mil optimizations seem to mess up the conversion process, so I disabled it by passing in pass_pipeline=coremltools.PassPipeline.EMPTY

The issue I have now.

The conversion seemed to run without errors but I ran into errors when I try to do a prediction. The error is

RuntimeError: Error compiling model: "Failed to parse the model specification. 
Error: Unable to parse ML Program: in operation x_4: Param 'weight' must be 
const".
  • I checked in the converted proto class Model. There are 178 operation steps in the generated proto class. 5 of them are not of const type and only one of them contains a variable named x_4. and x_4 appeared in an operation with type cond

Wondering what can be done to fix the issue? Or is there any recommended approach to achieve the same goal? Thanks!

The script

import pdb
import sys
import traceback
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel


class GPT2WrapperNet(nn.Module):
    """
    Wrap the GPT2LMHeadModel so that passing in boolean is not an issue
    """

    def __init__(self, model: GPT2LMHeadModel, has_cache: bool, decoder_layers: int):
        super().__init__()
        self.has_cache = has_cache
        self.model = model
        self.decoder_layers = decoder_layers

    def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor,
                past_key_values: Optional[torch.Tensor] = None):
        default_inputs = {
            'return_dict': False,
            'use_cache': True,
            'output_attentions': False,
            'output_hidden_states': False
        }
        # no_cache
        if not self.has_cache:
            default_inputs['input_ids'] = input_ids
            default_inputs['position_ids'] = position_ids
            default_inputs['attention_mask'] = attention_mask
            # logit shape: [batch_size x beam_size,  prefix/seq length, vocab_size]
            logits, cache = self.model(**default_inputs)
            # Tuple x 6 (num_layers)
            # --> Tuple x 2 (1 for K, 1 for V)
            # -----> [batch_size x beam_size, num_attn_heads, prefix_length, n_embd // n_head]
            cache_list = [
                cache[0][0], cache[0][1],
                cache[1][0], cache[1][1],
                cache[2][0], cache[2][1],
                cache[3][0], cache[3][1],
                cache[4][0], cache[4][1],
                cache[5][0], cache[5][1]
            ]
            # cache shape: [
            #      num_layers x 2, batch_size x beam_size,
            #      num_attn_heads, prefix_length, n_embd // n_head
            # ]
            cache = torch.stack(cache_list)
            # (364128, )
            flattened = torch.cat([torch.flatten(logits), torch.flatten(cache)])
            # TODO: update this to be dynamic
            pad_size = 500000 - flattened.shape[0]
            # need to pad the tensor to return the tensor with the same shape!
            return F.pad(flattened, (0, pad_size), "constant", 0.0)
        else:
            # has_cache
            past_key_values_tuple = tuple((
                (past_key_values[2 * i], past_key_values[2 * i + 1])
                for i in range(0, self.decoder_layers)  # change to decoder block count
            ))
            default_inputs['input_ids'] = input_ids
            default_inputs['position_ids'] = position_ids
            default_inputs['attention_mask'] = attention_mask
            default_inputs['past_key_values'] = past_key_values_tuple
            # logit shape: [batch_size x beam_size,  1, vocab_size]
            logits, cache = self.model(**default_inputs)
            # Tuple x 6 (num_layers)
            # --> Tuple x 2 (1 for K, 1 for V)
            # -----> [batch_size x beam_size, num_attn_heads, prefix_length_for_cache + 1, n_embd // n_head]
            cache_list = [
                cache[0][0], cache[0][1],
                cache[1][0], cache[1][1],
                cache[2][0], cache[2][1],
                cache[3][0], cache[3][1],
                cache[4][0], cache[4][1],
                cache[5][0], cache[5][1]
            ]
            # [
            #   num_layers x 2, batch_size x beam_size,
            #   num_attn_heads, prefix_length_for_cache + 1, n_embd // n_head
            # ]
            cache = torch.stack(cache_list)
            # (173976,)
            flattened = torch.cat([torch.flatten(logits), torch.flatten(cache)])

            # TODO: update this to be dynamic
            pad_size = 500000 - flattened.shape[0]
            # need to pad the tensor to return the tensor with the same shape!
            return F.pad(flattened, (0, pad_size), "constant", 0.0)


class ControlFlowNet(nn.Module):
    def __init__(self, model: GPT2LMHeadModel, model2: GPT2LMHeadModel):
        super().__init__()
        # later can load this from model config
        mock_batch_size = 1
        mock_beam_size = 3
        mock_sequence_length = 4
        decoder_layers = 6
        num_attention_heads = 12
        n_embedding = 768
        mock_vocab_size = 21128

        # Model that doesn't take caching into account
        input_without_cache = [
            # input_ids
            torch.randint(low=0, high=mock_vocab_size,
                          size=(mock_batch_size * mock_beam_size, mock_sequence_length))
            .to(torch.int64),
            # position_ids
            (
                    torch.zeros(mock_batch_size * mock_beam_size, mock_sequence_length) +
                    torch.unsqueeze(torch.tensor(list(range(0, mock_sequence_length))), dim=0)
            ).to(torch.int64),
            # attention_mask
            torch.abs(
                torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64))
        ]
        wrapper_gpt2_model_without_cache = GPT2WrapperNet(model=model, has_cache=False,
                                                          decoder_layers=decoder_layers)

        self.gpt2_model_without_cache = torch.jit.trace(wrapper_gpt2_model_without_cache, input_without_cache)

        past_key_values_2 = torch.rand(
            decoder_layers * 2,
            mock_batch_size * mock_beam_size,
            num_attention_heads,
            mock_sequence_length - 1,
            n_embedding // num_attention_heads
        )

        # Model that takes caching into account
        input_with_cache = [
            # input_ids, last dim is 1 because we only need the last token or
            # tokens (in the case of batch inference)
            torch.abs(
                torch.randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, 1)).to(
                    torch.int64)),
            # position_ids
            torch.abs(
                (torch.ones(mock_batch_size * mock_beam_size, 1) * (mock_sequence_length - 1)).to(torch.int64)
            ),
            # attention_mask
            torch.abs(
                torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64)),
            past_key_values_2
        ]

        wrapper_gpt2_model_with_cache = GPT2WrapperNet(model=model2, has_cache=True,
                                                       decoder_layers=decoder_layers)
        self.gpt2_model_with_cache = torch.jit.trace(wrapper_gpt2_model_with_cache, input_with_cache)

    def forward(self,
                input_ids_cache: torch.Tensor,
                attention_mask_cache: torch.Tensor,
                position_ids_cache: torch.Tensor,
                past_key_values: torch.Tensor,
                past_key_values_flag: torch.Tensor,  # [[1]] only if use past_key_values,
                input_ids_wo_cache: torch.Tensor,
                attention_mask_wo_cache: torch.Tensor,
                position_ids_wo_cache: torch.Tensor
                ):
        if past_key_values_flag[0, 0] <= 0:
            # Model that takes caching into account
            # input_ids: shape is [batch_size x beam_size, 1]
            # past_key_values: [
            #     no. of decoder layers,
            #     2 ,
            #     (batch_size x beam_siz),
            #     no. attention_heads,
            #     seq_length,
            #     embedding // attention_heads
            # ],
            # attention_mask: shape is [batch_size x beam_size, sequence_length]
            # position: shape is [batch_size x beam_size, 1]
            return self.gpt2_model_with_cache(
                input_ids=input_ids_cache, position_ids=position_ids_cache,
                attention_mask=attention_mask_cache, past_key_values=past_key_values
            )
        else:
            # Model doesn't take caching into account
            # input_ids: shape is [batch_size x beam_size, sequence_length]
            # position_ids: shape is [batch_size x beam_size, sequence_length]
            # attention_mask: shape [batch_size x beam_size, sequence_length]
            # return -> [batch_size, sequence_length, vocab_size], KV_cache
            return self.gpt2_model_without_cache(
                input_ids=input_ids_wo_cache,
                position_ids=position_ids_wo_cache,
                attention_mask=attention_mask_wo_cache)


def main():
    import pdb
    pdb.set_trace()
    hugging_face_model_name = "uer/gpt2-distil-chinese-cluecorpussmall"
    tokenizer = AutoTokenizer.from_pretrained(hugging_face_model_name)
    # GPT2LMHeadModel
    gpt2 = AutoModelForCausalLM.from_pretrained(hugging_face_model_name)

    # TODO: remove this later
    gpt2_copy = AutoModelForCausalLM.from_pretrained(hugging_face_model_name)
    wrapper_net = ControlFlowNet(model=gpt2, model2=gpt2_copy)
    wrapper_net.eval()
    scripted_model = torch.jit.script(wrapper_net)
    import coremltools

    mock_batch_size = 1
    mock_beam_size = 3
    batch_beam_size = mock_batch_size * mock_beam_size
    mock_sequence_length = 4
    decoder_layers = 6
    num_attention_heads = 12
    n_embedding = 768
    mock_vocab_size = 21128

    # input with cache
    input_ids_cache = coremltools.TensorType(name='input_ids_cache', shape=(batch_beam_size, 1),
                                             dtype=coremltools.converters.mil.mil.types.int64)
    attention_mask_cache = coremltools.TensorType(name='attention_mask_cache',
                                                  shape=(batch_beam_size, mock_sequence_length),
                                                  dtype=coremltools.converters.mil.mil.types.int64)
    position_ids_cache = coremltools.TensorType(name='position_ids_cache', shape=(batch_beam_size, 1),
                                                dtype=coremltools.converters.mil.mil.types.int64)
    past_key_values_shape = (
        decoder_layers * 2,
        batch_beam_size,
        num_attention_heads,
        mock_sequence_length - 1,
        n_embedding // num_attention_heads
    )
    past_key_values = coremltools.TensorType(name='past_key_values', shape=past_key_values_shape, dtype=float)

    # input without cache
    input_ids_wo_cache = coremltools.TensorType(name='input_ids_wo_cache', shape=(3, mock_sequence_length),
                                                dtype=coremltools.converters.mil.mil.types.int64)
    attention_mask_wo_cache = coremltools.TensorType(name='attention_mask_wo_cache',
                                                     shape=(3, mock_sequence_length),
                                                     dtype=coremltools.converters.mil.mil.types.int64)
    position_ids_wo_cache = coremltools.TensorType(name='position_ids_wo_cache', shape=(3, mock_sequence_length),
                                                   dtype=coremltools.converters.mil.mil.types.int64)

    # a flag to switch cases
    past_key_values_flag = coremltools.TensorType(name='past_key_values_flag', shape=(1, 1), dtype=np.int32)
    mlmodel = coremltools.converters.convert(
        scripted_model,
        inputs=[
            input_ids_cache,
            attention_mask_cache,
            position_ids_cache,
            past_key_values,
            past_key_values_flag,
            input_ids_wo_cache,
            attention_mask_wo_cache,
            position_ids_wo_cache
        ],
        pass_pipeline=coremltools.PassPipeline.EMPTY,
        minimum_deployment_target=coremltools.target.iOS15,
        convert_to='mlprogram'
    )
    print("The model is converted. Now checking its validity by doing inference")
    # don't use int64 even though it is used by torch
    # https://github.com/apple/coremltools/issues/194
    mlmodel.predict(data={
        "input_ids_wo_cache": torch
        .randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, mock_sequence_length))
        .to(torch.int64)
        .numpy().astype(np.int32),
        "attention_mask_wo_cache":
            torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64)
        .numpy().astype(np.int32),
        "position_ids_wo_cache": (
                torch.zeros(mock_batch_size * mock_beam_size, mock_sequence_length) +
                torch.unsqueeze(torch.tensor(list(range(0, mock_sequence_length))), dim=0)
        ).to(torch.int64).numpy().astype(np.int32),
        "past_key_values_flag": np.array([[1]], dtype=np.int32),  # use cache <= 0
        "past_key_values": torch
        .rand(
            decoder_layers * 2,
            mock_batch_size * mock_beam_size,
            num_attention_heads,
            mock_sequence_length - 1,
            n_embedding // num_attention_heads
        )
        .numpy(),
        "input_ids_cache": torch.randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, 1))
        .to(torch.int64).numpy().astype(np.int32),
        "attention_mask_cache":
            torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64).numpy().astype(
                np.int32),
        "position_ids_cache":
            (torch.ones(mock_batch_size * mock_beam_size, 1) * (mock_sequence_length - 1)).to(
                torch.int64).numpy().astype(np.int32)
    })
    print("The model can perform inference correctly")

    from pathlib import Path
    current_dir = Path(__file__).parent
    mlmodel.save(str(current_dir / "saved" / "gpt2.mlpackage"))
    print("The model is saved")


if __name__ == '__main__':

    try:
        main()
    except:
        extype, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)

Dependencies

transformers==4.30.2
coremltools==7.0
numpy==1.24.3
torch==2.0.0
@Ykid Ykid added the question Response providing clarification needed. Will not be assigned to a release. (type) label Oct 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)
Projects
None yet
Development

No branches or pull requests

1 participant