Skip to content

Commit

Permalink
chore(sdk): support base uri for serving behind controller gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Dec 23, 2022
1 parent 55f8e99 commit 98ea814
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 29 deletions.
21 changes: 14 additions & 7 deletions client/starwhale/api/_impl/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import typing as t
import functools
Expand Down Expand Up @@ -145,6 +146,11 @@ def _check_uri_reserved(uri: str) -> None:
raise InvalidUriException(f"{uri} is reserved, try using another URI")


def _with_base_uri(uri: str) -> str:
base = os.environ.get("SW_MODEL_SERVING_BASE_URI", "").strip("/")
return "/" + "/".join(filter(bool, [base, uri.lstrip("/")]))


class Service:
def __init__(self) -> None:
self.apis: t.Dict[str, Api] = {}
Expand Down Expand Up @@ -176,7 +182,8 @@ def get_spec(self) -> OpenApi:
spec = i.input.spec()
resp = i.output.spec().responses
spec.responses = resp
paths[i.uri if i.uri.startswith("/") else "/" + i.uri] = {"post": spec}
uri = i.uri if i.uri.startswith("/") else "/" + i.uri
paths[_with_base_uri(uri)] = {"post": spec}
return OpenApi(
openapi="3.0.0",
info={
Expand All @@ -192,7 +199,7 @@ def _api_spec_handler(self) -> bytes:

@staticmethod
def _doc_handler() -> str:
return """
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
Expand All @@ -209,12 +216,12 @@ def _doc_handler() -> str:
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@4.5.0/swagger-ui-bundle.js" crossorigin></script>
<script>
window.onload = () => {
window.ui = SwaggerUIBundle({
url: '/api-spec',
window.onload = () => {{
window.ui = SwaggerUIBundle({{
url: '{_with_base_uri('/api-spec')}',
dom_id: '#swagger-ui',
});
};
}});
}};
</script>
</body>
</html> """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.starwhale.mlops.common;

import static ai.starwhale.mlops.domain.job.ModelServingService.MODEL_SERVICE_PREFIX;

import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.mapper.ModelServingMapper;
import java.io.IOException;
Expand All @@ -40,7 +42,6 @@
public class ProxyServlet extends HttpServlet {
protected ModelServingMapper modelServingMapper;
protected HttpClient httpClient;
public static final String MODEL_SERVICE_PREFIX = "model-serving";

public ProxyServlet(ModelServingMapper modelServingMapper) {
this.modelServingMapper = modelServingMapper;
Expand Down Expand Up @@ -98,6 +99,10 @@ protected HttpRequest generateRequest(HttpServletRequest req, String uri) throws
protected void generateResponse(HttpResponse origin, HttpServletResponse resp) throws IOException {
var code = origin.getStatusLine().getStatusCode();
resp.setStatus(code);
var headers = origin.getAllHeaders();
for (var header : headers) {
resp.addHeader(header.getName(), header.getValue());
}
var entity = origin.getEntity();
if (entity == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public class ModelServingService {
private final ModelServingTokenValidator modelServingTokenValidator;


public static final String MODEL_SERVICE_PREFIX = "model-serving";

public ModelServingService(
ModelServingMapper modelServingMapper,
RuntimeDao runtimeDao,
Expand Down Expand Up @@ -168,7 +170,8 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri
"SW_PROJECT", project,
"SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl(),
"SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl(),
"SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost()
"SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost(),
"SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id)
);
var ss = k8sJobTemplate.renderModelServingOrch(envs, image, name);
k8sClient.deployStatefulSet(ss);
Expand All @@ -193,6 +196,6 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri
}

public static String getServiceName(long id) {
return String.format("model-serving-%d", id);
return String.format("%s-%d", MODEL_SERVICE_PREFIX, id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package ai.starwhale.mlops.common;

import static ai.starwhale.mlops.domain.job.ModelServingService.MODEL_SERVICE_PREFIX;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.intThat;
Expand Down Expand Up @@ -55,17 +56,17 @@ public void testGetTarget() {
long id = 1L;
when(modelServingMapper.find(id)).thenReturn(ModelServingEntity.builder().build());

var uri = String.format("/%s/%d/ppl", ProxyServlet.MODEL_SERVICE_PREFIX, id);
var uri = String.format("/%s/%d/ppl", MODEL_SERVICE_PREFIX, id);
var rt = proxyServlet.getTarget(uri);
Assertions.assertEquals("http://model-serving-1/ppl", rt);

var tooShort = ProxyServlet.MODEL_SERVICE_PREFIX + "/1";
var tooShort = MODEL_SERVICE_PREFIX + "/1";
Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(tooShort));

var wrongStartsWith = "/foo/1/ppl";
Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(wrongStartsWith));

var notFound = String.format("/%s/%d/ppl", ProxyServlet.MODEL_SERVICE_PREFIX, id + 1);
var notFound = String.format("/%s/%d/ppl", MODEL_SERVICE_PREFIX, id + 1);
Assertions.assertThrows(IllegalArgumentException.class, () -> proxyServlet.getTarget(notFound));
}

Expand All @@ -74,7 +75,7 @@ public void testService() throws ServletException, IOException {
proxyServlet.init();

var req = mock(HttpServletRequest.class);
var uri = String.format("/%s/1/ppl", ProxyServlet.MODEL_SERVICE_PREFIX);
var uri = String.format("/%s/1/ppl", MODEL_SERVICE_PREFIX);
when(req.getPathInfo()).thenReturn(uri);
when(req.getMethod()).thenReturn("GET");
var inputStream = mock(ServletInputStream.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@

public class ModelServingServiceTest {
private ModelServingService svc;
private ModelServingMapper modelServingMapper = mock(ModelServingMapper.class);
private RuntimeDao runtimeDao = mock(RuntimeDao.class);
private ProjectManager projectManager = mock(ProjectManager.class);
private ModelDao modelDao = mock(ModelDao.class);
private UserService userService = mock(UserService.class);
private K8sClient k8sClient = mock(K8sClient.class);
private K8sJobTemplate k8sJobTemplate = mock(K8sJobTemplate.class);
private RuntimeMapper runtimeMapper = mock(RuntimeMapper.class);
private RuntimeVersionMapper runtimeVersionMapper = mock(RuntimeVersionMapper.class);
private ModelMapper modelMapper = mock(ModelMapper.class);
private ModelVersionMapper modelVersionMapper = mock(ModelVersionMapper.class);
private SystemSettingService systemSettingService = mock(SystemSettingService.class);
private RunTimeProperties runTimeProperties = mock(RunTimeProperties.class);
private ModelServingTokenValidator modelServingTokenValidator = mock(ModelServingTokenValidator.class);
private final ModelServingMapper modelServingMapper = mock(ModelServingMapper.class);
private final RuntimeDao runtimeDao = mock(RuntimeDao.class);
private final ProjectManager projectManager = mock(ProjectManager.class);
private final ModelDao modelDao = mock(ModelDao.class);
private final UserService userService = mock(UserService.class);
private final K8sClient k8sClient = mock(K8sClient.class);
private final K8sJobTemplate k8sJobTemplate = mock(K8sJobTemplate.class);
private final RuntimeMapper runtimeMapper = mock(RuntimeMapper.class);
private final RuntimeVersionMapper runtimeVersionMapper = mock(RuntimeVersionMapper.class);
private final ModelMapper modelMapper = mock(ModelMapper.class);
private final ModelVersionMapper modelVersionMapper = mock(ModelVersionMapper.class);
private final SystemSettingService systemSettingService = mock(SystemSettingService.class);
private final RunTimeProperties runTimeProperties = mock(RunTimeProperties.class);
private final ModelServingTokenValidator modelServingTokenValidator = mock(ModelServingTokenValidator.class);

@BeforeEach
public void setUp() {
Expand Down Expand Up @@ -131,7 +131,8 @@ public void testCreate() throws ApiException {
"SW_TOKEN", "token",
"SW_INSTANCE_URI", "inst",
"SW_MODEL_VERSION", "md/version/9",
"SW_RUNTIME_VERSION", "rt/version/8"
"SW_RUNTIME_VERSION", "rt/version/8",
"SW_MODEL_SERVING_BASE_URI", "/gateway/model-serving/7"
), "img", "model-serving-7");

verify(k8sClient).deployService(any());
Expand Down

0 comments on commit 98ea814

Please sign in to comment.