Skip to content

Commit

Permalink
rename examples to test_examples (#255)
Browse files Browse the repository at this point in the history
* rename examples to test_examples

* rename

* fix tests

* fix tests

* fix tests
  • Loading branch information
aniketmaurya authored Aug 30, 2024
1 parent f2f70ea commit 8ec911d
Show file tree
Hide file tree
Showing 17 changed files with 49 additions and 49 deletions.
6 changes: 3 additions & 3 deletions src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from litserve.__about__ import * # noqa: F401, F403
from litserve.__about__ import * # noqa: F403
from litserve.api import LitAPI
from litserve.server import LitServer, Request, Response
from litserve import examples
from litserve import test_examples
from litserve.specs.openai import OpenAISpec

__all__ = ["LitAPI", "LitServer", "Request", "Response", "examples", "OpenAISpec"]
__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec"]
2 changes: 1 addition & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
_msg = (
"middlewares must be a list of tuples"
" where each tuple contains a middleware and its arguments. For example:\n"
"server = ls.LitServer(ls.examples.SimpleLitAPI(), "
"server = ls.LitServer(ls.test_examples.SimpleLitAPI(), "
'middlewares=[(RequestIdMiddleware, {"length": 5})])'
)
raise ValueError(_msg)
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/e2e/default_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
import litserve as ls

if __name__ == "__main__":
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api)
server.run(port=8000)
2 changes: 1 addition & 1 deletion tests/e2e/default_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
import litserve as ls

if __name__ == "__main__":
api = ls.examples.SimpleBatchedAPI()
api = ls.test_examples.SimpleBatchedAPI()
server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.05)
server.run(port=8000)
2 changes: 1 addition & 1 deletion tests/e2e/default_openai_with_batching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from litserve.examples.openai_spec_example import OpenAIBatchContext
from litserve.test_examples.openai_spec_example import OpenAIBatchContext

import litserve as ls

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/default_openaispec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from litserve import OpenAISpec
from litserve.examples.openai_spec_example import TestAPI
from litserve.test_examples.openai_spec_example import TestAPI
import litserve as ls

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/default_openaispec_response_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import litserve as ls
from litserve import OpenAISpec
from litserve.examples.openai_spec_example import TestAPIWithStructuredOutput
from litserve.test_examples.openai_spec_example import TestAPIWithStructuredOutput

if __name__ == "__main__":
server = ls.LitServer(TestAPIWithStructuredOutput(), spec=OpenAISpec())
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/default_openaispec_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import litserve as ls
from litserve import OpenAISpec
from litserve.examples.openai_spec_example import TestAPI
from litserve.test_examples.openai_spec_example import TestAPI
from litserve.specs.openai import ChatMessage


Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/default_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from litserve.examples.openai_spec_example import TestAPI
from litserve.test_examples.openai_spec_example import TestAPI
from litserve.specs.openai import OpenAISpec

import litserve as ls
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_batched_loop():
],
)
def test_collate_requests(batch_timeout, batch_size):
api = ls.examples.SimpleBatchedAPI()
api = ls.test_examples.SimpleBatchedAPI()
api.request_timeout = 5
request_queue = Queue()
for i in range(batch_size):
Expand Down
14 changes: 7 additions & 7 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
from asgi_lifespan import LifespanManager
from httpx import AsyncClient

from litserve.examples.openai_spec_example import (
from litserve.test_examples.openai_spec_example import (
OpenAIWithUsage,
OpenAIWithUsageEncodeResponse,
OpenAIBatchingWithUsage,
OpenAIBatchContext,
)
from litserve.examples.simple_example import SimpleStreamAPI
from litserve.test_examples.simple_example import SimpleStreamAPI
from litserve.utils import wrap_litserve_start
import litserve as ls


@pytest.mark.asyncio()
async def test_simple_pytorch_api():
api = ls.examples.SimpleTorchAPI()
api = ls.test_examples.SimpleTorchAPI()
server = ls.LitServer(api, accelerator="cpu")
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
Expand All @@ -26,7 +26,7 @@ async def test_simple_pytorch_api():

@pytest.mark.asyncio()
async def test_simple_batched_api():
api = ls.examples.SimpleBatchedAPI()
api = ls.test_examples.SimpleBatchedAPI()
server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.1)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
Expand All @@ -36,7 +36,7 @@ async def test_simple_batched_api():

