3
3
import pytest
4
4
5
5
import replicate
6
+ from replicate .exceptions import ReplicateError
6
7
from replicate .stream import ServerSentEvent
7
8
8
9
skip_if_no_token = pytest .mark .skipif (
@@ -21,22 +22,28 @@ async def test_stream(async_flag, record_mode):
21
22
22
23
events = []
23
24
24
- if async_flag :
25
- async for event in await replicate .async_stream (
26
- model ,
27
- input = input ,
28
- ):
29
- events .append (event )
30
- else :
31
- for event in replicate .stream (
32
- model ,
33
- input = input ,
34
- ):
35
- events .append (event )
25
+ try :
26
+ if async_flag :
27
+ async for event in await replicate .async_stream (
28
+ model ,
29
+ input = input ,
30
+ ):
31
+ events .append (event )
32
+ else :
33
+ for event in replicate .stream (
34
+ model ,
35
+ input = input ,
36
+ ):
37
+ events .append (event )
36
38
37
- assert len (events ) > 0
38
- assert any (event .event == ServerSentEvent .EventType .OUTPUT for event in events )
39
- assert any (event .event == ServerSentEvent .EventType .DONE for event in events )
39
+ assert len (events ) > 0
40
+ assert any (event .event == ServerSentEvent .EventType .OUTPUT for event in events )
41
+ assert any (event .event == ServerSentEvent .EventType .DONE for event in events )
42
+ except ReplicateError as e :
43
+ if e .status == 401 :
44
+ pytest .skip ("Skipping test due to authentication error" )
45
+ else :
46
+ raise e
40
47
41
48
42
49
@skip_if_no_token
@@ -50,15 +57,21 @@ async def test_stream_prediction(async_flag, record_mode):
50
57
51
58
events = []
52
59
53
- if async_flag :
54
- async for event in replicate .predictions .create (
55
- version = version , input = input , stream = True
56
- ).async_stream ():
57
- events .append (event )
58
- else :
59
- for event in replicate .predictions .create (
60
- version = version , input = input , stream = True
61
- ).stream ():
62
- events .append (event )
60
+ try :
61
+ if async_flag :
62
+ async for event in replicate .predictions .create (
63
+ version = version , input = input , stream = True
64
+ ).async_stream ():
65
+ events .append (event )
66
+ else :
67
+ for event in replicate .predictions .create (
68
+ version = version , input = input , stream = True
69
+ ).stream ():
70
+ events .append (event )
63
71
64
- assert len (events ) > 0
72
+ assert len (events ) > 0
73
+ except ReplicateError as e :
74
+ if e .status == 401 :
75
+ pytest .skip ("Skipping test due to authentication error" )
76
+ else :
77
+ raise e
0 commit comments