1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
+ from collections .abc import Iterable
3
4
from dataclasses import dataclass
4
5
from functools import cached_property
5
- from typing import (TYPE_CHECKING , Any , Dict , Generic , Iterable , List , Literal ,
6
- Optional , Tuple , Union , cast )
6
+ from typing import TYPE_CHECKING , Any , Generic , Literal , Optional , Union , cast
7
7
8
8
import torch
9
9
from typing_extensions import NotRequired , TypedDict , TypeVar , assert_never
@@ -26,7 +26,7 @@ class TextPrompt(TypedDict):
26
26
if the model supports it.
27
27
"""
28
28
29
- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
29
+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
30
30
"""
31
31
Optional multi-modal processor kwargs to be forwarded to the
32
32
multimodal input mapper & processor. Note that if multiple modalities
@@ -38,10 +38,10 @@ class TextPrompt(TypedDict):
38
38
class TokensPrompt (TypedDict ):
39
39
"""Schema for a tokenized prompt."""
40
40
41
- prompt_token_ids : List [int ]
41
+ prompt_token_ids : list [int ]
42
42
"""A list of token IDs to pass to the model."""
43
43
44
- token_type_ids : NotRequired [List [int ]]
44
+ token_type_ids : NotRequired [list [int ]]
45
45
"""A list of token type IDs to pass to the cross encoder model."""
46
46
47
47
multi_modal_data : NotRequired ["MultiModalDataDict" ]
@@ -50,7 +50,7 @@ class TokensPrompt(TypedDict):
50
50
if the model supports it.
51
51
"""
52
52
53
- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
53
+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
54
54
"""
55
55
Optional multi-modal processor kwargs to be forwarded to the
56
56
multimodal input mapper & processor. Note that if multiple modalities
@@ -115,7 +115,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
115
115
116
116
decoder_prompt : Optional [_T2_co ]
117
117
118
- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
118
+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
119
119
120
120
121
121
PromptType = Union [SingletonPrompt , ExplicitEncoderDecoderPrompt ]
@@ -136,10 +136,10 @@ class TokenInputs(TypedDict):
136
136
type : Literal ["token" ]
137
137
"""The type of inputs."""
138
138
139
- prompt_token_ids : List [int ]
139
+ prompt_token_ids : list [int ]
140
140
"""The token IDs of the prompt."""
141
141
142
- token_type_ids : NotRequired [List [int ]]
142
+ token_type_ids : NotRequired [list [int ]]
143
143
"""The token type IDs of the prompt."""
144
144
145
145
prompt : NotRequired [str ]
@@ -164,12 +164,12 @@ class TokenInputs(TypedDict):
164
164
Placeholder ranges for the multi-modal data.
165
165
"""
166
166
167
- multi_modal_hashes : NotRequired [List [str ]]
167
+ multi_modal_hashes : NotRequired [list [str ]]
168
168
"""
169
169
The hashes of the multi-modal data.
170
170
"""
171
171
172
- mm_processor_kwargs : NotRequired [Dict [str , Any ]]
172
+ mm_processor_kwargs : NotRequired [dict [str , Any ]]
173
173
"""
174
174
Optional multi-modal processor kwargs to be forwarded to the
175
175
multimodal input mapper & processor. Note that if multiple modalities
@@ -179,14 +179,14 @@ class TokenInputs(TypedDict):
179
179
180
180
181
181
def token_inputs (
182
- prompt_token_ids : List [int ],
183
- token_type_ids : Optional [List [int ]] = None ,
182
+ prompt_token_ids : list [int ],
183
+ token_type_ids : Optional [list [int ]] = None ,
184
184
prompt : Optional [str ] = None ,
185
185
multi_modal_data : Optional ["MultiModalDataDict" ] = None ,
186
186
multi_modal_inputs : Optional ["MultiModalKwargs" ] = None ,
187
- multi_modal_hashes : Optional [List [str ]] = None ,
187
+ multi_modal_hashes : Optional [list [str ]] = None ,
188
188
multi_modal_placeholders : Optional ["MultiModalPlaceholderDict" ] = None ,
189
- mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
189
+ mm_processor_kwargs : Optional [dict [str , Any ]] = None ,
190
190
) -> TokenInputs :
191
191
"""Construct :class:`TokenInputs` from optional values."""
192
192
inputs = TokenInputs (type = "token" , prompt_token_ids = prompt_token_ids )
@@ -255,7 +255,7 @@ def prompt(self) -> Optional[str]:
255
255
assert_never (inputs ) # type: ignore[arg-type]
256
256
257
257
@cached_property
258
- def prompt_token_ids (self ) -> List [int ]:
258
+ def prompt_token_ids (self ) -> list [int ]:
259
259
inputs = self .inputs
260
260
261
261
if inputs ["type" ] == "token" or inputs ["type" ] == "multimodal" :
@@ -264,7 +264,7 @@ def prompt_token_ids(self) -> List[int]:
264
264
assert_never (inputs ) # type: ignore[arg-type]
265
265
266
266
@cached_property
267
- def token_type_ids (self ) -> List [int ]:
267
+ def token_type_ids (self ) -> list [int ]:
268
268
inputs = self .inputs
269
269
270
270
if inputs ["type" ] == "token" or inputs ["type" ] == "multimodal" :
@@ -294,7 +294,7 @@ def multi_modal_data(self) -> "MultiModalDataDict":
294
294
assert_never (inputs ) # type: ignore[arg-type]
295
295
296
296
@cached_property
297
- def multi_modal_inputs (self ) -> Union [Dict , "MultiModalKwargs" ]:
297
+ def multi_modal_inputs (self ) -> Union [dict , "MultiModalKwargs" ]:
298
298
inputs = self .inputs
299
299
300
300
if inputs ["type" ] == "token" :
@@ -306,7 +306,7 @@ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
306
306
assert_never (inputs ) # type: ignore[arg-type]
307
307
308
308
@cached_property
309
- def multi_modal_hashes (self ) -> List [str ]:
309
+ def multi_modal_hashes (self ) -> list [str ]:
310
310
inputs = self .inputs
311
311
312
312
if inputs ["type" ] == "token" :
@@ -331,7 +331,7 @@ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
331
331
assert_never (inputs ) # type: ignore[arg-type]
332
332
333
333
@cached_property
334
- def mm_processor_kwargs (self ) -> Dict [str , Any ]:
334
+ def mm_processor_kwargs (self ) -> dict [str , Any ]:
335
335
inputs = self .inputs
336
336
337
337
if inputs ["type" ] == "token" :
@@ -355,7 +355,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]:
355
355
def build_explicit_enc_dec_prompt (
356
356
encoder_prompt : _T1 ,
357
357
decoder_prompt : Optional [_T2 ],
358
- mm_processor_kwargs : Optional [Dict [str , Any ]] = None ,
358
+ mm_processor_kwargs : Optional [dict [str , Any ]] = None ,
359
359
) -> ExplicitEncoderDecoderPrompt [_T1 , _T2 ]:
360
360
if mm_processor_kwargs is None :
361
361
mm_processor_kwargs = {}
@@ -368,9 +368,9 @@ def build_explicit_enc_dec_prompt(
368
368
def zip_enc_dec_prompts (
369
369
enc_prompts : Iterable [_T1 ],
370
370
dec_prompts : Iterable [Optional [_T2 ]],
371
- mm_processor_kwargs : Optional [Union [Iterable [Dict [str , Any ]],
372
- Dict [str , Any ]]] = None ,
373
- ) -> List [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]]:
371
+ mm_processor_kwargs : Optional [Union [Iterable [dict [str , Any ]],
372
+ dict [str , Any ]]] = None ,
373
+ ) -> list [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]]:
374
374
"""
375
375
Zip encoder and decoder prompts together into a list of
376
376
:class:`ExplicitEncoderDecoderPrompt` instances.
@@ -380,12 +380,12 @@ def zip_enc_dec_prompts(
380
380
provided, it will be zipped with the encoder/decoder prompts.
381
381
"""
382
382
if mm_processor_kwargs is None :
383
- mm_processor_kwargs = cast (Dict [str , Any ], {})
383
+ mm_processor_kwargs = cast (dict [str , Any ], {})
384
384
if isinstance (mm_processor_kwargs , dict ):
385
385
return [
386
386
build_explicit_enc_dec_prompt (
387
387
encoder_prompt , decoder_prompt ,
388
- cast (Dict [str , Any ], mm_processor_kwargs ))
388
+ cast (dict [str , Any ], mm_processor_kwargs ))
389
389
for (encoder_prompt ,
390
390
decoder_prompt ) in zip (enc_prompts , dec_prompts )
391
391
]
@@ -399,7 +399,7 @@ def zip_enc_dec_prompts(
399
399
400
400
def to_enc_dec_tuple_list (
401
401
enc_dec_prompts : Iterable [ExplicitEncoderDecoderPrompt [_T1 , _T2 ]],
402
- ) -> List [ Tuple [_T1 , Optional [_T2 ]]]:
402
+ ) -> list [ tuple [_T1 , Optional [_T2 ]]]:
403
403
return [(enc_dec_prompt ["encoder_prompt" ],
404
404
enc_dec_prompt ["decoder_prompt" ])
405
405
for enc_dec_prompt in enc_dec_prompts ]
0 commit comments