@pytest.mark.asyncio()
async def test_simple_api():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
Expand All @@ -46,15 +46,15 @@ async def test_simple_api():

@pytest.mark.asyncio()
async def test_simple_api_without_server():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
api.setup(None)
assert api.model is not None, "Model should be loaded after setup"
assert api.predict(4) == 16, "Model should be able to predict"


@pytest.mark.asyncio()
async def test_simple_pytorch_api_without_server():
api = ls.examples.SimpleTorchAPI()
api = ls.test_examples.SimpleTorchAPI()
api.setup("cpu")
assert api.model is not None, "Model should be loaded after setup"
assert isinstance(api.model, torch.nn.Module)
Expand Down
30 changes: 15 additions & 15 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_server_run(mock_uvicorn):
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@patch("litserve.server.uvicorn")
def test_start_server(mock_uvicon):
server = LitServer(ls.examples.TestAPI(), spec=ls.OpenAISpec())
server = LitServer(ls.test_examples.TestAPI(), spec=ls.OpenAISpec())
sockets = MagicMock()
server._start_server(8000, 1, "info", sockets, "process")
mock_uvicon.Server.assert_called()
Expand All @@ -213,7 +213,7 @@ def test_start_server(mock_uvicon):
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@patch("litserve.server.uvicorn")
def test_server_run_with_api_server_worker_type(mock_uvicorn):
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api, devices=1)
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
server.run(api_server_worker_type="invalid")
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_server_run_with_api_server_worker_type(mock_uvicorn):
@pytest.mark.skipif(sys.platform != "win32", reason="Test is only for Windows")
@patch("litserve.server.uvicorn")
def test_server_run_windows(mock_uvicorn):
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api)
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()
Expand All @@ -273,7 +273,7 @@ def test_server_terminate():
mock_manager.shutdown.assert_called()


class IdentityAPI(ls.examples.SimpleLitAPI):
class IdentityAPI(ls.test_examples.SimpleLitAPI):
def predict(self, x, context):
context["input"] = x
return self.model(x)
Expand All @@ -283,7 +283,7 @@ def encode_response(self, output, context):
return {"output": input}


class IdentityBatchedAPI(ls.examples.SimpleBatchedAPI):
class IdentityBatchedAPI(ls.test_examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
for c, x in zip(context, x_batch):
c["input"] = x
Expand All @@ -294,7 +294,7 @@ def encode_response(self, output, context):
return {"output": input}


class IdentityBatchedStreamingAPI(ls.examples.SimpleBatchedAPI):
class IdentityBatchedStreamingAPI(ls.test_examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
for c, x in zip(context, x_batch):
c["input"] = x
Expand All @@ -305,7 +305,7 @@ def encode_response(self, output_stream, context):
yield [{"output": ctx["input"]} for ctx in context]


class PredictErrorAPI(ls.examples.SimpleLitAPI):
class PredictErrorAPI(ls.test_examples.SimpleLitAPI):
def predict(self, x, y, context):
context["input"] = x
return self.model(x)
Expand Down Expand Up @@ -354,16 +354,16 @@ def dummy_load_and_raise(resp):

def test_custom_api_path():
with pytest.raises(ValueError, match="api_path must start with '/'. "):
LitServer(ls.examples.SimpleLitAPI(), api_path="predict")
LitServer(ls.test_examples.SimpleLitAPI(), api_path="predict")

server = LitServer(ls.examples.SimpleLitAPI(), api_path="/v1/custom_predict")
server = LitServer(ls.test_examples.SimpleLitAPI(), api_path="/v1/custom_predict")
url = server.api_path
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post(url, json={"input": 4.0})
assert response.status_code == 200, "Server response should be 200 (OK)"


class TestHTTPExceptionAPI(ls.examples.SimpleLitAPI):
class TestHTTPExceptionAPI(ls.test_examples.SimpleLitAPI):
def decode_request(self, request):
raise HTTPException(501, "decode request is bad")

Expand All @@ -389,7 +389,7 @@ async def dispatch(self, request, call_next):


def test_custom_middleware():
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})])
server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[(RequestIdMiddleware, {"length": 5})])
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
Expand All @@ -407,7 +407,7 @@ def test_starlette_middlewares():
),
HTTPSRedirectMiddleware,
]
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=middlewares)
server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=middlewares)
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"})
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
Expand All @@ -421,11 +421,11 @@ def test_middlewares_inputs():
server = ls.LitServer(SimpleLitAPI(), middlewares=[])
assert len(server.middlewares) == 1, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000)
server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000)
assert len(server.middlewares) == 2, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=None)
server = ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=None)
assert len(server.middlewares) == 1, "Default middleware should be present"

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))
ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))
28 changes: 14 additions & 14 deletions tests/test_litapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,33 @@ def test_batch_unbatch_stream():

