Skip to content

Commit a462a03

Browse files
committed
Add initial implementation of TT backend (Worker, ModelRunner, ModelLoader) with basic llama generation example (#1)
Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
1 parent e1b0048 commit a462a03

File tree

9 files changed

+589
-1
lines changed

9 files changed

+589
-1
lines changed

examples/offline_inference_tt.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import sys
3+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4+
from tt_metal.models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
5+
6+
from vllm import LLM, SamplingParams
7+
from vllm import ModelRegistry
8+
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)
9+
10+
# Sample prompts.
11+
# prompts = [
12+
# "Hello, my name is",
13+
# "The president of the United States is",
14+
# "The capital of France is",
15+
# "The future of AI is",
16+
# ]
17+
prompts = [ "Hello, my name is" ] * 32
18+
# Create a sampling params object.
19+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
20+
21+
# Create an LLM.
22+
llm = LLM(model="meta-llama/Meta-Llama-3.1-70B")
23+
24+
# Generate texts from the prompts. The output is a list of RequestOutput objects
25+
# that contain the prompt, generated text, and other information.
26+
outputs = llm.generate(prompts, sampling_params)
27+
# Print the outputs.
28+
for output in outputs:
29+
prompt = output.prompt
30+
generated_text = output.outputs[0].text
31+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

requirements/tt.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Common dependencies
2+
-r common.txt

setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,10 @@ def _no_device() -> bool:
429429
return VLLM_TARGET_DEVICE == "empty"
430430

431431

