Skip to content

Commit 3f36c77

Browse files
Add argument "negative_prompt" huggingface#549
1 parent 57b70c5 commit 3f36c77

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def disable_attention_slicing(self):
110110
def __call__(
111111
self,
112112
prompt: Union[str, List[str]],
113+
negative_prompt: Optional[Union[str, List[str]]] = None,
113114
height: Optional[int] = 512,
114115
width: Optional[int] = 512,
115116
num_inference_steps: Optional[int] = 50,
@@ -127,6 +128,8 @@ def __call__(
127128
Args:
128129
prompt (`str` or `List[str]`):
129130
The prompt or prompts to guide the image generation.
131+
negative_prompt (`str` or `List[str]`, *optional*):
132+
The prompt or prompts not to guide the image generation.
130133
height (`int`, *optional*, defaults to 512):
131134
The height in pixels of the generated image.
132135
width (`int`, *optional*, defaults to 512):
@@ -203,9 +206,25 @@ def __call__(
203206
do_classifier_free_guidance = guidance_scale > 1.0
204207
# get unconditional embeddings for classifier free guidance
205208
if do_classifier_free_guidance:
209+
ucond_tokens: List[str]
210+
if negative_prompt is None:
211+
ucond_tokens = [""] * batch_size
212+
elif type(prompt) is not type(negative_prompt):
213+
raise TypeError("`negative_prompt` should be the same type to `prompt`.")
214+
elif isinstance(negative_prompt, str):
215+
ucond_tokens = [negative_prompt] * batch_size
216+
elif batch_size != len(negative_prompt):
217+
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
218+
else:
219+
ucond_tokens = negative_prompt
220+
206221
max_length = text_input.input_ids.shape[-1]
207222
uncond_input = self.tokenizer(
208-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
223+
ucond_tokens,
224+
padding="max_length",
225+
max_length=max_length,
226+
truncation=True,
227+
return_tensors="pt",
209228
)
210229
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
211230

0 commit comments

Comments
 (0)