Open
Description
Goal
- My goal is to add KV cache to GPT2 by putting two inference mode in each branch of an if-else loop, the motivation behind which is mainly improving the inference speed by feeding in some cache.
What I tried
- I'm basically following the Mix Tracing and Scripting section
- It seems the results for both branches must be the same, for the cond operator, so I combined the tuple into one and pad it to be the same tensor
- It seems tuples are not a valid input to the coreml models, so I created a wrapper for it.
- mil optimizations seem to mess up the conversion process, so I disabled it by passing in
pass_pipeline=coremltools.PassPipeline.EMPTY
The issue I have now.
The conversion seemed to run without errors but I ran into errors when I try to do a prediction. The error is
RuntimeError: Error compiling model: "Failed to parse the model specification.
Error: Unable to parse ML Program: in operation x_4: Param 'weight' must be
const".
- I checked in the converted proto class
Model
. There are 178 operation steps in the generated proto class. 5 of them are not ofconst
type and only one of them contains a variable namedx_4
. andx_4
appeared in an operation with typecond
Wondering what can be done to fix the issue? Or is there any recommended approach to achieve the same goal? Thanks!
The script
import pdb
import sys
import traceback
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
class GPT2WrapperNet(nn.Module):
"""
Wrap the GPT2LMHeadModel so that passing in boolean is not an issue
"""
def __init__(self, model: GPT2LMHeadModel, has_cache: bool, decoder_layers: int):
super().__init__()
self.has_cache = has_cache
self.model = model
self.decoder_layers = decoder_layers
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor,
past_key_values: Optional[torch.Tensor] = None):
default_inputs = {
'return_dict': False,
'use_cache': True,
'output_attentions': False,
'output_hidden_states': False
}
# no_cache
if not self.has_cache:
default_inputs['input_ids'] = input_ids
default_inputs['position_ids'] = position_ids
default_inputs['attention_mask'] = attention_mask
# logit shape: [batch_size x beam_size, prefix/seq length, vocab_size]
logits, cache = self.model(**default_inputs)
# Tuple x 6 (num_layers)
# --> Tuple x 2 (1 for K, 1 for V)
# -----> [batch_size x beam_size, num_attn_heads, prefix_length, n_embd // n_head]
cache_list = [
cache[0][0], cache[0][1],
cache[1][0], cache[1][1],
cache[2][0], cache[2][1],
cache[3][0], cache[3][1],
cache[4][0], cache[4][1],
cache[5][0], cache[5][1]
]
# cache shape: [
# num_layers x 2, batch_size x beam_size,
# num_attn_heads, prefix_length, n_embd // n_head
# ]
cache = torch.stack(cache_list)
# (364128, )
flattened = torch.cat([torch.flatten(logits), torch.flatten(cache)])
# TODO: update this to be dynamic
pad_size = 500000 - flattened.shape[0]
# need to pad the tensor to return the tensor with the same shape!
return F.pad(flattened, (0, pad_size), "constant", 0.0)
else:
# has_cache
past_key_values_tuple = tuple((
(past_key_values[2 * i], past_key_values[2 * i + 1])
for i in range(0, self.decoder_layers) # change to decoder block count
))
default_inputs['input_ids'] = input_ids
default_inputs['position_ids'] = position_ids
default_inputs['attention_mask'] = attention_mask
default_inputs['past_key_values'] = past_key_values_tuple
# logit shape: [batch_size x beam_size, 1, vocab_size]
logits, cache = self.model(**default_inputs)
# Tuple x 6 (num_layers)
# --> Tuple x 2 (1 for K, 1 for V)
# -----> [batch_size x beam_size, num_attn_heads, prefix_length_for_cache + 1, n_embd // n_head]
cache_list = [
cache[0][0], cache[0][1],
cache[1][0], cache[1][1],
cache[2][0], cache[2][1],
cache[3][0], cache[3][1],
cache[4][0], cache[4][1],
cache[5][0], cache[5][1]
]
# [
# num_layers x 2, batch_size x beam_size,
# num_attn_heads, prefix_length_for_cache + 1, n_embd // n_head
# ]
cache = torch.stack(cache_list)
# (173976,)
flattened = torch.cat([torch.flatten(logits), torch.flatten(cache)])
# TODO: update this to be dynamic
pad_size = 500000 - flattened.shape[0]
# need to pad the tensor to return the tensor with the same shape!
return F.pad(flattened, (0, pad_size), "constant", 0.0)
class ControlFlowNet(nn.Module):
def __init__(self, model: GPT2LMHeadModel, model2: GPT2LMHeadModel):
super().__init__()
# later can load this from model config
mock_batch_size = 1
mock_beam_size = 3
mock_sequence_length = 4
decoder_layers = 6
num_attention_heads = 12
n_embedding = 768
mock_vocab_size = 21128
# Model that doesn't take caching into account
input_without_cache = [
# input_ids
torch.randint(low=0, high=mock_vocab_size,
size=(mock_batch_size * mock_beam_size, mock_sequence_length))
.to(torch.int64),
# position_ids
(
torch.zeros(mock_batch_size * mock_beam_size, mock_sequence_length) +
torch.unsqueeze(torch.tensor(list(range(0, mock_sequence_length))), dim=0)
).to(torch.int64),
# attention_mask
torch.abs(
torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64))
]
wrapper_gpt2_model_without_cache = GPT2WrapperNet(model=model, has_cache=False,
decoder_layers=decoder_layers)
self.gpt2_model_without_cache = torch.jit.trace(wrapper_gpt2_model_without_cache, input_without_cache)
past_key_values_2 = torch.rand(
decoder_layers * 2,
mock_batch_size * mock_beam_size,
num_attention_heads,
mock_sequence_length - 1,
n_embedding // num_attention_heads
)
# Model that takes caching into account
input_with_cache = [
# input_ids, last dim is 1 because we only need the last token or
# tokens (in the case of batch inference)
torch.abs(
torch.randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, 1)).to(
torch.int64)),
# position_ids
torch.abs(
(torch.ones(mock_batch_size * mock_beam_size, 1) * (mock_sequence_length - 1)).to(torch.int64)
),
# attention_mask
torch.abs(
torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64)),
past_key_values_2
]
wrapper_gpt2_model_with_cache = GPT2WrapperNet(model=model2, has_cache=True,
decoder_layers=decoder_layers)
self.gpt2_model_with_cache = torch.jit.trace(wrapper_gpt2_model_with_cache, input_with_cache)
def forward(self,
input_ids_cache: torch.Tensor,
attention_mask_cache: torch.Tensor,
position_ids_cache: torch.Tensor,
past_key_values: torch.Tensor,
past_key_values_flag: torch.Tensor, # [[1]] only if use past_key_values,
input_ids_wo_cache: torch.Tensor,
attention_mask_wo_cache: torch.Tensor,
position_ids_wo_cache: torch.Tensor
):
if past_key_values_flag[0, 0] <= 0:
# Model that takes caching into account
# input_ids: shape is [batch_size x beam_size, 1]
# past_key_values: [
# no. of decoder layers,
# 2 ,
# (batch_size x beam_siz),
# no. attention_heads,
# seq_length,
# embedding // attention_heads
# ],
# attention_mask: shape is [batch_size x beam_size, sequence_length]
# position: shape is [batch_size x beam_size, 1]
return self.gpt2_model_with_cache(
input_ids=input_ids_cache, position_ids=position_ids_cache,
attention_mask=attention_mask_cache, past_key_values=past_key_values
)
else:
# Model doesn't take caching into account
# input_ids: shape is [batch_size x beam_size, sequence_length]
# position_ids: shape is [batch_size x beam_size, sequence_length]
# attention_mask: shape [batch_size x beam_size, sequence_length]
# return -> [batch_size, sequence_length, vocab_size], KV_cache
return self.gpt2_model_without_cache(
input_ids=input_ids_wo_cache,
position_ids=position_ids_wo_cache,
attention_mask=attention_mask_wo_cache)
def main():
import pdb
pdb.set_trace()
hugging_face_model_name = "uer/gpt2-distil-chinese-cluecorpussmall"
tokenizer = AutoTokenizer.from_pretrained(hugging_face_model_name)
# GPT2LMHeadModel
gpt2 = AutoModelForCausalLM.from_pretrained(hugging_face_model_name)
# TODO: remove this later
gpt2_copy = AutoModelForCausalLM.from_pretrained(hugging_face_model_name)
wrapper_net = ControlFlowNet(model=gpt2, model2=gpt2_copy)
wrapper_net.eval()
scripted_model = torch.jit.script(wrapper_net)
import coremltools
mock_batch_size = 1
mock_beam_size = 3
batch_beam_size = mock_batch_size * mock_beam_size
mock_sequence_length = 4
decoder_layers = 6
num_attention_heads = 12
n_embedding = 768
mock_vocab_size = 21128
# input with cache
input_ids_cache = coremltools.TensorType(name='input_ids_cache', shape=(batch_beam_size, 1),
dtype=coremltools.converters.mil.mil.types.int64)
attention_mask_cache = coremltools.TensorType(name='attention_mask_cache',
shape=(batch_beam_size, mock_sequence_length),
dtype=coremltools.converters.mil.mil.types.int64)
position_ids_cache = coremltools.TensorType(name='position_ids_cache', shape=(batch_beam_size, 1),
dtype=coremltools.converters.mil.mil.types.int64)
past_key_values_shape = (
decoder_layers * 2,
batch_beam_size,
num_attention_heads,
mock_sequence_length - 1,
n_embedding // num_attention_heads
)
past_key_values = coremltools.TensorType(name='past_key_values', shape=past_key_values_shape, dtype=float)
# input without cache
input_ids_wo_cache = coremltools.TensorType(name='input_ids_wo_cache', shape=(3, mock_sequence_length),
dtype=coremltools.converters.mil.mil.types.int64)
attention_mask_wo_cache = coremltools.TensorType(name='attention_mask_wo_cache',
shape=(3, mock_sequence_length),
dtype=coremltools.converters.mil.mil.types.int64)
position_ids_wo_cache = coremltools.TensorType(name='position_ids_wo_cache', shape=(3, mock_sequence_length),
dtype=coremltools.converters.mil.mil.types.int64)
# a flag to switch cases
past_key_values_flag = coremltools.TensorType(name='past_key_values_flag', shape=(1, 1), dtype=np.int32)
mlmodel = coremltools.converters.convert(
scripted_model,
inputs=[
input_ids_cache,
attention_mask_cache,
position_ids_cache,
past_key_values,
past_key_values_flag,
input_ids_wo_cache,
attention_mask_wo_cache,
position_ids_wo_cache
],
pass_pipeline=coremltools.PassPipeline.EMPTY,
minimum_deployment_target=coremltools.target.iOS15,
convert_to='mlprogram'
)
print("The model is converted. Now checking its validity by doing inference")
# don't use int64 even though it is used by torch
# https://github.com/apple/coremltools/issues/194
mlmodel.predict(data={
"input_ids_wo_cache": torch
.randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, mock_sequence_length))
.to(torch.int64)
.numpy().astype(np.int32),
"attention_mask_wo_cache":
torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64)
.numpy().astype(np.int32),
"position_ids_wo_cache": (
torch.zeros(mock_batch_size * mock_beam_size, mock_sequence_length) +
torch.unsqueeze(torch.tensor(list(range(0, mock_sequence_length))), dim=0)
).to(torch.int64).numpy().astype(np.int32),
"past_key_values_flag": np.array([[1]], dtype=np.int32), # use cache <= 0
"past_key_values": torch
.rand(
decoder_layers * 2,
mock_batch_size * mock_beam_size,
num_attention_heads,
mock_sequence_length - 1,
n_embedding // num_attention_heads
)
.numpy(),
"input_ids_cache": torch.randint(low=0, high=mock_vocab_size, size=(mock_batch_size * mock_beam_size, 1))
.to(torch.int64).numpy().astype(np.int32),
"attention_mask_cache":
torch.ones(mock_batch_size * mock_beam_size, mock_sequence_length).to(torch.int64).numpy().astype(
np.int32),
"position_ids_cache":
(torch.ones(mock_batch_size * mock_beam_size, 1) * (mock_sequence_length - 1)).to(
torch.int64).numpy().astype(np.int32)
})
print("The model can perform inference correctly")
from pathlib import Path
current_dir = Path(__file__).parent
mlmodel.save(str(current_dir / "saved" / "gpt2.mlpackage"))
print("The model is saved")
if __name__ == '__main__':
try:
main()
except:
extype, value, tb = sys.exc_info()
traceback.print_exc()
pdb.post_mortem(tb)
Dependencies
transformers==4.30.2
coremltools==7.0
numpy==1.24.3
torch==2.0.0