From 98ea814bcebc0bd8f159fc669f7d75612c366350 Mon Sep 17 00:00:00 2001 From: Jialei Date: Fri, 23 Dec 2022 13:26:55 +0800 Subject: [PATCH] chore(sdk): support base uri for serving behind controller gateway --- client/starwhale/api/_impl/service.py | 21 ++++++++----- .../starwhale/mlops/common/ProxyServlet.java | 7 ++++- .../mlops/domain/job/ModelServingService.java | 7 +++-- .../mlops/common/ProxyServletTest.java | 9 +++--- .../domain/job/ModelServingServiceTest.java | 31 ++++++++++--------- 5 files changed, 46 insertions(+), 29 deletions(-) diff --git a/client/starwhale/api/_impl/service.py b/client/starwhale/api/_impl/service.py index 651dd0fd5c..d810581de8 100644 --- a/client/starwhale/api/_impl/service.py +++ b/client/starwhale/api/_impl/service.py @@ -1,3 +1,4 @@ +import os import json import typing as t import functools @@ -145,6 +146,11 @@ def _check_uri_reserved(uri: str) -> None: raise InvalidUriException(f"{uri} is reserved, try using another URI") +def _with_base_uri(uri: str) -> str: + base = os.environ.get("SW_MODEL_SERVING_BASE_URI", "").strip("/") + return "/" + "/".join(filter(bool, [base, uri.lstrip("/")])) + + class Service: def __init__(self) -> None: self.apis: t.Dict[str, Api] = {} @@ -176,7 +182,8 @@ def get_spec(self) -> OpenApi: spec = i.input.spec() resp = i.output.spec().responses spec.responses = resp - paths[i.uri if i.uri.startswith("/") else "/" + i.uri] = {"post": spec} + uri = i.uri if i.uri.startswith("/") else "/" + i.uri + paths[_with_base_uri(uri)] = {"post": spec} return OpenApi( openapi="3.0.0", info={ @@ -192,7 +199,7 @@ def _api_spec_handler(self) -> bytes: @staticmethod def _doc_handler() -> str: - return """ + return f""" @@ -209,12 +216,12 @@ def _doc_handler() -> str:
""" diff --git a/server/controller/src/main/java/ai/starwhale/mlops/common/ProxyServlet.java b/server/controller/src/main/java/ai/starwhale/mlops/common/ProxyServlet.java index ac4f28e653..14c9942c50 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/common/ProxyServlet.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/common/ProxyServlet.java @@ -16,6 +16,8 @@ package ai.starwhale.mlops.common; +import static ai.starwhale.mlops.domain.job.ModelServingService.MODEL_SERVICE_PREFIX; + import ai.starwhale.mlops.domain.job.ModelServingService; import ai.starwhale.mlops.domain.job.mapper.ModelServingMapper; import java.io.IOException; @@ -40,7 +42,6 @@ public class ProxyServlet extends HttpServlet { protected ModelServingMapper modelServingMapper; protected HttpClient httpClient; - public static final String MODEL_SERVICE_PREFIX = "model-serving"; public ProxyServlet(ModelServingMapper modelServingMapper) { this.modelServingMapper = modelServingMapper; @@ -98,6 +99,10 @@ protected HttpRequest generateRequest(HttpServletRequest req, String uri) throws protected void generateResponse(HttpResponse origin, HttpServletResponse resp) throws IOException { var code = origin.getStatusLine().getStatusCode(); resp.setStatus(code); + var headers = origin.getAllHeaders(); + for (var header : headers) { + resp.addHeader(header.getName(), header.getValue()); + } var entity = origin.getEntity(); if (entity == null) { return; diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java index 427d8c2615..abab3f0777 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java @@ -70,6 +70,8 @@ public class ModelServingService { private final ModelServingTokenValidator modelServingTokenValidator; + public static final String MODEL_SERVICE_PREFIX = "model-serving"; + public ModelServingService( ModelServingMapper modelServingMapper, RuntimeDao runtimeDao, @@ -168,7 +170,8 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri "SW_PROJECT", project, "SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl(), "SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl(), - "SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost() + "SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost(), + "SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id) ); var ss = k8sJobTemplate.renderModelServingOrch(envs, image, name); k8sClient.deployStatefulSet(ss); @@ -193,6 +196,6 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri } public static String getServiceName(long id) { - return String.format("model-serving-%d", id); + return String.format("%s-%d", MODEL_SERVICE_PREFIX, id); } } diff --git a/server/controller/src/test/java/ai/starwhale/mlops/common/ProxyServletTest.java b/server/controller/src/test/java/ai/starwhale/mlops/common/ProxyServletTest.java index 20806f7bef..a64a90d856 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/common/ProxyServletTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/common/ProxyServletTest.java @@ -16,6 +16,7 @@ package ai.starwhale.mlops.common; +import static ai.starwhale.mlops.domain.job.ModelServingService.MODEL_SERVICE_PREFIX; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.intThat; @@ -55,17 +56,17 @@ public void testGetTarget() { long id = 1L; when(modelServingMapper.find(id)).thenReturn(ModelServingEntity.builder().build()); - var uri = String.format("/%s/%d/ppl", ProxyServlet.MODEL_SERVICE_PREFIX, id); + var uri = String.format("/%s/%d/ppl", MODEL_SERVICE_PREFIX, id); var rt = proxyServlet.getTarget(uri); Assertions.assertEquals("http://model-serving-1/ppl", rt); - var tooShort = ProxyServlet.MODEL_SERVICE_PREFIX + "/1"; + var tooShort = MODEL_SERVICE_PREFIX + "/1"; Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(tooShort)); var wrongStartsWith = "/foo/1/ppl"; Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(wrongStartsWith)); - var notFound = String.format("/%s/%d/ppl", ProxyServlet.MODEL_SERVICE_PREFIX, id + 1); + var notFound = String.format("/%s/%d/ppl", MODEL_SERVICE_PREFIX, id + 1); Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(notFound)); } @@ -74,7 +75,7 @@ public void testService() throws ServletException, IOException { proxyServlet.init(); var req = mock(HttpServletRequest.class); - var uri = String.format("/%s/1/ppl", ProxyServlet.MODEL_SERVICE_PREFIX); + var uri = String.format("/%s/1/ppl", MODEL_SERVICE_PREFIX); when(req.getPathInfo()).thenReturn(uri); when(req.getMethod()).thenReturn("GET"); var inputStream = mock(ServletInputStream.class); diff --git a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java index 8c53522564..580ed78625 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/domain/job/ModelServingServiceTest.java @@ -50,20 +50,20 @@ public class ModelServingServiceTest { private ModelServingService svc; - private ModelServingMapper modelServingMapper = mock(ModelServingMapper.class); - private RuntimeDao runtimeDao = mock(RuntimeDao.class); - private ProjectManager projectManager = mock(ProjectManager.class); - private ModelDao modelDao = mock(ModelDao.class); - private UserService userService = mock(UserService.class); - private K8sClient k8sClient = mock(K8sClient.class); - private K8sJobTemplate k8sJobTemplate = mock(K8sJobTemplate.class); - private RuntimeMapper runtimeMapper = mock(RuntimeMapper.class); - private RuntimeVersionMapper runtimeVersionMapper = mock(RuntimeVersionMapper.class); - private ModelMapper modelMapper = mock(ModelMapper.class); - private ModelVersionMapper modelVersionMapper = mock(ModelVersionMapper.class); - private SystemSettingService systemSettingService = mock(SystemSettingService.class); - private RunTimeProperties runTimeProperties = mock(RunTimeProperties.class); - private ModelServingTokenValidator modelServingTokenValidator = mock(ModelServingTokenValidator.class); + private final ModelServingMapper modelServingMapper = mock(ModelServingMapper.class); + private final RuntimeDao runtimeDao = mock(RuntimeDao.class); + private final ProjectManager projectManager = mock(ProjectManager.class); + private final ModelDao modelDao = mock(ModelDao.class); + private final UserService userService = mock(UserService.class); + private final K8sClient k8sClient = mock(K8sClient.class); + private final K8sJobTemplate k8sJobTemplate = mock(K8sJobTemplate.class); + private final RuntimeMapper runtimeMapper = mock(RuntimeMapper.class); + private final RuntimeVersionMapper runtimeVersionMapper = mock(RuntimeVersionMapper.class); + private final ModelMapper modelMapper = mock(ModelMapper.class); + private final ModelVersionMapper modelVersionMapper = mock(ModelVersionMapper.class); + private final SystemSettingService systemSettingService = mock(SystemSettingService.class); + private final RunTimeProperties runTimeProperties = mock(RunTimeProperties.class); + private final ModelServingTokenValidator modelServingTokenValidator = mock(ModelServingTokenValidator.class); @BeforeEach public void setUp() { @@ -131,7 +131,8 @@ public void testCreate() throws ApiException { "SW_TOKEN", "token", "SW_INSTANCE_URI", "inst", "SW_MODEL_VERSION", "md/version/9", - "SW_RUNTIME_VERSION", "rt/version/8" + "SW_RUNTIME_VERSION", "rt/version/8", + "SW_MODEL_SERVING_BASE_URI", "/gateway/model-serving/7" ), "img", "model-serving-7"); verify(k8sClient).deployService(any());