def test_decode_request():
request = {"input": 4.0}
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
assert api.decode_request(request) == 4.0, "Decode request should return the input 4.0"


def test_decode_request_with_openai_spec():
api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())
request = ChatCompletionRequest(messages=[{"role": "system", "content": "Hello"}])
decoded_request = api.decode_request(request)
assert decoded_request[0]["content"] == "Hello", "Decode request should return the input message"


def test_decode_request_with_openai_spec_wrong_request():
api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())
with pytest.raises(AttributeError, match="object has no attribute 'messages'"):
api.decode_request({"input": "Hello"})


def test_encode_response():
response = 4.0
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
assert api.encode_response(response) == {"output": 4.0}, 'Encode response returns encoded output {"output": 4.0}'


def test_encode_response_with_openai_spec():
api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())
response = "This is a LLM generated text".split()
generated_tokens = []
Expand All @@ -164,7 +164,7 @@ def predict():
yield {"content": token, "prompt_tokens": 4, "completion_tokens": 4, "total_tokens": 8}

generated_tokens = []
api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())

for output in api.encode_response(predict()):
Expand All @@ -174,12 +174,12 @@ def predict():


def test_encode_response_with_custom_spec_api():
class CustomSpecAPI(ls.examples.TestAPI):
class CustomSpecAPI(ls.test_examples.TestAPI):
def encode_response(self, output_stream):
for output in output_stream:
yield {"content": output}

api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=CustomSpecAPI())
response = "This is a LLM generated text".split()
generated_tokens = []
Expand All @@ -189,7 +189,7 @@ def encode_response(self, output_stream):


def test_encode_response_with_openai_spec_invalid_input():
api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())
response = 10
with pytest.raises(TypeError, match="object is not iterable"):
Expand All @@ -200,14 +200,14 @@ def test_encode_response_with_openai_spec_invalid_predict_output():
def predict():
yield {"hello": "world"}

api = ls.examples.TestAPI()
api = ls.test_examples.TestAPI()
api._sanitize(max_batch_size=1, spec=ls.OpenAISpec())
with pytest.raises(HTTPException, match=r"Malformed output from LitAPI.predict"):
next(api.encode_response(predict()))


def test_format_encoded_response():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
sample = {"output": 4.0}
msg = "Format encoded response should return the encoded response as a string"
assert api.format_encoded_response(sample) == '{"output": 4.0}\n', msg
Expand All @@ -225,18 +225,18 @@ class Sample(BaseModel):


def test_batch_torch():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
x = [torch.Tensor([1, 2, 3, 4]), torch.Tensor([5, 6, 7, 8])]
assert torch.all(api.batch(x) == torch.stack(x)), "Batch should stack torch tensors"


def test_batch_numpy():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
x = [np.asarray([1, 2, 3, 4]), np.asarray([5, 6, 7, 8])]
assert np.all(api.batch(x) == np.stack(x)), "Batch should stack Numpy array"


def test_device_property():
api = ls.examples.SimpleLitAPI()
api = ls.test_examples.SimpleLitAPI()
api.device = "cpu"
assert api.device == "cpu"
2 changes: 1 addition & 1 deletion tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
from httpx import AsyncClient
from litserve.examples.openai_spec_example import (
from litserve.test_examples.openai_spec_example import (
OpenAIBatchingWithUsage,
OpenAIWithUsage,
OpenAIWithUsageEncodeResponse,
Expand Down

0 comments on commit 8ec911d

Please sign in to comment.