|
23 | 23 |
|
24 | 24 | from openlayer import Openlayer, AsyncOpenlayer, APIResponseValidationError
|
25 | 25 | from openlayer._types import Omit
|
26 |
| -from openlayer._utils import maybe_transform |
27 | 26 | from openlayer._models import BaseModel, FinalRequestOptions
|
28 |
| -from openlayer._constants import RAW_RESPONSE_HEADER |
29 | 27 | from openlayer._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
|
30 | 28 | from openlayer._base_client import (
|
31 | 29 | DEFAULT_TIMEOUT,
|
|
35 | 33 | DefaultAsyncHttpxClient,
|
36 | 34 | make_request_options,
|
37 | 35 | )
|
38 |
| -from openlayer.types.inference_pipelines.data_stream_params import DataStreamParams |
39 | 36 |
|
40 | 37 | from .utils import update_env
|
41 | 38 |
|
@@ -724,82 +721,49 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
|
724 | 721 |
|
725 | 722 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
|
726 | 723 | @pytest.mark.respx(base_url=base_url)
|
727 |
| - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 724 | + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None: |
728 | 725 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
|
729 | 726 | side_effect=httpx.TimeoutException("Test timeout error")
|
730 | 727 | )
|
731 | 728 |
|
732 | 729 | with pytest.raises(APITimeoutError):
|
733 |
| - self.client.post( |
734 |
| - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
735 |
| - body=cast( |
736 |
| - object, |
737 |
| - maybe_transform( |
738 |
| - dict( |
739 |
| - config={ |
740 |
| - "input_variable_names": ["user_query"], |
741 |
| - "output_column_name": "output", |
742 |
| - "num_of_token_column_name": "tokens", |
743 |
| - "cost_column_name": "cost", |
744 |
| - "timestamp_column_name": "timestamp", |
745 |
| - }, |
746 |
| - rows=[ |
747 |
| - { |
748 |
| - "user_query": "what is the meaning of life?", |
749 |
| - "output": "42", |
750 |
| - "tokens": 7, |
751 |
| - "cost": 0.02, |
752 |
| - "timestamp": 1610000000, |
753 |
| - } |
754 |
| - ], |
755 |
| - ), |
756 |
| - DataStreamParams, |
757 |
| - ), |
758 |
| - ), |
759 |
| - cast_to=httpx.Response, |
760 |
| - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
761 |
| - ) |
| 730 | + client.inference_pipelines.data.with_streaming_response.stream( |
| 731 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 732 | + config={"output_column_name": "output"}, |
| 733 | + rows=[ |
| 734 | + { |
| 735 | + "user_query": "bar", |
| 736 | + "output": "bar", |
| 737 | + "tokens": "bar", |
| 738 | + "cost": "bar", |
| 739 | + "timestamp": "bar", |
| 740 | + } |
| 741 | + ], |
| 742 | + ).__enter__() |
762 | 743 |
|
763 | 744 | assert _get_open_connections(self.client) == 0
|
764 | 745 |
|
765 | 746 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
|
766 | 747 | @pytest.mark.respx(base_url=base_url)
|
767 |
| - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 748 | + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Openlayer) -> None: |
768 | 749 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
|
769 | 750 | return_value=httpx.Response(500)
|
770 | 751 | )
|
771 | 752 |
|
772 | 753 | with pytest.raises(APIStatusError):
|
773 |
| - self.client.post( |
774 |
| - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
775 |
| - body=cast( |
776 |
| - object, |
777 |
| - maybe_transform( |
778 |
| - dict( |
779 |
| - config={ |
780 |
| - "input_variable_names": ["user_query"], |
781 |
| - "output_column_name": "output", |
782 |
| - "num_of_token_column_name": "tokens", |
783 |
| - "cost_column_name": "cost", |
784 |
| - "timestamp_column_name": "timestamp", |
785 |
| - }, |
786 |
| - rows=[ |
787 |
| - { |
788 |
| - "user_query": "what is the meaning of life?", |
789 |
| - "output": "42", |
790 |
| - "tokens": 7, |
791 |
| - "cost": 0.02, |
792 |
| - "timestamp": 1610000000, |
793 |
| - } |
794 |
| - ], |
795 |
| - ), |
796 |
| - DataStreamParams, |
797 |
| - ), |
798 |
| - ), |
799 |
| - cast_to=httpx.Response, |
800 |
| - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
801 |
| - ) |
802 |
| - |
| 754 | + client.inference_pipelines.data.with_streaming_response.stream( |
| 755 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 756 | + config={"output_column_name": "output"}, |
| 757 | + rows=[ |
| 758 | + { |
| 759 | + "user_query": "bar", |
| 760 | + "output": "bar", |
| 761 | + "tokens": "bar", |
| 762 | + "cost": "bar", |
| 763 | + "timestamp": "bar", |
| 764 | + } |
| 765 | + ], |
| 766 | + ).__enter__() |
803 | 767 | assert _get_open_connections(self.client) == 0
|
804 | 768 |
|
805 | 769 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
|
@@ -1652,82 +1616,53 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
|
1652 | 1616 |
|
1653 | 1617 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
|
1654 | 1618 | @pytest.mark.respx(base_url=base_url)
|
1655 |
| - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1619 | + async def test_retrying_timeout_errors_doesnt_leak( |
| 1620 | + self, respx_mock: MockRouter, async_client: AsyncOpenlayer |
| 1621 | + ) -> None: |
1656 | 1622 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
|
1657 | 1623 | side_effect=httpx.TimeoutException("Test timeout error")
|
1658 | 1624 | )
|
1659 | 1625 |
|
1660 | 1626 | with pytest.raises(APITimeoutError):
|
1661 |
| - await self.client.post( |
1662 |
| - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
1663 |
| - body=cast( |
1664 |
| - object, |
1665 |
| - maybe_transform( |
1666 |
| - dict( |
1667 |
| - config={ |
1668 |
| - "input_variable_names": ["user_query"], |
1669 |
| - "output_column_name": "output", |
1670 |
| - "num_of_token_column_name": "tokens", |
1671 |
| - "cost_column_name": "cost", |
1672 |
| - "timestamp_column_name": "timestamp", |
1673 |
| - }, |
1674 |
| - rows=[ |
1675 |
| - { |
1676 |
| - "user_query": "what is the meaning of life?", |
1677 |
| - "output": "42", |
1678 |
| - "tokens": 7, |
1679 |
| - "cost": 0.02, |
1680 |
| - "timestamp": 1610000000, |
1681 |
| - } |
1682 |
| - ], |
1683 |
| - ), |
1684 |
| - DataStreamParams, |
1685 |
| - ), |
1686 |
| - ), |
1687 |
| - cast_to=httpx.Response, |
1688 |
| - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1689 |
| - ) |
| 1627 | + await async_client.inference_pipelines.data.with_streaming_response.stream( |
| 1628 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 1629 | + config={"output_column_name": "output"}, |
| 1630 | + rows=[ |
| 1631 | + { |
| 1632 | + "user_query": "bar", |
| 1633 | + "output": "bar", |
| 1634 | + "tokens": "bar", |
| 1635 | + "cost": "bar", |
| 1636 | + "timestamp": "bar", |
| 1637 | + } |
| 1638 | + ], |
| 1639 | + ).__aenter__() |
1690 | 1640 |
|
1691 | 1641 | assert _get_open_connections(self.client) == 0
|
1692 | 1642 |
|
1693 | 1643 | @mock.patch("openlayer._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
|
1694 | 1644 | @pytest.mark.respx(base_url=base_url)
|
1695 |
| - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1645 | + async def test_retrying_status_errors_doesnt_leak( |
| 1646 | + self, respx_mock: MockRouter, async_client: AsyncOpenlayer |
| 1647 | + ) -> None: |
1696 | 1648 | respx_mock.post("/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream").mock(
|
1697 | 1649 | return_value=httpx.Response(500)
|
1698 | 1650 | )
|
1699 | 1651 |
|
1700 | 1652 | with pytest.raises(APIStatusError):
|
1701 |
| - await self.client.post( |
1702 |
| - "/inference-pipelines/182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e/data-stream", |
1703 |
| - body=cast( |
1704 |
| - object, |
1705 |
| - maybe_transform( |
1706 |
| - dict( |
1707 |
| - config={ |
1708 |
| - "input_variable_names": ["user_query"], |
1709 |
| - "output_column_name": "output", |
1710 |
| - "num_of_token_column_name": "tokens", |
1711 |
| - "cost_column_name": "cost", |
1712 |
| - "timestamp_column_name": "timestamp", |
1713 |
| - }, |
1714 |
| - rows=[ |
1715 |
| - { |
1716 |
| - "user_query": "what is the meaning of life?", |
1717 |
| - "output": "42", |
1718 |
| - "tokens": 7, |
1719 |
| - "cost": 0.02, |
1720 |
| - "timestamp": 1610000000, |
1721 |
| - } |
1722 |
| - ], |
1723 |
| - ), |
1724 |
| - DataStreamParams, |
1725 |
| - ), |
1726 |
| - ), |
1727 |
| - cast_to=httpx.Response, |
1728 |
| - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1729 |
| - ) |
1730 |
| - |
| 1653 | + await async_client.inference_pipelines.data.with_streaming_response.stream( |
| 1654 | + inference_pipeline_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", |
| 1655 | + config={"output_column_name": "output"}, |
| 1656 | + rows=[ |
| 1657 | + { |
| 1658 | + "user_query": "bar", |
| 1659 | + "output": "bar", |
| 1660 | + "tokens": "bar", |
| 1661 | + "cost": "bar", |
| 1662 | + "timestamp": "bar", |
| 1663 | + } |
| 1664 | + ], |
| 1665 | + ).__aenter__() |
1731 | 1666 | assert _get_open_connections(self.client) == 0
|
1732 | 1667 |
|
1733 | 1668 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
|
|
0 commit comments