@@ -655,50 +655,52 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
655
655
[MODEL_NAME , "zephyr-lora" ],
656
656
)
657
657
async def test_batch_completions (client : openai .AsyncOpenAI , model_name : str ):
658
- # test simple list
659
- batch = await client .completions .create (
660
- model = model_name ,
661
- prompt = ["Hello, my name is" , "Hello, my name is" ],
662
- max_tokens = 5 ,
663
- temperature = 0.0 ,
664
- )
665
- assert len (batch .choices ) == 2
666
- assert batch .choices [0 ].text == batch .choices [1 ].text
667
-
668
- # test n = 2
669
- batch = await client .completions .create (
670
- model = model_name ,
671
- prompt = ["Hello, my name is" , "Hello, my name is" ],
672
- n = 2 ,
673
- max_tokens = 5 ,
674
- temperature = 0.0 ,
675
- extra_body = dict (
676
- # NOTE: this has to be true for n > 1 in vLLM, but not necessary
677
- # for official client.
678
- use_beam_search = True ),
679
- )
680
- assert len (batch .choices ) == 4
681
- assert batch .choices [0 ].text != batch .choices [
682
- 1 ].text , "beam search should be different"
683
- assert batch .choices [0 ].text == batch .choices [
684
- 2 ].text , "two copies of the same prompt should be the same"
685
- assert batch .choices [1 ].text == batch .choices [
686
- 3 ].text , "two copies of the same prompt should be the same"
658
+ # test both text and token IDs
659
+ for prompts in (["Hello, my name is" ] * 2 , [[0 , 0 , 0 , 0 , 0 ]] * 2 ):
660
+ # test simple list
661
+ batch = await client .completions .create (
662
+ model = model_name ,
663
+ prompt = prompts ,
664
+ max_tokens = 5 ,
665
+ temperature = 0.0 ,
666
+ )
667
+ assert len (batch .choices ) == 2
668
+ assert batch .choices [0 ].text == batch .choices [1 ].text
687
669
688
- # test streaming
689
- batch = await client .completions .create (
690
- model = model_name ,
691
- prompt = ["Hello, my name is" , "Hello, my name is" ],
692
- max_tokens = 5 ,
693
- temperature = 0.0 ,
694
- stream = True ,
695
- )
696
- texts = ["" ] * 2
697
- async for chunk in batch :
698
- assert len (chunk .choices ) == 1
699
- choice = chunk .choices [0 ]
700
- texts [choice .index ] += choice .text
701
- assert texts [0 ] == texts [1 ]
670
+ # test n = 2
671
+ batch = await client .completions .create (
672
+ model = model_name ,
673
+ prompt = prompts ,
674
+ n = 2 ,
675
+ max_tokens = 5 ,
676
+ temperature = 0.0 ,
677
+ extra_body = dict (
678
+ # NOTE: this has to be true for n > 1 in vLLM, but not necessary
679
+ # for official client.
680
+ use_beam_search = True ),
681
+ )
682
+ assert len (batch .choices ) == 4
683
+ assert batch .choices [0 ].text != batch .choices [
684
+ 1 ].text , "beam search should be different"
685
+ assert batch .choices [0 ].text == batch .choices [
686
+ 2 ].text , "two copies of the same prompt should be the same"
687
+ assert batch .choices [1 ].text == batch .choices [
688
+ 3 ].text , "two copies of the same prompt should be the same"
689
+
690
+ # test streaming
691
+ batch = await client .completions .create (
692
+ model = model_name ,
693
+ prompt = prompts ,
694
+ max_tokens = 5 ,
695
+ temperature = 0.0 ,
696
+ stream = True ,
697
+ )
698
+ texts = ["" ] * 2
699
+ async for chunk in batch :
700
+ assert len (chunk .choices ) == 1
701
+ choice = chunk .choices [0 ]
702
+ texts [choice .index ] += choice .text
703
+ assert texts [0 ] == texts [1 ]
702
704
703
705
704
706
@pytest .mark .asyncio
0 commit comments