432+
def _is_tt() -> bool:
433+
return VLLM_TARGET_DEVICE == "tt"
434+
435+
432436
def _is_cuda() -> bool:
433437
has_cuda = torch.version.cuda is not None
434438
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
@@ -544,6 +548,8 @@ def get_vllm_version() -> str:
544548
if _no_device():
545549
if envs.VLLM_TARGET_DEVICE == "empty":
546550
version += f"{sep}empty"
551+
elif _is_tt():
552+
version += f"{sep}tt"
547553
elif _is_cuda():
548554
if envs.VLLM_USE_PRECOMPILED:
549555
version += f"{sep}precompiled"
@@ -625,6 +631,8 @@ def _read_requirements(filename: str) -> list[str]:
625631
requirements = _read_requirements("cpu.txt")
626632
elif _is_xpu():
627633
requirements = _read_requirements("xpu.txt")
634+
elif _is_tt():
635+
requirements = _read_requirements("tt.txt")
628636
else:
629637
raise ValueError(
630638
"Unsupported platform, please use CUDA, ROCm, Neuron, HPU, "
@@ -665,6 +673,9 @@ def _read_requirements(filename: str) -> list[str]:
665673

666674
if _no_device():
667675
ext_modules = []
676+
677+
if _is_tt():
678+
ext_modules = []
668679

669680
if not ext_modules:
670681
cmdclass = {}

tt_metal/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
2+
## Environment Creation
3+
4+
To setup the tt-metal environment with vLLM, follow the instructions in `setup-metal.sh`
5+
6+
## Accessing the Meta-Llama-3.1 Hugging Face Model
7+
8+
To run Meta-Llama-3.1, it is required to have access to the model on Hugging Face.
9+
Steps:
10+
1. Request access on [https://huggingface.co/meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B).
11+
2. Once you have received access, create and copy your access token from the settings tab on Hugging Face.
12+
3. Run this code in python and paste your access token:
13+
```python
14+
from huggingface_hub import notebook_login
15+
notebook_login()
16+
```
17+
18+
## Importing the tt-metal models
19+
20+
Create a symbolic link to the tt-metal models folder inside vLLM:
21+
```sh
22+
cd tt_metal
23+
ln -s <path/to/tt-metal>/models ./models
24+
```
25+
26+
## Running the offline inference example
27+
```python
28+
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml python examples/offline_inference_tt.py
29+
```

tt_metal/setup-metal.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
export PYTHON_ENV_DIR="${TT_METAL_HOME}/build/python_env_vllm"
2+
export VLLM_TARGET_DEVICE="tt"
3+
4+
# to create vllm env (first time):
5+
# 1. setup tt-metal env vars
6+
# 2. source $vllm_dir/tt_metal/setup-metal.sh (this script)
7+
# 3. build and create tt-metal env as usual
8+
# 4. source $PYTHON_ENV_DIR/bin/activate
9+
# 5. pip3 install --upgrade pip
10+
# 6. cd $vllm_dir && pip install -e .
11+
12+
# to activate (after first time):
13+
# 1. source $vllm_dir/tt_metal/setup-metal.sh && source $PYTHON_ENV_DIR/bin/activate

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2070,7 +2070,7 @@ def __init__(self, device: str = "auto") -> None:
20702070
# Some device types require processing inputs on CPU
20712071
if self.device_type in ["neuron"]:
20722072
self.device = torch.device("cpu")
2073-
elif self.device_type in ["tpu"]:
2073+
elif self.device_type in ["tpu"] or self.device_type in ["tt"]:
20742074
self.device = None
20752075
else:
20762076
# Set device with device type
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
from torch import nn
4+
5+
from vllm.model_executor.model_loader.loader import BaseModelLoader
6+
from vllm.model_executor.model_loader.utils import get_model_architecture
7+
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
8+
ParallelConfig, SchedulerConfig)
9+
10+
11+
class TTModelLoader(BaseModelLoader):
12+
def load_model(self, *, model_config: ModelConfig,
13+
device_config: DeviceConfig,
14+
parallel_config: ParallelConfig,
15+
scheduler_config: SchedulerConfig,
16+
cache_config: CacheConfig) -> nn.Module:
17+
"""Load a model with the given configurations."""
18+
19+
# For TT models, prepend "TT" to the architecture name, e.g. "TTLlamaForCausalLM"
20+
arch_names = model_config.hf_config.architectures
21+
assert len(model_config.hf_config.architectures) == 1
22+
arch_names[0] = "TT" + arch_names[0]
23+
24+
model_class, _ = get_model_architecture(model_config)
25+
model = model_class.initialize_vllm_model(model_config.hf_config, device_config.device)
26+
return model

vllm/worker/tt_model_runner.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from dataclasses import dataclass
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
3+
4+
import torch
5+
6+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig,
7+
ModelConfig, ParallelConfig,
8+
SchedulerConfig)
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.sampler import SamplerOutput
11+
from vllm.model_executor.model_loader.tt_loader import TTModelLoader
12+
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput
13+
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
14+
15+
if TYPE_CHECKING:
16+
from vllm.attention.backends.abstract import AttentionBackend
17+
18+
logger = init_logger(__name__)
19+
20+
21+
@dataclass(frozen=True)
22+
class TTModelInput(ModelRunnerInputBase):
23+
"""
24+
Used by the TTModelRunner.
25+
"""
26+
input_tokens: Optional[torch.Tensor] = None
27+
input_positions: Optional[torch.Tensor] = None
28+
prompt_lens: Optional[torch.Tensor] = None
29+
seq_groups: Optional[List[List[int]]] = None
30+
31+
def as_broadcastable_tensor_dict(
32+
self) -> Dict[str, Union[int, torch.Tensor]]:
33+
tensor_dict = {
34+
"input_tokens": self.input_tokens,
35+
"input_positions": self.input_positions,
36+
"prompt_lens": self.prompt_lens,
37+
"seq_groups": self.seq_groups,
38+
}
39+
40+
return tensor_dict
41+
42+
@classmethod
43+
def from_broadcasted_tensor_dict(
44+
cls: Type["TTModelInput"],
45+
tensor_dict: Dict[str, Any],
46+
) -> "TTModelInput":
47+
return cls(**tensor_dict)
48+
49+
50+
class TTModelRunner(ModelRunnerBase[TTModelInput]):
51+
52+
def __init__(
53+
self,
54+
model_config: ModelConfig,
55+
parallel_config: ParallelConfig,
56+
scheduler_config: SchedulerConfig,
57+
device_config: DeviceConfig,
58+
cache_config: CacheConfig,
59+
load_config: LoadConfig,
60+
):
61+
self.model_config = model_config
62+
self.parallel_config = parallel_config
63+
self.scheduler_config = scheduler_config
64+
# Currently, TT worker doesn't support chunked prefill.
65+
assert self.scheduler_config.chunked_prefill_enabled is False
66+
self.device_config = device_config
67+
self.cache_config = cache_config
68+
self.load_config = load_config
69+
70+
self.device = self.device_config.device
71+
72+
self.sliding_window = model_config.get_sliding_window()
73+
self.block_size = cache_config.block_size
74+
75+
def load_model(self) -> None:
76+
# Note: using custom TT loader instead of selecting from default vllm loaders
77+
loader = TTModelLoader(self.load_config)
78+
self.model = loader.load_model(model_config=self.model_config,
79+
device_config=self.device_config,
80+
parallel_config=self.parallel_config,
81+
scheduler_config=self.scheduler_config,
82+
cache_config=self.cache_config
83+
)
84+
85+
def make_model_input_from_broadcasted_tensor_dict(
86+
self,
87+
tensor_dict: Dict[str, Any],
88+
) -> TTModelInput:
89+
return TTModelInput.from_broadcasted_tensor_dict(
90+
tensor_dict,
91+
)
92+
93+
def prepare_model_input(
94+
self,
95+
seq_group_metadata_list: List[SequenceGroupMetadata],
96+
virtual_engine: int = 0,
97+
finished_requests_ids: Optional[List[str]] = None
98+
) -> TTModelInput:
99+
100+
# NOTE: We assume that all sequences in the group are all prompts or
101+
# all decodes.
102+
is_prompt = seq_group_metadata_list[0].is_prompt # prefill if True, otherwise decode
103+
assert all(x.is_prompt == is_prompt for x in seq_group_metadata_list), "Currently only supporting all prefills or all decodes in seq group"
104+
105+
batch_size = len(seq_group_metadata_list)
106+
assert batch_size > 0
107+
108+
input_tokens: List[int] = []
109+
input_positions: List[int] = []
110+
prompt_lens: List[int] = []
111+
112+
for seq_group_metadata in seq_group_metadata_list:
113+
seq_ids = list(seq_group_metadata.seq_data.keys())
114+
assert len(seq_ids) == 1 # Only support one sequence per request group
115+
seq_id = seq_ids[0]
116+
117+
seq_data = seq_group_metadata.seq_data[seq_id]
118+
119+
if is_prompt:
120+
# tokens
121+
prompt_tokens = seq_data.get_token_ids()
122+
input_tokens.extend(prompt_tokens)
123+
124+
# positions
125+
prompt_len = len(prompt_tokens)
126+
prompt_lens.append(prompt_len)
127+
input_positions.extend(list(range(prompt_len)))
128+
else:
129+
# tokens
130+
generation_token = seq_data.get_last_token_id()
131+
input_tokens.append(generation_token)
132+
133+
# positions
134+
position = seq_data.get_len() - 1
135+
input_positions.append(position)
136+
137+
# TODO: Get block table using seq_group_metadata.block_tables[seq_id]
138+
139+
input_tokens = torch.tensor(input_tokens, dtype=torch.int32, device="cpu")
140+
input_positions = torch.tensor(input_positions, dtype=torch.int32, device="cpu")
141+
if is_prompt:
142+
prompt_lens = torch.tensor(prompt_lens,
143+
dtype=torch.int32,
144+
device="cpu")
145+
else:
146+
prompt_lens = None
147+
148+
seq_groups = [
149+
list(metadata.seq_data.keys())
150+
for metadata in seq_group_metadata_list
151+
]
152+
153+
return TTModelInput(input_tokens, input_positions, prompt_lens, seq_groups)
154+
155+
@torch.no_grad()
156+
def execute_model(
157+
self,
158+
model_input: TTModelInput,
159+
kv_caches: List[torch.Tensor],
160+
intermediate_tensors: Optional[IntermediateTensors] = None,
161+
num_steps: int = 1,
162+
) -> Optional[List[SamplerOutput]]:
163+
if num_steps > 1:
164+
raise ValueError(
165+
"TT worker does not support multi-step execution.")
166+
167+
is_prompt = model_input.prompt_lens is not None # prefill if True, otherwise decode
168+
169+
if is_prompt:
170+
input_position = 0
171+
# Currently only support same prompt length
172+
assert torch.all(model_input.prompt_lens == model_input.prompt_lens[0]), "Currently only supporting same prompt lengths for prefill"
173+
batch_size = model_input.prompt_lens.shape[0]
174+
else:
175+
# Currently only support same decode positions
176+
input_position = model_input.input_positions[0].item()
177+
assert torch.all(model_input.input_positions == input_position), "Currently only supporting same input positions for decode"
178+
batch_size = model_input.input_tokens.shape[0]
179+
180+
input_tokens = model_input.input_tokens.view(batch_size, -1)
181+
182+
execute_model_kwargs = {
183+
"tokens": input_tokens,
184+
"start_pos": input_position,
185+
# TODO: Add block table and maybe kv cache
186+
}
187+
188+
logits = self.model.forward(**execute_model_kwargs) # [batch_size, seq_len, vocab_size]
189+
190+
# Note: for other devices, vLLM applies vllm.model_executor.layers.logits_processor::LogitsProcessor::_apply_logits_processors on logits, we don't use this
191+
# Note: for other devices, vLLM applies vllm.model_executor.layers.sampler::Sampler for sampling tokens, we don't use this
192+
next_logits = logits[:, -1, :] # batch, vocab of last token
193+
next_token_ids = self._sample_tokens(next_logits)
194+
195+
# Minimal code to construct the sampler outputs, based on tpu_model_runner.py
196+
# TT backend does not support the advanced sampling parameters such as logprobs.
197+
zero_logprob = Logprob(0.0)
198+
sampler_outputs = []
199+
for batch_idx, seq_ids in enumerate(model_input.seq_groups):
200+
assert len(seq_ids) == 1 # Only support one sequence per request group
201+
next_token_id = next_token_ids[batch_idx]
202+
seq_outputs = [SequenceOutput(seq_ids[0], next_token_id,
203+
{next_token_id: zero_logprob})]
204+
sampler_outputs.append(
205+
CompletionSequenceGroupOutput(seq_outputs, None))
206+
return [SamplerOutput(sampler_outputs)]
207+
208+
209+
def _sample_tokens(self, logits):
210+
# TODO: Add other sampling methods, currently only using greedy sampling
211+
return torch.argmax(logits, dim=-1)

0 commit comments

Comments
 (0)