Skip to content

Commit

Permalink
[tnx] neuron rolling batch test suite (#2172)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Jul 13, 2024
1 parent 1038c63 commit ad268ca
Show file tree
Hide file tree
Showing 5 changed files with 585 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from collections import defaultdict
from typing import List, Dict
from dataclasses import dataclass
from djl_python import test_model, Input
from djl_python.request import Request
from djl_python.input_parser import format_input


@dataclass
class SimulationSchedule:
prompts: List[str]
params: List[Dict]
reqs_to_prefill: List[int]
wait_steps: List[int]


class NeuronRollingBatchGenerator:

def __init__(self):
self.rolling_batch = None
self._req_id = 0
# Store the results
self.output_all = defaultdict(list)
self.input_all = {}
self.data_collector = []
self.responses = []

# Status variables
self.input_str = []
self.params = []
self.req_ids = []

# Spec_dec
self.token_numbers = defaultdict(list)

def init_neuron_service(self, properties: dict):
from djl_python.transformers_neuronx import TransformersNeuronXService
_service = TransformersNeuronXService()
_service.initialize(properties)
self.rolling_batch = _service.rolling_batch

def get_req_id(self):
req_id = self._req_id
self._req_id = self._req_id + 1
return req_id

def collect_data(self, result):
done_requests_indices = []
for idx, item in enumerate(result):
if len(self.data_collector) <= idx:
self.data_collector.append(item["data"])
else:
self.data_collector[idx] += item["data"]
if item['last']:
done_requests_indices.append(idx)
for idx in sorted(done_requests_indices, reverse=True):
value = self.data_collector.pop(idx)
self.responses.append(value)
print(f"\nFinished request: {value}\n")
return done_requests_indices

def build_request(self, raw_input):
inputs = test_model.create_json_request(raw_input)
parsed_inputs = format_input(inputs)
request = Request(parsed_inputs)
request.id = self.get_req_id()
return request

def simulator(self, schedule: SimulationSchedule):
assert len(schedule.prompts) == len(schedule.params)
assert len(schedule.reqs_to_prefill) == len(schedule.wait_steps)
zipped_requests = zip(schedule.prompts, schedule.params)
all_requests = [{
"inputs": prompt,
"parameters": params
} for prompt, params in zipped_requests]
current_requests = []
new_requests = []
for batch_size, step in zip(schedule.reqs_to_prefill,
schedule.wait_steps):
for _ in range(batch_size):
request = self.build_request(all_requests.pop(0))
new_requests = [request] + new_requests
current_requests.append(request)

for i in range(step):
if len(current_requests) == 0:
break
generated_tokens = self.rolling_batch.inference(new_requests)
new_requests.clear()
finished_indices = self.collect_data(generated_tokens)
for idx in sorted(finished_indices, reverse=True):
current_requests.pop(idx)
while len(current_requests) > 0:
generated_tokens = self.rolling_batch.inference(new_requests)
finished_indices = self.collect_data(generated_tokens)
for idx in sorted(finished_indices, reverse=True):
current_requests.pop(idx)

def step(self, step=20, input_str_delta=None, params_delta=None):
if input_str_delta:
begin_id = max(self.input_all.keys(), default=0) + 1
req_ids_delta = list(
range(begin_id, begin_id + len(input_str_delta)))

self.input_str += input_str_delta
self.params += params_delta
self.req_ids += req_ids_delta
for req_id, input_s, param in zip(req_ids_delta, input_str_delta,
params_delta):
self.input_all[req_id] = (input_s, param)

iterator = range(step)
for i in iterator:
result = self.rolling_batch.inference(self.input_str, self.params)
for res, req_id in zip(result, self.req_ids):
self.output_all[req_id].append(res['data'])
self.token_numbers[req_id].append(res.get('step_token_num', 1))
self.req_ids = [
req_id for req_id, res in zip(self.req_ids, result)
if not res['last']
]
self.input_str = [
s for s, res in zip(self.input_str, result) if not res['last']
]
self.params = [
p for p, res in zip(self.params, result) if not res['last']
]
if not self.req_ids:
break

def is_empty(self):
return not self.req_ids

def reset(self):
self.data_collector = []
self.rolling_batch = None
# Store the results
self.output_all = defaultdict(list)
self.input_all = {}

# Status variables, the remaining
self.input_str = []
self.params = []
self.req_ids = []

self.token_numbers = defaultdict(list)
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import unittest
import json
import os

try:
import transformers_neuronx
from djl_python.transformers_neuronx import TransformersNeuronXService
SKIP_TEST = False
except ImportError:
SKIP_TEST = True

expected_text_30 = {
"TinyLlama/TinyLlama-1.1B-Chat-v0.6": {
0:
"Hello, my name is [Your Name] and I am a [Your Job Title] at [Your Company Name]. I am interested in learning more about your company'",
1:
'The president of the United States is a man named Donald Trump.\n\n2. The president of the United States is a man named Donald Trump.\n\n3. The president',
2:
'The capital of France is Paris.\n\n2. The capital of the United States is Washington, D.C.\n\n3. The capital of Canada is Ott',
3:
"The future of AI is bright, and it's already here. With the help of AI, we can create more personalized experiences, automate repetitive tasks, and even predict the future.",
}
}


@unittest.skipIf(SKIP_TEST, "Neuron dependencies are not available")
class TestNeuronRollingBatch(unittest.TestCase):

def test_models(self):
# === Preparation ===
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule

# --- Models ---
model_names = [
"TinyLlama/TinyLlama-1.1B-Chat-v0.6",
]

# === Test ===
for model_id in model_names:
properties = {
"tensor_parallel_degree": 2,
"n_positions": "128",
"rolling_batch": "tnx",
"max_rolling_batch_size": 4,
"model_id": model_id
}

# ===================== neuron-tnx ============================
gen = NeuronRollingBatchGenerator()
gen.init_neuron_service(properties)

print('========== init inference ===========')
input_str = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

params = [{
"max_new_tokens": 100,
"do_sample": False,
}.copy() for _ in range(len(input_str))]

test_input = SimulationSchedule(prompts=input_str,
params=params,
reqs_to_prefill=[1, 2, 1],
wait_steps=[1, 4, 5])

gen.simulator(test_input)

for i, out in enumerate(gen.responses):
out_dict = json.loads(''.join(out))
out_str = out_dict["generated_text"]
test_generation = input_str[i] + " " + out_str
print(f"\n====req_id: {i}=====\n{test_generation}\n")
if model_id in expected_text_30 and i in expected_text_30[
model_id]:
expected_prefix_30_req_id = expected_text_30[model_id][i]
assert expected_prefix_30_req_id == test_generation[:len(
expected_prefix_30_req_id)]

gen.reset()
del gen
import gc
gc.collect()

def test_tiny_models(self):
# === Preparation ===
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule
from djl_python.tests.neuron_test_scripts.tiny_models import artifacts
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- Models ---
model_names = [
"llama",
"gpt2",
"gptneox",
"bloom",
]

# === Test ===
for model_id in model_names:
properties = {
"tensor_parallel_degree": 2,
"n_positions": "128",
"rolling_batch": "tnx",
"max_rolling_batch_size": 4,
"model_loading_timeout": 3600,
"model_id": artifacts(model_id)
}

# ===================== neuron-tnx ============================
gen = NeuronRollingBatchGenerator()
gen.init_neuron_service(properties)

print('========== init inference ===========')
input_str = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

params = [{
"max_new_tokens": 100,
"do_sample": False,
"ignore_eos": True,
}.copy() for _ in range(len(input_str))]

test_input = SimulationSchedule(prompts=input_str,
params=params,
reqs_to_prefill=[1, 2, 1],
wait_steps=[1, 4, 5])

gen.simulator(test_input)
gen.reset()
del gen
import gc
gc.collect()


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit ad268ca

Please sign in to comment.