Skip to content

Commit cd12cf4

Browse files
committed
Correctly handle concatenated output when not streaming
1 parent dd64e91 commit cd12cf4

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

replicate/use.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,18 @@ def output(self) -> O:
368368
),
369369
)
370370

371-
# For non-iterator types, wait for completion and process output
371+
# For non-streaming, wait for completion and process output
372372
self._prediction.wait()
373373

374374
if self._prediction.status == "failed":
375375
raise ModelError(self._prediction)
376376

377+
# Handle concatenate iterators - return joined string
378+
if _has_concatenate_iterator_output_type(self._schema):
379+
if isinstance(self._prediction.output, list):
380+
return "".join(str(item) for item in self._prediction.output)
381+
return self._prediction.output
382+
377383
# Process output for file downloads based on schema
378384
return _process_output_with_schema(self._prediction.output, self._schema)
379385

@@ -529,12 +535,18 @@ async def output(self) -> O:
529535
),
530536
)
531537

532-
# For non-iterator types, wait for completion and process output
538+
# For non-streaming, wait for completion and process output
533539
await self._prediction.async_wait()
534540

535541
if self._prediction.status == "failed":
536542
raise ModelError(self._prediction)
537543

544+
# Handle concatenate iterators - return joined string
545+
if _has_concatenate_iterator_output_type(self._schema):
546+
if isinstance(self._prediction.output, list):
547+
return "".join(str(item) for item in self._prediction.output)
548+
return self._prediction.output
549+
538550
# Process output for file downloads based on schema
539551
return _process_output_with_schema(self._prediction.output, self._schema)
540552

tests/test_use.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,53 @@ async def async_iterator():
625625
assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation
626626

627627

628+
@pytest.mark.asyncio
629+
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
630+
@respx.mock
631+
async def test_use_concatenate_iterator_without_streaming_returns_string(client_mode):
632+
"""Test that concatenate iterator models without streaming=True return final concatenated string."""
633+
mock_model_endpoints(
634+
versions=[
635+
create_mock_version(
636+
{
637+
"openapi_schema": {
638+
"components": {
639+
"schemas": {
640+
"Output": {
641+
"type": "array",
642+
"items": {"type": "string"},
643+
"x-cog-array-type": "iterator",
644+
"x-cog-array-display": "concatenate",
645+
}
646+
}
647+
}
648+
}
649+
}
650+
)
651+
]
652+
)
653+
mock_prediction_endpoints(
654+
predictions=[
655+
create_mock_prediction(),
656+
create_mock_prediction(
657+
{"status": "succeeded", "output": ["Hello", " ", "world", "!"]}
658+
),
659+
]
660+
)
661+
662+
# Call use with "acme/hotdog-detector" WITHOUT streaming=True (default behavior)
663+
hotdog_detector = replicate.use(
664+
"acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC
665+
)
666+
667+
if client_mode == ClientMode.ASYNC:
668+
output = await hotdog_detector(prompt="hello world")
669+
else:
670+
output = hotdog_detector(prompt="hello world")
671+
672+
assert output == "Hello world!"
673+
674+
628675
@pytest.mark.asyncio
629676
@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC])
630677
@respx.mock

0 commit comments

Comments
 (0)