Skip to content

Commit ab7ef6b

Browse files
fix(tests): fix: tests which call HTTP endpoints directly with the example parameters
1 parent 2c30786 commit ab7ef6b

File tree

1 file changed

+60
-125
lines changed

1 file changed

+60
-125
lines changed

tests/test_client.py

Lines changed: 60 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424
from openlayer import Openlayer, AsyncOpenlayer, APIResponseValidationError
2525
from openlayer._types import Omit
26-
from openlayer._utils import maybe_transform
2726
from openlayer._models import BaseModel, FinalRequestOptions
28-
from openlayer._constants import RAW_RESPONSE_HEADER
2927
from openlayer._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
3028
from openlayer._base_client import (
3129
DEFAULT_TIMEOUT,
@@ -35,7 +33,6 @@
3533
DefaultAsyncHttpxClient,
3634
make_request_options,
3735
)
38-
from openlayer.types.inference_pipelines.data_stream_params import DataStreamParams
3936

4037
from .utils import update_env
4138

@@ -724,82 +721,49 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
724721

725722
@mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
726723
@pytest.mark.respx(base_url=base_url)
727-
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
724+
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None:
728725
respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
729726
side_effect=httpx.TimeoutException("Test timeout error")
730727
)
731728

732729
with pytest.raises(APITimeoutError):
733-
self.client.post(
734-
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream",
735-
body=cast(
736-
object,
737-
maybe_transform(
738-
dict(
739-
config={
740-
"input_variable_names": ["user_query"],
741-
"output_column_name": "output",
742-
"num_of_token_column_name": "tokens",
743-
"cost_column_name": "cost",
744-
"timestamp_column_name": "timestamp",
745-
},
746-
rows=[
747-
{
748-
"user_query": "what is the meaning of life?",
749-
"output": "42",
750-
"tokens": 7,
751-
"cost": 0.02,
752-
"timestamp": 1610000000,
753-
}
754-
],
755-
),
756-
DataStreamParams,
757-
),
758-
),
759-
cast_to=httpx.Response,
760-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
761-
)
730+
client.inference_pipelines.data.with_streaming_response.stream(
731+
inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
732+
config={"output_column_name": "output"},
733+
rows=[
734+
{
735+
"user_query": "bar",
736+
"output": "bar",
737+
"tokens": "bar",
738+
"cost": "bar",
739+
"timestamp": "bar",
740+
}
741+
],
742+
).__enter__()
762743

763744
assert _get_open_connections(self.client) == 0
764745

765746
@mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
766747
@pytest.mark.respx(base_url=base_url)
767-
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
748+
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None:
768749
respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
769750
return_value=httpx.Response(500)
770751
)
771752

772753
with pytest.raises(APIStatusError):
773-
self.client.post(
774-
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream",
775-
body=cast(
776-
object,
777-
maybe_transform(
778-
dict(
779-
config={
780-
"input_variable_names": ["user_query"],
781-
"output_column_name": "output",
782-
"num_of_token_column_name": "tokens",
783-
"cost_column_name": "cost",
784-
"timestamp_column_name": "timestamp",
785-
},
786-
rows=[
787-
{
788-
"user_query": "what is the meaning of life?",
789-
"output": "42",
790-
"tokens": 7,
791-
"cost": 0.02,
792-
"timestamp": 1610000000,
793-
}
794-
],
795-
),
796-
DataStreamParams,
797-
),
798-
),
799-
cast_to=httpx.Response,
800-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
801-
)
802-
754+
client.inference_pipelines.data.with_streaming_response.stream(
755+
inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
756+
config={"output_column_name": "output"},
757+
rows=[
758+
{
759+
"user_query": "bar",
760+
"output": "bar",
761+
"tokens": "bar",
762+
"cost": "bar",
763+
"timestamp": "bar",
764+
}
765+
],
766+
).__enter__()
803767
assert _get_open_connections(self.client) == 0
804768

805769
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@@ -1652,82 +1616,53 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
16521616

16531617
@mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
16541618
@pytest.mark.respx(base_url=base_url)
1655-
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1619+
async def test_retrying_timeout_errors_doesnt_leak(
1620+
self, respx_mock: MockRouter, async_client: AsyncOpenlayer
1621+
) -> None:
16561622
respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
16571623
side_effect=httpx.TimeoutException("Test timeout error")
16581624
)
16591625

16601626
with pytest.raises(APITimeoutError):
1661-
await self.client.post(
1662-
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream",
1663-
body=cast(
1664-
object,
1665-
maybe_transform(
1666-
dict(
1667-
config={
1668-
"input_variable_names": ["user_query"],
1669-
"output_column_name": "output",
1670-
"num_of_token_column_name": "tokens",
1671-
"cost_column_name": "cost",
1672-
"timestamp_column_name": "timestamp",
1673-
},
1674-
rows=[
1675-
{
1676-
"user_query": "what is the meaning of life?",
1677-
"output": "42",
1678-
"tokens": 7,
1679-
"cost": 0.02,
1680-
"timestamp": 1610000000,
1681-
}
1682-
],
1683-
),
1684-
DataStreamParams,
1685-
),
1686-
),
1687-
cast_to=httpx.Response,
1688-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1689-
)
1627+
await async_client.inference_pipelines.data.with_streaming_response.stream(
1628+
inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
1629+
config={"output_column_name": "output"},
1630+
rows=[
1631+
{
1632+
"user_query": "bar",
1633+
"output": "bar",
1634+
"tokens": "bar",
1635+
"cost": "bar",
1636+
"timestamp": "bar",
1637+
}
1638+
],
1639+
).__aenter__()
16901640

16911641
assert _get_open_connections(self.client) == 0
16921642

16931643
@mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
16941644
@pytest.mark.respx(base_url=base_url)
1695-
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1645+
async def test_retrying_status_errors_doesnt_leak(
1646+
self, respx_mock: MockRouter, async_client: AsyncOpenlayer
1647+
) -> None:
16961648
respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
16971649
return_value=httpx.Response(500)
16981650
)
16991651

17001652
with pytest.raises(APIStatusError):
1701-
await self.client.post(
1702-
"/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream",
1703-
body=cast(
1704-
object,
1705-
maybe_transform(
1706-
dict(
1707-
config={
1708-
"input_variable_names": ["user_query"],
1709-
"output_column_name": "output",
1710-
"num_of_token_column_name": "tokens",
1711-
"cost_column_name": "cost",
1712-
"timestamp_column_name": "timestamp",
1713-
},
1714-
rows=[
1715-
{
1716-
"user_query": "what is the meaning of life?",
1717-
"output": "42",
1718-
"tokens": 7,
1719-
"cost": 0.02,
1720-
"timestamp": 1610000000,
1721-
}
1722-
],
1723-
),
1724-
DataStreamParams,
1725-
),
1726-
),
1727-
cast_to=httpx.Response,
1728-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1729-
)
1730-
1653+
await async_client.inference_pipelines.data.with_streaming_response.stream(
1654+
inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e",
1655+
config={"output_column_name": "output"},
1656+
rows=[
1657+
{
1658+
"user_query": "bar",
1659+
"output": "bar",
1660+
"tokens": "bar",
1661+
"cost": "bar",
1662+
"timestamp": "bar",
1663+
}
1664+
],
1665+
).__aenter__()
17311666
assert _get_open_connections(self.client) == 0
17321667

17331668
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])

0 commit comments

Comments
 (0)