Skip to content

Commit cd422fc

Browse files
authored
Update integration tests to catch 401 errors (#317)
Follow-up to #316 Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent f21086e commit cd422fc

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

tests/test_stream.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import replicate
6+
from replicate.exceptions import ReplicateError
67
from replicate.stream import ServerSentEvent
78

89
skip_if_no_token = pytest.mark.skipif(
@@ -21,22 +22,28 @@ async def test_stream(async_flag, record_mode):
2122

2223
events = []
2324

24-
if async_flag:
25-
async for event in await replicate.async_stream(
26-
model,
27-
input=input,
28-
):
29-
events.append(event)
30-
else:
31-
for event in replicate.stream(
32-
model,
33-
input=input,
34-
):
35-
events.append(event)
25+
try:
26+
if async_flag:
27+
async for event in await replicate.async_stream(
28+
model,
29+
input=input,
30+
):
31+
events.append(event)
32+
else:
33+
for event in replicate.stream(
34+
model,
35+
input=input,
36+
):
37+
events.append(event)
3638

37-
assert len(events) > 0
38-
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
39-
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)
39+
assert len(events) > 0
40+
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
41+
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)
42+
except ReplicateError as e:
43+
if e.status == 401:
44+
pytest.skip("Skipping test due to authentication error")
45+
else:
46+
raise e
4047

4148

4249
@skip_if_no_token
@@ -50,15 +57,21 @@ async def test_stream_prediction(async_flag, record_mode):
5057

5158
events = []
5259

53-
if async_flag:
54-
async for event in replicate.predictions.create(
55-
version=version, input=input, stream=True
56-
).async_stream():
57-
events.append(event)
58-
else:
59-
for event in replicate.predictions.create(
60-
version=version, input=input, stream=True
61-
).stream():
62-
events.append(event)
60+
try:
61+
if async_flag:
62+
async for event in replicate.predictions.create(
63+
version=version, input=input, stream=True
64+
).async_stream():
65+
events.append(event)
66+
else:
67+
for event in replicate.predictions.create(
68+
version=version, input=input, stream=True
69+
).stream():
70+
events.append(event)
6371

64-
assert len(events) > 0
72+
assert len(events) > 0
73+
except ReplicateError as e:
74+
if e.status == 401:
75+
pytest.skip("Skipping test due to authentication error")
76+
else:
77+
raise e

0 commit comments

Comments
 (0)