2929logger = init_logger (__name__ )
3030
3131_S = TypeVar ("_S" , str , list [int ])
32- _PromptSeq = Union [str , list [int ]]
32+
33+ PromptSeq = Union [str , list [int ]]
34+ """A token sequence (list of token IDs) or text."""
3335
3436
3537@dataclass
3638class PromptReplacementDetails :
37- full : _PromptSeq
39+ """Details about the replacement token sequence or text."""
40+
41+ full : PromptSeq
3842 """The full replacement."""
3943
40- features : _PromptSeq
44+ features : PromptSeq
4145 """
42- The part of the replacement that corresponds to placeholder feature tokens.
46+ The part of the replacement that corresponds to feature placeholders;
47+ this will be replaced by the output of the vision encoder during model
48+ inference.
4349 """
4450
4551 @staticmethod
46- def from_seq (seq : _PromptSeq ) -> "PromptReplacementDetails" :
52+ def from_seq (seq : PromptSeq ) -> "PromptReplacementDetails" :
4753 return PromptReplacementDetails (full = seq , features = seq )
4854
4955
50- _PromptRepl = Union [_PromptSeq , PromptReplacementDetails ]
56+ PromptRepl = Union [PromptSeq , PromptReplacementDetails ]
57+ """
58+ The replacement token sequence or text.
59+
60+ If only part of the replacement corresponds to feature placeholders, you can
61+ use :class:`PromptReplacementDetails` to specify which part.
62+ """
5163
5264
5365@dataclass
5466class PromptReplacement :
5567 """
5668 Defines how to replace portions of an input prompt with placeholder tokens.
69+
70+ Example:
71+
72+ For each image, replace one ``<image>`` input placeholder in the prompt
73+ with a number of ``<image>`` feature placeholders
74+ equal to the feature size of the vision encoder:
75+
76+ .. code-block:: python
77+
78+ PromptReplacement(
79+ modality="image",
80+ target="<image>",
81+ replacement="<image>" * image_feature_size,
82+ )
83+
84+ As above, but further pad the feature placeholders with ``<image_bos>``
85+ and `<image_eos>``, which are not supposed to be passed to the vision
86+ encoder:
87+
88+ .. code-block:: python
89+
90+ PromptReplacement(
91+ modality="image",
92+ target="<image>",
93+ replacement=PromptReplacementDetails(
94+ full="".join([
95+ "<image_bos>",
96+ "<image>" * image_feature_size,
97+ "<image_eos>",
98+ ]),
99+ features="<image>" * image_feature_size,
100+ ),
101+ )
102+
103+ To avoid unnecessary tokenization during prompt replacement,
104+ we recommended passing token sequences instead of text:
105+
106+ .. code-block:: python
107+
108+ PromptReplacement(
109+ modality="image",
110+ target=[image_token_id],
111+ replacement=PromptReplacementDetails(
112+ full=([image_bos_id] + [image_token_id] * image_feature_size
113+ + [image_eos_id]),
114+ features=[image_token_id] * image_feature_size,
115+ ),
116+ )
57117 """
58118
59119 modality : str
60120 """The modality for which the replacement is made."""
61121
62- target : _PromptSeq
122+ target : PromptSeq
63123 """The token sequence (or text) to find and replace."""
64124
65- replacement : Union [Callable [[int ], _PromptRepl ],
66- _PromptRepl ] = field (repr = False )
125+ replacement : Union [Callable [[int ], PromptRepl ],
126+ PromptRepl ] = field (repr = False )
67127 """
68128 Given the index of the processed item within :attr:`modality`,
69129 output the replacement token sequence (or text).
@@ -126,6 +186,10 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
126186
127187@dataclass
128188class _BoundPromptSequence :
189+ """
190+ A :data:`_PromptSeq` bound to a tokenizer to automatically
191+ convert between token sequence and text representations.
192+ """
129193 tokenizer : AnyTokenizer = field (repr = False )
130194
131195 _text : Optional [str ]
@@ -134,7 +198,7 @@ class _BoundPromptSequence:
134198 @staticmethod
135199 def from_seq (
136200 tokenizer : AnyTokenizer ,
137- seq : _PromptSeq ,
201+ seq : PromptSeq ,
138202 ) -> "_BoundPromptSequence" :
139203 return _BoundPromptSequence (
140204 tokenizer = tokenizer ,
@@ -180,9 +244,9 @@ class BoundPromptReplacement:
180244 tokenizer : AnyTokenizer = field (repr = False )
181245 modality : str
182246
183- _target : _PromptSeq
184- _replacement : Union [Callable [[int ], _PromptRepl ],
185- _PromptRepl ] = field (repr = False )
247+ _target : PromptSeq
248+ _replacement : Union [Callable [[int ], PromptRepl ],
249+ PromptRepl ] = field (repr = False )
186250
187251 def __post_init__ (self ) -> None :
188252 self ._replacement_cache = dict [int , _BoundPromptReplacementGroup ]()
@@ -350,7 +414,7 @@ def find_text_matches(
350414
351415
352416def _resolve_matches (
353- prompt : _PromptSeq ,
417+ prompt : PromptSeq ,
354418 mm_matches : Mapping [str , Sequence [_PromptReplacementMatch ]],
355419) -> list [_PromptReplacementMatch ]:
356420 """
0 commit comments