Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
181 changes: 57 additions & 124 deletions comfy_api_nodes/nodes_pika.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@

from io import BytesIO
import logging
from typing import Optional, TypeVar
from typing import Optional

import torch

from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
EmptyRequest,
HttpMethod,
PollingOperation,
SynchronousOperation,
sync_op,
poll_op,
)
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio

R = TypeVar("R")

PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
Expand All @@ -40,28 +39,18 @@


async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60,
max_poll_attempts=240,
).execute()
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
Expand Down Expand Up @@ -124,23 +113,15 @@ async def execute(
resolution=resolution,
duration=duration,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaTextToVideoNode(IO.ComfyNode):
Expand Down Expand Up @@ -183,29 +164,21 @@ async def execute(
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaScenes(IO.ComfyNode):
Expand Down Expand Up @@ -309,24 +282,16 @@ async def execute(
duration=duration,
aspectRatio=aspect_ratio,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)

return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikAdditionsNode(IO.ComfyNode):
Expand Down Expand Up @@ -383,24 +348,16 @@ async def execute(
negativePrompt=negative_prompt,
seed=seed,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)

return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaSwapsNode(IO.ComfyNode):
Expand Down Expand Up @@ -472,23 +429,15 @@ async def execute(
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaffectsNode(IO.ComfyNode):
Expand Down Expand Up @@ -528,28 +477,20 @@ async def execute(
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaStartEndFrameNode(IO.ComfyNode):
Expand Down Expand Up @@ -592,18 +533,11 @@ async def execute(
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
Expand All @@ -612,9 +546,8 @@ async def execute(
),
files=pika_files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
return await execute_task(initial_operation.video_id, cls)


class PikaApiNodesExtension(ComfyExtension):
Expand Down