@@ -110,6 +110,7 @@ def disable_attention_slicing(self):
110
110
def __call__ (
111
111
self ,
112
112
prompt : Union [str , List [str ]],
113
+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
113
114
height : Optional [int ] = 512 ,
114
115
width : Optional [int ] = 512 ,
115
116
num_inference_steps : Optional [int ] = 50 ,
@@ -127,6 +128,8 @@ def __call__(
127
128
Args:
128
129
prompt (`str` or `List[str]`):
129
130
The prompt or prompts to guide the image generation.
131
+ negative_prompt (`str` or `List[str]`, *optional*):
132
+ The prompt or prompts not to guide the image generation.
130
133
height (`int`, *optional*, defaults to 512):
131
134
The height in pixels of the generated image.
132
135
width (`int`, *optional*, defaults to 512):
@@ -203,9 +206,25 @@ def __call__(
203
206
do_classifier_free_guidance = guidance_scale > 1.0
204
207
# get unconditional embeddings for classifier free guidance
205
208
if do_classifier_free_guidance :
209
+ ucond_tokens : List [str ]
210
+ if negative_prompt is None :
211
+ ucond_tokens = ["" ] * batch_size
212
+ elif type (prompt ) is not type (negative_prompt ):
213
+ raise TypeError ("`negative_prompt` should be the same type to `prompt`." )
214
+ elif isinstance (negative_prompt , str ):
215
+ ucond_tokens = [negative_prompt ] * batch_size
216
+ elif batch_size != len (negative_prompt ):
217
+ raise ValueError ("The length of `negative_prompt` should be equal to batch_size." )
218
+ else :
219
+ ucond_tokens = negative_prompt
220
+
206
221
max_length = text_input .input_ids .shape [- 1 ]
207
222
uncond_input = self .tokenizer (
208
- ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "pt"
223
+ ucond_tokens ,
224
+ padding = "max_length" ,
225
+ max_length = max_length ,
226
+ truncation = True ,
227
+ return_tensors = "pt" ,
209
228
)
210
229
uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
211
230
0 commit comments