Skip to content

Commit 3cc1974

Browse files
committed
Add file_encoding_strategy parameter and upload files by default
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 5f02ee9 commit 3cc1974

File tree

4 files changed

+90
-11
lines changed

4 files changed

+90
-11
lines changed

replicate/deployment.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
23

34
from typing_extensions import Unpack, deprecated
@@ -419,8 +420,14 @@ def create(
419420
Create a new prediction with the deployment.
420421
"""
421422

423+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
422424
if input is not None:
423-
input = encode_json(input, upload_file=upload_file)
425+
input = encode_json(
426+
input,
427+
upload_file=upload_file
428+
if file_encoding_strategy == "base64"
429+
else lambda file: self._client.files.create(file).urls["get"],
430+
)
424431
body = _create_prediction_body(version=None, input=input, **params)
425432

426433
resp = self._client._request(
@@ -440,8 +447,16 @@ async def async_create(
440447
Create a new prediction with the deployment.
441448
"""
442449

450+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
443451
if input is not None:
444-
input = encode_json(input, upload_file=upload_file)
452+
input = encode_json(
453+
input,
454+
upload_file=upload_file
455+
if file_encoding_strategy == "base64"
456+
else lambda file: asyncio.get_event_loop()
457+
.run_until_complete(self._client.files.async_create(file))
458+
.urls["get"],
459+
)
445460
body = _create_prediction_body(version=None, input=input, **params)
446461

447462
resp = await self._client._async_request(
@@ -470,8 +485,14 @@ def create(
470485

471486
url = _create_prediction_url_from_deployment(deployment)
472487

488+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
473489
if input is not None:
474-
input = encode_json(input, upload_file=upload_file)
490+
input = encode_json(
491+
input,
492+
upload_file=upload_file
493+
if file_encoding_strategy == "base64"
494+
else lambda file: self._client.files.create(file).urls["get"],
495+
)
475496
body = _create_prediction_body(version=None, input=input, **params)
476497

477498
resp = self._client._request(
@@ -494,8 +515,16 @@ async def async_create(
494515

495516
url = _create_prediction_url_from_deployment(deployment)
496517

518+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
497519
if input is not None:
498-
input = encode_json(input, upload_file=upload_file)
520+
input = encode_json(
521+
input,
522+
upload_file=upload_file
523+
if file_encoding_strategy == "base64"
524+
else lambda file: asyncio.get_event_loop()
525+
.run_until_complete(self._client.files.async_create(file))
526+
.urls["get"],
527+
)
499528
body = _create_prediction_body(version=None, input=input, **params)
500529

501530
resp = await self._client._async_request(

replicate/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload
23

34
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
@@ -394,8 +395,14 @@ def create(
394395

395396
url = _create_prediction_url_from_model(model)
396397

398+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
397399
if input is not None:
398-
input = encode_json(input, upload_file=upload_file)
400+
input = encode_json(
401+
input,
402+
upload_file=upload_file
403+
if file_encoding_strategy == "base64"
404+
else lambda file: self._client.files.create(file).urls["get"],
405+
)
399406
body = _create_prediction_body(version=None, input=input, **params)
400407

401408
resp = self._client._request(
@@ -418,8 +425,16 @@ async def async_create(
418425

419426
url = _create_prediction_url_from_model(model)
420427

428+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
421429
if input is not None:
422-
input = encode_json(input, upload_file=upload_file)
430+
input = encode_json(
431+
input,
432+
upload_file=upload_file
433+
if file_encoding_strategy == "base64"
434+
else lambda file: asyncio.get_event_loop()
435+
.run_until_complete(self._client.files.async_create(file))
436+
.urls["get"],
437+
)
423438
body = _create_prediction_body(version=None, input=input, **params)
424439

425440
resp = await self._client._async_request(

replicate/prediction.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ class CreatePredictionParams(TypedDict):
383383
stream: NotRequired[bool]
384384
"""Enable streaming of prediction output."""
385385

386+
file_encoding_strategy: NotRequired[Literal["upload", "base64"]]
387+
"""The strategy to use for encoding files in the prediction input."""
388+
386389
@overload
387390
def create(
388391
self,
@@ -453,8 +456,14 @@ def create( # type: ignore
453456
**params,
454457
)
455458

459+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
456460
if input is not None:
457-
input = encode_json(input, upload_file=upload_file)
461+
input = encode_json(
462+
input,
463+
upload_file=upload_file
464+
if file_encoding_strategy == "base64"
465+
else lambda file: self._client.files.create(file).urls["get"],
466+
)
458467
body = _create_prediction_body(
459468
version,
460469
input,
@@ -539,8 +548,16 @@ async def async_create( # type: ignore
539548
**params,
540549
)
541550

551+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
542552
if input is not None:
543-
input = encode_json(input, upload_file=upload_file)
553+
input = encode_json(
554+
input,
555+
upload_file=upload_file
556+
if file_encoding_strategy == "base64"
557+
else lambda file: asyncio.get_event_loop()
558+
.run_until_complete(self._client.files.async_create(file))
559+
.urls["get"],
560+
)
544561
body = _create_prediction_body(
545562
version,
546563
input,
@@ -597,6 +614,7 @@ def _create_prediction_body( # pylint: disable=too-many-arguments
597614
webhook_completed: Optional[str] = None,
598615
webhook_events_filter: Optional[List[str]] = None,
599616
stream: Optional[bool] = None,
617+
**_kwargs,
600618
) -> Dict[str, Any]:
601619
body = {}
602620

replicate/training.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import (
23
TYPE_CHECKING,
34
Any,
@@ -220,6 +221,7 @@ class CreateTrainingParams(TypedDict):
220221
webhook: NotRequired[str]
221222
webhook_completed: NotRequired[str]
222223
webhook_events_filter: NotRequired[List[str]]
224+
file_encoding_strategy: NotRequired[Literal["upload", "base64"]]
223225

224226
@overload
225227
def create( # pylint: disable=too-many-arguments
@@ -277,10 +279,16 @@ def create( # type: ignore
277279
if not url:
278280
raise ValueError("model and version or shorthand version must be specified")
279281

282+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
280283
if input is not None:
281-
input = encode_json(input, upload_file=upload_file)
284+
input = encode_json(
285+
input,
286+
upload_file=upload_file
287+
if file_encoding_strategy == "base64"
288+
else lambda file: self._client.files.create(file).urls["get"],
289+
)
282290
body = _create_training_body(input, **params)
283-
291+
284292
resp = self._client._request(
285293
"POST",
286294
url,
@@ -312,8 +320,16 @@ async def async_create(
312320

313321
url = _create_training_url_from_model_and_version(model, version)
314322

323+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
315324
if input is not None:
316-
input = encode_json(input, upload_file=upload_file)
325+
input = encode_json(
326+
input,
327+
upload_file=upload_file
328+
if file_encoding_strategy == "base64"
329+
else lambda file: asyncio.get_event_loop()
330+
.run_until_complete(self._client.files.async_create(file))
331+
.urls["get"],
332+
)
317333
body = _create_training_body(input, **params)
318334

319335
resp = await self._client._async_request(
@@ -366,6 +382,7 @@ def _create_training_body(
366382
webhook: Optional[str] = None,
367383
webhook_completed: Optional[str] = None,
368384
webhook_events_filter: Optional[List[str]] = None,
385+
**_kwargs,
369386
) -> Dict[str, Any]:
370387
body = {}
371388

0 commit comments

Comments
 (0)