|
| 1 | +import os |
| 2 | +import io |
| 3 | +import chainlit as cl |
| 4 | +from stability_sdk import client |
| 5 | +import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation |
| 6 | +from PIL import Image |
| 7 | +from langchain.tools import Tool, StructuredTool |
| 8 | + |
| 9 | +os.environ["STABILITY_HOST"] = "grpc.stability.ai:443" |
| 10 | + |
| 11 | + |
| 12 | +def get_image_name(): |
| 13 | + image_count = cl.user_session.get("image_count") |
| 14 | + if image_count is None: |
| 15 | + image_count = 0 |
| 16 | + else: |
| 17 | + image_count += 1 |
| 18 | + |
| 19 | + cl.user_session.set("image_count", image_count) |
| 20 | + |
| 21 | + return f"image-{image_count}" |
| 22 | + |
| 23 | + |
| 24 | +def _generate_image(prompt: str, init_image=None): |
| 25 | + # Set up our connection to the API. |
| 26 | + stability_api = client.StabilityInference( |
| 27 | + key=os.environ["STABILITY_KEY"], # API Key reference. |
| 28 | + verbose=True, # Print debug messages. |
| 29 | + engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation. |
| 30 | + # Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0 |
| 31 | + # stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0 |
| 32 | + ) |
| 33 | + |
| 34 | + start_schedule = 0.8 if init_image else 1 |
| 35 | + |
| 36 | + # Set up our initial generation parameters. |
| 37 | + answers = stability_api.generate( |
| 38 | + prompt=prompt, |
| 39 | + init_image=init_image, |
| 40 | + start_schedule=start_schedule, |
| 41 | + seed=992446758, # If a seed is provided, the resulting generated image will be deterministic. |
| 42 | + # What this means is that as long as all generation parameters remain the same, you can always recall the same image simply by generating it again. |
| 43 | + # Note: This isn't quite the case for CLIP Guided generations, which we tackle in the CLIP Guidance documentation. |
| 44 | + steps=30, # Amount of inference steps performed on image generation. Defaults to 30. |
| 45 | + cfg_scale=8.0, # Influences how strongly your generation is guided to match your prompt. |
| 46 | + # Setting this value higher increases the strength in which it tries to match your prompt. |
| 47 | + # Defaults to 7.0 if not specified. |
| 48 | + width=512, # Generation width, defaults to 512 if not included. |
| 49 | + height=512, # Generation height, defaults to 512 if not included. |
| 50 | + samples=1, # Number of images to generate, defaults to 1 if not included. |
| 51 | + sampler=generation.SAMPLER_K_DPMPP_2M # Choose which sampler we want to denoise our generation with. |
| 52 | + # Defaults to k_dpmpp_2m if not specified. Clip Guidance only supports ancestral samplers. |
| 53 | + # (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde) |
| 54 | + ) |
| 55 | + |
| 56 | + # Set up our warning to print to the console if the adult content classifier is tripped. |
| 57 | + # If adult content classifier is not tripped, save generated images. |
| 58 | + for resp in answers: |
| 59 | + for artifact in resp.artifacts: |
| 60 | + if artifact.finish_reason == generation.FILTER: |
| 61 | + raise ValueError( |
| 62 | + "Your request activated the API's safety filters and could not be processed." |
| 63 | + "Please modify the prompt and try again." |
| 64 | + ) |
| 65 | + if artifact.type == generation.ARTIFACT_IMAGE: |
| 66 | + name = get_image_name() |
| 67 | + cl.user_session.set(name, artifact.binary) |
| 68 | + cl.user_session.set("generated_image", name) |
| 69 | + return name |
| 70 | + else: |
| 71 | + raise ValueError( |
| 72 | + f"Your request did not generate an image. Please modify the prompt and try again. Finish reason: {artifact.finish_reason}" |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def generate_image(prompt: str): |
| 77 | + image_name = _generate_image(prompt) |
| 78 | + return f"Here is {image_name}." |
| 79 | + |
| 80 | + |
| 81 | +def edit_image(init_image_name: str, prompt: str): |
| 82 | + init_image_bytes = cl.user_session.get(init_image_name) |
| 83 | + if init_image_bytes is None: |
| 84 | + raise ValueError(f"Could not find image `{init_image_name}`.") |
| 85 | + |
| 86 | + init_image = Image.open(io.BytesIO(init_image_bytes)) |
| 87 | + image_name = _generate_image(prompt, init_image) |
| 88 | + |
| 89 | + return f"Here is {image_name} based on {init_image_name}." |
| 90 | + |
| 91 | + |
| 92 | +generate_image_tool = Tool.from_function( |
| 93 | + func=generate_image, |
| 94 | + name="GenerateImage", |
| 95 | + description="Useful to create an image from a text prompt.", |
| 96 | + return_direct=True, |
| 97 | +) |
| 98 | + |
| 99 | +edit_image_tool = StructuredTool.from_function( |
| 100 | + func=edit_image, |
| 101 | + name="EditImage", |
| 102 | + description="Useful to edit an image with a prompt. Works well with commands such as 'replace', 'add', 'change', 'remove'.", |
| 103 | + return_direct=True, |
| 104 | +) |
0 commit comments