Skip to content

fix: api key optional & add client retry logic #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ MANIFEST
# Installer logs
pip-log.txt
/venv
pypi_token
72 changes: 61 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ prediction = client.run(
"num_images": 1,
"seed": -1,
"enable_safety_checker": True
}
},
)

# Print the generated image URLs
Expand Down Expand Up @@ -100,15 +100,66 @@ prediction = client.create(
print(f"Prediction created with ID: {prediction.id}")
print(f"Initial status: {prediction.status}")

# Later, you can wait for the prediction to complete
result = prediction.wait()
print(f"Final status: {result.status}")
# Wait for the prediction to complete
prediction = prediction.wait()
print(f"Final status: {prediction.status}")

# Print the generated image URLs
for i, img_url in enumerate(result.outputs):
for i, img_url in enumerate(prediction.outputs):
print(f"Image {i+1}: {img_url}")
```

### Using Webhooks

You can use webhooks to receive a notification when your prediction is complete, instead of polling the API:

```python
from wavespeed import WaveSpeed, Options

# Initialize the client with your API key
client = WaveSpeed(api_key="YOUR_API_KEY")

# Create options with a webhook URL
options = Options(webhook="https://your-webhook-url.com/callback")

# Create a prediction with webhook notification
prediction = client.create(
modelId="wavespeed-ai/flux-dev",
input={
"prompt": "A futuristic cityscape with flying cars and neon lights",
"size": "1024*1024",
"num_inference_steps": 28,
"guidance_scale": 5.0,
"num_images": 1,
"seed": -1,
"enable_safety_checker": True
},
options=options
)

print(f"Prediction created with ID: {prediction.id}")
print(f"Status: {prediction.status}")
print(f"You will receive a webhook notification at {options.webhook} when the prediction is complete or failed")

# Your webhook endpoint will receive a POST request with the prediction data
# Example webhook handler (using Flask):
"""
@app.route('/callback', methods=['POST'])
def webhook_callback(prediction: Prediction):
prediction_id = prediction.id
status = prediction.status

if status == 'completed':
outputs = prediction.outputs
# Process the generated images
for img_url in outputs:
# Download or process the image
pass

return '', 200
"""
```

## Command Line Examples

The package includes several example scripts that you can use to generate images:
Expand Down Expand Up @@ -167,31 +218,31 @@ WaveSpeed(api_key)
#### run

```python
run(modelId, input, **kwargs) -> Prediction
run(modelId, input, options=None) -> Prediction
```

Generate an image and wait for the result.

#### async_run

```python
async_run(modelId, input, **kwargs) -> Prediction
async_run(modelId, input, options=None) -> Prediction
```

Asynchronously generate an image and wait for the result.

#### create

```python
create(modelId, input, **kwargs) -> Prediction
create(modelId, input, options=None) -> Prediction
```

Create a prediction without waiting for it to complete.

#### async_create

```python
async_create(modelId, input, **kwargs) -> Prediction
async_create(modelId, input, options=None) -> Prediction
```

Asynchronously create a prediction without waiting for it to complete.
Expand Down Expand Up @@ -227,8 +278,7 @@ await prediction.async_reload() -> Prediction
## Environment Variables

- `WAVESPEED_API_KEY`: Your WaveSpeed API key
- `WAVESPEED_POLL_INTERVAL`: Interval in seconds for polling prediction status (default: 1)
- `WAVESPEED_TIMEOUT`: Timeout in seconds for API requests (default: 60)
- `WAVESPEED_POLL_INTERVAL`: Interval in seconds for polling prediction status (default: 0.5)

## License

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "wavespeed"
version = "0.0.3"
version = "0.0.5"
description = "Python client for WaveSpeed AI"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
184 changes: 183 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
Tests for the Wavespeed client.
"""

import unittest
import pytest
import respx
from httpx import Response
import httpx
from datetime import datetime
import unittest.mock

from wavespeed.client import WaveSpeed
from wavespeed.client import WaveSpeed, RetryTransport
from wavespeed.schemas.prediction import Prediction, PredictionUrls
import logging

logging.basicConfig(level=logging.INFO)


@pytest.fixture
Expand Down Expand Up @@ -112,6 +118,86 @@ def mock_prediction_completed_response():
}


@pytest.fixture
def mock_transport():
"""Create a mock transport for testing retry functionality."""
class MockTransport(httpx.BaseTransport):
def __init__(self):
self.request_count = 0
self.requests = []

def handle_request(self, request):
self.request_count += 1
self.requests.append(request)

# Return 503 for the first two requests, then 200
if self.request_count <= 2:
response = Response(503, json={"error": "Service Unavailable"})
response._request = request # Set the request on the response
return response
else:
response = Response(200, json={
"code": 200,
"message": "Success",
"data": {
"id": "test_prediction_id",
"model": "wavespeed-ai/flux-dev",
"input": {"prompt": "A test prompt"},
"outputs": ["https://example.com/generated_image.jpg"],
"urls": {"get": "https://api.wavespeed.ai/api/v2/predictions/test_prediction_id/result"},
"has_nsfw_contents": [False],
"status": "completed",
"created_at": datetime.now().isoformat(),
"error": "",
"executionTime": 1000
}
})
response._request = request # Set the request on the response
return response

return MockTransport()


@pytest.fixture
def mock_async_transport():
"""Create a mock async transport for testing retry functionality."""
class MockAsyncTransport(httpx.AsyncBaseTransport):
def __init__(self):
self.request_count = 0
self.requests = []

async def handle_async_request(self, request):
self.request_count += 1
self.requests.append(request)

# Return 503 for the first two requests, then 200
if self.request_count <= 2:
response = Response(503, json={"error": "Service Unavailable"})
response._request = request # Set the request on the response
return response
else:
response = Response(200, json={
"code": 200,
"message": "Success",
"data": {
"id": "test_prediction_id",
"model": "wavespeed-ai/flux-dev",
"input": {"prompt": "A test prompt"},
"outputs": ["https://example.com/generated_image.jpg"],
"urls": {"get": "https://api.wavespeed.ai/api/v2/predictions/test_prediction_id/result"},
"has_nsfw_contents": [False],
"status": "completed",
"created_at": datetime.now().isoformat(),
"error": "",
"executionTime": 1000
}
})
response._request = request # Set the request on the response
return response

return MockAsyncTransport()


@respx.mock
def test_run(client, mock_prediction_response, mock_prediction_completed_response):
"""Test the run method."""
Expand Down Expand Up @@ -274,6 +360,7 @@ def test_prediction_wait(client, mock_prediction_in_progress_response, mock_pred
assert result.status == "completed"
assert len(result.outputs) == 1
assert result.outputs[0] == "https://example.com/generated_image.jpg"
assert result.has_nsfw_contents == [False]


@respx.mock
Expand Down Expand Up @@ -313,3 +400,98 @@ async def test_prediction_async_wait(async_client, mock_prediction_in_progress_r
assert result.status == "completed"
assert len(result.outputs) == 1
assert result.outputs[0] == "https://example.com/generated_image.jpg"
assert result.has_nsfw_contents == [False]

@unittest.mock.patch("httpx.HTTPTransport.handle_request")
def test_client_retry_integration(mock_send):
"""Test that the client's retry logic works correctly when integrated with httpx."""
# Create a counter to track the number of requests
request_count = {"count": 0}

# Configure the mock to return 502 for the first request, 503 for the second, then 200
def side_effect(request, **kwargs):
request_count["count"] += 1
if request_count["count"] == 1:
response = Response(502, json={"error": "Bad Gateway"})
elif request_count["count"] == 2:
response = Response(503, json={"error": "Service Unavailable"})
else:
response = Response(200, json={
"code": 200,
"message": "Success",
"data": {
"id": "test_prediction_id",
"model": "wavespeed-ai/flux-dev",
"input": {"prompt": "A test prompt"},
"outputs": ["https://example.com/generated_image.jpg"],
"urls": {"get": "https://api.wavespeed.ai/api/v2/predictions/test_prediction_id/result"},
"has_nsfw_contents": [False],
"status": "completed",
"created_at": datetime.now().isoformat(),
"error": "",
"executionTime": 1000
}
})
response._request = request
return response
mock_send.side_effect = side_effect

# Create a test client
client = WaveSpeed(api_key="test_api_key")

prediction = client.get_prediction("test_prediction_id")
assert prediction.status == "completed"
# Verify that the transport was called the expected number of times (3 total: 2 failures + 1 success)
assert request_count["count"] == 3
assert mock_send.call_count == 3


@pytest.mark.asyncio
@unittest.mock.patch("httpx.AsyncHTTPTransport.handle_async_request")
async def test_client_async_retry_integration(mock_send):
"""Test that the client's async retry logic works correctly when integrated with httpx."""
# Create a counter to track the number of requests
request_count = {"count": 0}

# Configure the mock to return 502 for the first request, 503 for the second, then 200
async def side_effect(request, **kwargs):
request_count["count"] += 1
if request_count["count"] == 1:
response = Response(502, json={"error": "Bad Gateway"})
elif request_count["count"] == 2:
response = Response(503, json={"error": "Service Unavailable"})
else:
response = Response(200, json={
"code": 200,
"message": "Success",
"data": {
"id": "test_prediction_id",
"model": "wavespeed-ai/flux-dev",
"input": {"prompt": "A test prompt"},
"outputs": ["https://example.com/generated_image.jpg"],
"urls": {"get": "https://api.wavespeed.ai/api/v2/predictions/test_prediction_id/result"},
"has_nsfw_contents": [False],
"status": "completed",
"created_at": datetime.now().isoformat(),
"error": "",
"executionTime": 1000
}
})
response._request = request
return response
mock_send.side_effect = side_effect

# Create a test client
client = WaveSpeed(api_key="test_api_key")

# Send the request through the client
prediction = await client.async_get_prediction(
predictionId="test_prediction_id",
)

# Verify the response is successful
assert prediction.status == "completed"

# Verify that the transport was called the expected number of times (3 total: 2 failures + 1 success)
assert request_count["count"] == 3
assert mock_send.call_count == 3
Loading