Skip to content

Commit

Permalink
lora
Browse files Browse the repository at this point in the history
  • Loading branch information
gokayfem committed Oct 17, 2024
1 parent b0c9ddf commit 845892e
Showing 1 changed file with 78 additions and 4 deletions.
82 changes: 78 additions & 4 deletions nodes/image_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,65 @@ def generate_image(self, prompt, image_size, width, height, num_images, safety_t
print(f"Error generating image with FluxPro 1.1: {str(e)}")
return self.create_blank_image()

class FluxLora:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"default": "", "multiline": True}),
"image_size": (["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9", "custom"], {"default": "landscape_4_3"}),
"width": ("INT", {"default": 1024, "min": 512, "max": 1536, "step": 16}),
"height": ("INT", {"default": 768, "min": 512, "max": 1536, "step": 16}),
"num_inference_steps": ("INT", {"default": 28, "min": 1, "max": 50}),
"guidance_scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 20.0, "step": 0.1}),
"num_images": ("INT", {"default": 1, "min": 1, "max": 4}),
"enable_safety_checker": ("BOOLEAN", {"default": True}),
"lora_path_1": ("STRING", {"default": ""}),
"lora_scale_1": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.1}),
},
"optional": {
"seed": ("INT", {"default": -1}),
"lora_path_2": ("STRING", {"default": ""}),
"lora_scale_2": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.1}),
}
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_image"
CATEGORY = "FAL/Image"

def generate_image(self, prompt, image_size, width, height, num_inference_steps, guidance_scale, num_images, enable_safety_checker, lora_path_1, lora_scale_1, seed=-1, lora_path_2="", lora_scale_2=0.7):
arguments = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"enable_safety_checker": enable_safety_checker,
}
if image_size == "custom":
arguments["image_size"] = {"width": width, "height": height}
else:
arguments["image_size"] = image_size
if seed != -1:
arguments["seed"] = seed

# Add LoRAs
loras = []
if lora_path_1:
loras.append({"path": lora_path_1, "scale": lora_scale_1})
if lora_path_2:
loras.append({"path": lora_path_2, "scale": lora_scale_2})
if loras:
arguments["loras"] = loras

try:
handler = submit("fal-ai/flux-lora", arguments=arguments)
result = handler.get()
return self.process_result(result)
except Exception as e:
print(f"Error generating image with FluxLora: {str(e)}")
return self.create_blank_image()

class FluxGeneral:
@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -285,14 +344,18 @@ def INPUT_TYPES(cls):
"control_mask": ("MASK",),
"ip_adapter_image": ("IMAGE",),
"ip_adapter_mask": ("MASK",),
"lora_path_1": ("STRING", {"default": ""}),
"lora_scale_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}),
"lora_path_2": ("STRING", {"default": ""}),
"lora_scale_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}),
}
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_image"
CATEGORY = "FAL/Image"

def generate_image(self, prompt, image_size, width, height, num_inference_steps, guidance_scale, real_cfg_scale, num_images, enable_safety_checker, use_real_cfg, sync_mode, seed=-1, loras="None", lora_scale=1.0, ip_adapter_scale=0.6, controlnet_conditioning_scale=0.6, controlnet_union_control_mode="canny", ip_adapters="None", controlnets="None", controlnet_unions="None", control_image=None, control_mask=None, ip_adapter_image=None, ip_adapter_mask=None):
def generate_image(self, prompt, image_size, width, height, num_inference_steps, guidance_scale, real_cfg_scale, num_images, enable_safety_checker, use_real_cfg, sync_mode, seed=-1, loras="None", lora_scale=1.0, ip_adapter_scale=0.6, controlnet_conditioning_scale=0.6, controlnet_union_control_mode="canny", ip_adapters="None", controlnets="None", controlnet_unions="None", control_image=None, control_mask=None, ip_adapter_image=None, ip_adapter_mask=None, lora_path_1="", lora_scale_1=1.0, lora_path_2="", lora_scale_2=1.0):
arguments = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
Expand Down Expand Up @@ -404,6 +467,15 @@ def generate_image(self, prompt, image_size, width, height, num_inference_steps,
if mask_image_url:
arguments["ip_adapters"][0]["mask_image_url"] = mask_image_url

# Add LoRAs if provided
loras = []
if lora_path_1:
loras.append({"path": lora_path_1, "scale": lora_scale_1})
if lora_path_2:
loras.append({"path": lora_path_2, "scale": lora_scale_2})
if loras:
arguments["loras"] = loras

try:
handler = submit("fal-ai/flux-general", arguments=arguments)
result = handler.get()
Expand Down Expand Up @@ -451,7 +523,8 @@ def create_blank_image(self):
"FluxDev_fal": FluxDev,
"FluxSchnell_fal": FluxSchnell,
"FluxPro11_fal": FluxPro11,
"FluxGeneral_fal": FluxGeneral
"FluxGeneral_fal": FluxGeneral,
"FluxLora_fal": FluxLora
}

# Node display name mappings
Expand All @@ -460,5 +533,6 @@ def create_blank_image(self):
"FluxDev_fal": "Flux Dev (fal)",
"FluxSchnell_fal": "Flux Schnell (fal)",
"FluxPro11_fal": "Flux Pro 1.1 (fal)",
"FluxGeneral_fal": "Flux General (fal)"
}
"FluxGeneral_fal": "Flux General (fal)",
"FluxLora_fal": "Flux LoRA (fal)"
}

0 comments on commit 845892e

Please sign in to comment.