Skip to content

Commit

Permalink
[feature] Enable model sharding on seq_scheduler tested on gpt_neox_2…
Browse files Browse the repository at this point in the history
…0B (deepjavalibrary#1086)

Co-authored-by: KexinFeng <fenkexin@amazon.com>
  • Loading branch information
KexinFeng and KexinFeng authored Sep 25, 2023
1 parent f2f93e1 commit 7dfeee2
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from djl_python.scheduler import HuggingfaceBlock, BloomBlock, SearchConfig, SeqBatchScheduler
from djl_python.scheduler import HuggingfaceBlock, BloomBlock, FalconBlock, SearchConfig, SeqBatchScheduler
# from seq_scheduler import HuggingfaceBlock, BloomBlock, FalconBlock, SearchConfig, SeqBatchScheduler
from collections import namedtuple, defaultdict
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

import torch
import re

MODEL_TYPE_2_BLOCK = {'bloom': BloomBlock}
MODEL_TYPE_2_BLOCK = {'bloom': BloomBlock, 'falcon': FalconBlock}
DEFAULT_SEARCH_ALGORITHM = 'greedy'


Expand All @@ -37,6 +39,8 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
super().__init__(device, **kwargs)
self._init_model_and_tokenizer(model_id_or_path,
device=device,
multi_gpu=properties.get(
'multi_gpu', None),
**kwargs)
self._init_scheduler(properties)

Expand Down Expand Up @@ -91,6 +95,7 @@ def preprocess_requests(self, requests):
def _init_model_and_tokenizer(self,
model_id_or_path,
device=None,
multi_gpu=None,
**kwargs):
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
architectures = self.config.architectures
Expand All @@ -111,17 +116,35 @@ def _init_model_and_tokenizer(self,
if 'device_map' in kwargs:
device_map = kwargs.pop('device_map')

self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, device_map=device_map, **kwargs)
if "lmi_dist_sharding" == multi_gpu:
if 'neox' in model_id_or_path:
try:
from lmi_dist.models.gpt_neox import GPTNeoxSharded
from lmi_dist.utils import download_and_convert_weights

download_and_convert_weights(model_id_or_path)
self.model = GPTNeoxSharded(model_id_or_path)
except ImportError:
print(
f"Running {model_id_or_path} requires package lmi_dist."
)
else:
raise Exception(
f"{model_id_or_path} with lmi_dist_sharding is currently unsupported."
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, device_map=device_map, **kwargs)

self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left")
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

def _init_scheduler(self, properties):
lm_block_cls = MODEL_TYPE_2_BLOCK.get(self.config.model_type,
HuggingfaceBlock)
lm_block_cls = MODEL_TYPE_2_BLOCK.get(
'falcon' if 'falcon' in self.config.model_type else '$',
HuggingfaceBlock)
self.lm_block = lm_block_cls(self.model)
self.search_config = SearchConfig(
eos_token_id=self.tokenizer.eos_token,
Expand Down Expand Up @@ -211,7 +234,9 @@ def _construct_search_config(self, parameters):
top_p=parameters.get('top_p', self.search_config.topk),
temperature=parameters.get('temperature',
self.search_config.temperature),
use_lru_kv_cache=use_lru_kv_cache)
use_lru_kv_cache=use_lru_kv_cache,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id)


def _get_request_ids_tensor(request_ids):
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from .lm_block import HuggingfaceBlock, BloomBlock
from .lm_block import HuggingfaceBlock, BloomBlock, FalconBlock
from .search_config import SearchConfig
from .seq_batch_scheduler import SeqBatchScheduler
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import defaultdict
import torch
from djl_python.rolling_batch import SchedulerRollingBatch
import torch.distributed as dist


def print_rank0(content):
if not dist.is_initialized() or dist.get_rank() == 0:
print(content)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

properties = {
"tensor_parallel_degree": 2,
"dtype": "fp16",
"max_rolling_batch_size": 8,
"model_loading_timeout": 7200,
"max_rolling_batch_prefill_tokens": 10000,
"paged_attention": "True"
}

model_id = "EleutherAI/gpt-neox-20b"
"""
{"inputs":"write a program to add two numbers in python","parameters":{"max_new_tokens":1000, "do_sample":true, "temperature":0.7}}
"""

input_str = [
"Memories follow me left and right", "Memories follow me left and right."
]

params = [{
"max_new_tokens": 50,
"do_sample": False,
"temperature": 0.000007
}, {
"max_new_tokens": 50,
"do_sample": False,
"temperature": 0.000007
}]

# ===================== lmi ============================
print("=========== lmi =========")
rolling_batch = SchedulerRollingBatch(model_id, device, properties)
rolling_batch.output_formatter = None
print("reach here")

