Skip to content

Commit

Permalink
[Core] Subclass ModelRunner to support cross-attention & encoder sequ…
Browse files Browse the repository at this point in the history
…ences (towards eventual encoder/decoder model support) (vllm-project#4942)

Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent 660470e commit fd95e02
Show file tree
Hide file tree
Showing 33 changed files with 3,976 additions and 352 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ steps:
- python3 cpu_offload.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
- python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py

- label: Models Test # 1hr10min
source_file_dependencies:
Expand Down Expand Up @@ -289,6 +290,7 @@ steps:
commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
- pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
Expand Down
99 changes: 99 additions & 0 deletions examples/offline_inference_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
'''
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
'''

from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists

dtype = "float"

# Create a BART encoder/decoder model instance
llm = LLM(
model="facebook/bart-large-cnn",
dtype=dtype,
)

# Get BART tokenizer
tokenizer = llm.llm_engine.get_tokenizer_group()

# Test prompts
#
# This section shows all of the valid ways to prompt an
# encoder/decoder model.
#
# - Helpers for building prompts
text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is")
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
prompt="The capital of France is"))
# - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None

single_text_prompt_raw = text_prompt_raw # Pass a string directly
single_text_prompt = text_prompt # Pass a TextPrompt
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt

# - Pass explicit encoder and decoder input prompts within one data structure.
# Encoder and decoder prompts can both independently be text or tokens, with
# no requirement that they be the same prompt type. Some example prompt-type
# combinations are shown below, note that these are not exhaustive.

enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt string directly, &
# pass decoder prompt tokens
encoder_prompt=single_text_prompt_raw,
decoder_prompt=single_tokens_prompt,
)
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
# Pass TextPrompt to encoder, and
# pass decoder prompt string directly
encoder_prompt=single_text_prompt,
decoder_prompt=single_text_prompt_raw,
)
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt tokens directly, and
# pass TextPrompt to decoder
encoder_prompt=single_tokens_prompt,
decoder_prompt=single_text_prompt,
)

# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])

# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
prompts = [
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
] + zipped_prompt_list

print(prompts)

# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
)

# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
Loading

0 comments on commit fd95e02

Please sign in to comment.