output_all = defaultdict(list)
result = rolling_batch.inference(input_str, params)
for i, res in enumerate(result):
output_all[i].append(res['data'])

for _ in range(50):
result = rolling_batch.inference(input_str, params)
for i, res in enumerate(result):
output_all[i].append(res['data'])

for i, out in enumerate(output_all.values()):
print_rank0(input_str[i] + ''.join(out))
print_rank0('\n====')
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from djl_python.scheduler.lm_block import HuggingfaceBlock
from djl_python.scheduler.seq_batch_scheduler import SeqBatchScheduler
from transformers import AutoConfig
from djl_python.scheduler.search_config import SearchConfig
import torch
from transformers import AutoTokenizer

from lmi_dist.models.gpt_neox import GPTNeoxSharded
from lmi_dist.utils import download_and_convert_weights

global_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TestSchedulerSharded:

def test_lm_block(self):
model_id = "EleutherAI/gpt-neox-20b"
download_and_convert_weights(model_id)
model = GPTNeoxSharded(model_id)

device = model.device
tokenizer = AutoTokenizer.from_pretrained(model_id)

encoding = tokenizer("Hello, my dog is cute", return_tensors="pt")
input_ids_0 = encoding.data['input_ids']
seq_len = input_ids_0.shape[1]

lm_block = HuggingfaceBlock(model)

input0 = [
torch.repeat_interleave(input_ids_0, dim=0, repeats=2).to(device),
torch.repeat_interleave(torch.arange(seq_len)[None, :],
dim=0,
repeats=2).to(device),
torch.repeat_interleave(torch.ones(seq_len,
dtype=torch.int64)[None, :],
dim=0,
repeats=2).to(device)
]

output0 = lm_block.forward(*input0, None)

model_config = AutoConfig.from_pretrained(model_id)
assert len(output0.past_key_values) == model_config.num_hidden_layers

# input with kv_cache
# k: [32, 64, 6], v: [32, 6, 64], [batch*head, kvDim, seq]
past_key_values = output0.past_key_values
input_ids = torch.tensor([[404], [405]]).to(device)
past_seq = past_key_values[0][0].shape[-2]
position_ids = torch.tensor([[past_seq], [past_seq]]).to(device)
attention_mask = torch.ones(2, past_seq + 1,
dtype=torch.int64).to(device)
output1 = lm_block.forward(input_ids, position_ids, attention_mask,
past_key_values)
assert len(output1.past_key_values) == model_config.num_hidden_layers

def test_contrastive_scheduler(self):
model_id = "EleutherAI/gpt-neox-20b"
download_and_convert_weights(model_id)
model = GPTNeoxSharded(model_id)

device = model.device
tokenizer = AutoTokenizer.from_pretrained(model_id,
padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

lm_block = HuggingfaceBlock(model)

search_config = SearchConfig()
search_config.pad_token_id = tokenizer.pad_token_id
PAD = search_config.pad_token_id
scheduler = SeqBatchScheduler(lm_block, "contrastive", search_config)

input_ids_0 = tokenizer.encode(
'Memories follow me left and right. I can',
return_tensors='pt').to(device)
request_ids = torch.tensor([[0]])

# Test init_forward
scheduler.add_request(input_ids_0, request_ids)

# Merge longer sequences
input12 = [
r"When your legs don't work like they used to before And I can't sweep you off",
r"There's a time that I remember, when I did not know"
]
input_ids = tokenizer(input12, return_tensors='pt',
padding=True).input_ids.to(device)

request_ids = torch.tensor([[1], [2]])
scheduler.add_request(input_ids, request_ids)

# Forward pass
for _ in scheduler.increment_forward(20):
pass

results = scheduler.results

# Merge shorter sequences
input_ids_1 = tokenizer.encode("When your legs don't work",
return_tensors='pt')
input_ids_2 = torch.concat([
torch.tensor([PAD, PAD]),
tokenizer.encode("There's a time", return_tensors='pt')[0]
]).view(1, -1)
input_ids = torch.concat([input_ids_1, input_ids_2], dim=0).to(device)
request_ids = torch.tensor([[3], [4]])

scheduler.add_request(input_ids, request_ids)

# Forward pass
for _ in scheduler.increment_forward(100):
pass

# print
for i, ret in results.items():
print('\n{}:'.format(i), tokenizer.decode(ret))


if __name__ == '__main__':
# unittest.main()

c = TestSchedulerSharded()
# c.test_lm_block()
# c.test_contrastive_scheduler()

0 comments on commit 7dfeee2

Please sign in to comment.