Skip to content

Commit

Permalink
feat(controller): online eval support gc
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Jan 5, 2023
1 parent fa35b89 commit 5d8c8bc
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.net.UnknownHostException;
import java.util.Date;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -134,10 +135,10 @@ public String getTarget(String uri) {
}
var id = Long.parseLong(parts[1]);

// TODO add cache
if (modelServingMapper.find(id) == null) {
throw new IllegalArgumentException("can not find model serving entry " + parts[1]);
}
modelServingMapper.updateLastVisitTime(id, new Date());

var svc = ModelServingService.getServiceName(id);
var handler = "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,27 @@
import ai.starwhale.mlops.exception.SwProcessException;
import ai.starwhale.mlops.schedule.k8s.K8sClient;
import ai.starwhale.mlops.schedule.k8s.K8sJobTemplate;
import com.google.protobuf.Api;
import io.kubernetes.client.custom.IntOrString;
import io.kubernetes.client.openapi.ApiException;
import io.kubernetes.client.openapi.models.V1ObjectMeta;
import io.kubernetes.client.openapi.models.V1OwnerReference;
import io.kubernetes.client.openapi.models.V1Service;
import io.kubernetes.client.openapi.models.V1ServicePort;
import io.kubernetes.client.openapi.models.V1ServiceSpec;
import io.kubernetes.client.openapi.models.V1StatefulSet;
import io.kubernetes.client.util.labels.EqualityMatcher;
import io.kubernetes.client.util.labels.LabelSelector;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.Objects;
import java.util.TreeMap;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;

@Slf4j
Expand All @@ -77,6 +84,7 @@ public class ModelServingService {


public static final String MODEL_SERVICE_PREFIX = "model-serving";
private static final Pattern modelServingNamePattern = Pattern.compile(MODEL_SERVICE_PREFIX + "-(\\d+)");

public ModelServingService(
ModelServingMapper modelServingMapper,
Expand Down Expand Up @@ -136,16 +144,23 @@ public ModelServingVo create(
.modelVersionId(modelVersionId)
.jobStatus(JobStatus.CREATED)
.resourcePool(resourcePool)
.lastVisitTime(new Date())
.build();

modelServingMapper.add(entity);
long id;
synchronized (this) {
modelServingMapper.add(entity);

var services = modelServingMapper.list(projectId, modelVersionId, runtimeVersionId, resourcePool);
if (services.size() != 1) {
// this can not happen
throw new SwProcessException(SwProcessException.ErrorType.DB, "duplicate entries, size " + services.size());
var services = modelServingMapper.list(projectId, modelVersionId, runtimeVersionId, resourcePool);
if (services.size() != 1) {
// this can not happen
throw new SwProcessException(SwProcessException.ErrorType.DB,
"duplicate entries, size " + services.size());
}
id = services.get(0).getId();
// update last visit time, prevents garbage collected
modelServingMapper.updateLastVisitTime(id, new Date());
}
var id = services.get(0).getId();

log.info("Model serving job has been created. ID={}", id);

Expand Down Expand Up @@ -199,7 +214,7 @@ private void deploy(
);
var ss = k8sJobTemplate.renderModelServingOrch(envs, image, name);
try {
k8sClient.deployStatefulSet(ss);
ss = k8sClient.deployStatefulSet(ss);
} catch (ApiException e) {
if (e.getCode() != HttpServletResponse.SC_CONFLICT) {
throw e;
Expand All @@ -214,24 +229,141 @@ private void deploy(
svc.metadata(meta);
var spec = new V1ServiceSpec();
svc.spec(spec);
var selector = Map.of("app", name);
var selector = Map.of(K8sJobTemplate.LABEL_APP, name);
spec.selector(selector);
var port = new V1ServicePort();
port.name("model-serving-port");
port.protocol("TCP");
port.port(80);
port.targetPort(new IntOrString(8080));
port.targetPort(new IntOrString(K8sJobTemplate.ONLINE_EVAL_PORT_IN_POD));
spec.ports(List.of(port));

// add owner reference for svc and we can just delete the stateful-set when gc is needed
var ownerRef = new V1OwnerReference();
ownerRef.kind(ss.getKind());
Objects.requireNonNull(ss.getMetadata());
ownerRef.uid(ss.getMetadata().getUid());
meta.ownerReferences(List.of(ownerRef));

// add svc to k8s
k8sClient.deployService(svc);
// TODO add owner reference for svc
// TODO garbage collection when svc fails

// if operations of svc failed, the gc thread will delete the zombie stateful-set,
// so we do not need to delete the previous stateful-set when this fails
}

public static String getServiceName(long id) {
return String.format("%s-%d", MODEL_SERVICE_PREFIX, id);
}

public static Long getServiceIdFromName(String name) {
var match = modelServingNamePattern.matcher(name);
if (match.matches()) {
return Long.parseLong(match.group(1));
}
return null;
}

public static String getServiceBaseUri(long id) {
return String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id);
}


@Scheduled(initialDelay = 10000, fixedDelay = 10000)
public void gc() throws ApiException {
var labelSelector = LabelSelector.and(EqualityMatcher.equal(K8sJobTemplate.LABEL_WORKLOAD_TYPE,
K8sJobTemplate.WORKLOAD_TYPE_ONLINE_EVAL)).toString();
var statefulSetList = k8sClient.getStatefulSetList(labelSelector);

boolean hasPending = false;
Map<Date, V1StatefulSet> running = new TreeMap<>((t1, t2) -> {
// oldest at the beginning
return t1 == t2 ? 0 : t1.before(t2) ? -1 : 1;
});

for (var statefulSet : statefulSetList.getItems()) {
// check if the stateful set is outdated
var meta = statefulSet.getMetadata();
if (meta == null || StringUtils.isEmpty(meta.getName())) {
continue;
}
// parse entity id from stateful set name
var name = meta.getName();
var id = getServiceIdFromName(name);
if (id == null) {
log.warn("can not get entity id from name {}", name);
continue;
}

// check if in db record
ModelServingEntity entity;
synchronized (this) {
entity = modelServingMapper.find(id);
}
if (entity == null) {
// delete the unknown stateful set
log.info("delete stateful set {} when there is no entry in db", name);
deleteStatefulSet(name);
continue;
}

var now = System.currentTimeMillis();
// TODO use duration from system settings
if (now - entity.getLastVisitTime().getTime() > 12 * 3600 * 1000) {
log.info("delete stateful set {} when it reaches the max TTL", name);
deleteStatefulSet(name);
}

var createTime = statefulSet.getMetadata().getCreationTimestamp();
// TODO use duration from system settings
if (createTime != null && now - createTime.toInstant().toEpochMilli() < 1800 * 1000) {
// just been deployed, ignore
log.info("ignore stateful set {} (just been deployed)", name);
continue;
}

var status = statefulSet.getStatus();
// check if the stateful set is pending
if (status == null) {
// may have just been deployed, ignore
log.info("ignore stateful set {} (no status found)", name);
continue;
}
if (status.getReadyReplicas() == null || status.getReadyReplicas() == 0) {
hasPending = true;
log.info("found pending stateful set {}", name);
continue;
}

running.put(entity.getLastVisitTime(), statefulSet);
}

if (!hasPending) {
log.info("no pending stateful set, done");
return;
}

// kill the oldest stateful set
if (running.isEmpty()) {
log.info("no stateful set to gc");
return;
}
var key = running.keySet().iterator().next();
var oldest = running.get(key);
var name = Objects.requireNonNull(oldest.getMetadata()).getName();
k8sClient.deleteStatefulSet(name);
log.info("delete stateful set {}", name);
}

private void deleteStatefulSet(String name) {
try {
k8sClient.deleteStatefulSet(name);
} catch (ApiException e) {
if (e.getCode() == HttpServletResponse.SC_NOT_FOUND) {
log.info("stateful set {} not found", name);
return;
}
log.error("delete stateful set {} failed, reason {}", name, e.getResponseBody(), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ai.starwhale.mlops.domain.job.po.ModelServingEntity;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import org.apache.commons.text.CaseUtils;
import org.apache.ibatis.annotations.InsertProvider;
Expand All @@ -26,6 +27,7 @@
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.apache.ibatis.annotations.SelectProvider;
import org.apache.ibatis.annotations.Update;
import org.apache.ibatis.jdbc.SQL;

@Mapper
Expand All @@ -49,6 +51,9 @@ public interface ModelServingMapper {
@Select("select * from " + TABLE + " where id=#{id}")
ModelServingEntity find(long id);

@Update("update " + TABLE + " set last_visit_time = #{date} where id = #{id}")
void updateLastVisitTime(long id, Date date);

@SelectProvider(value = SqlProviderAdapter.class, method = "listByConditions")
List<ModelServingEntity> list(
@Param("projectId") Long projectId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ public class ModelServingEntity extends BaseEntity {
private Long runtimeVersionId;
private Integer isDeleted;
private String resourcePool;
private Date lastVisitTime;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.starwhale.mlops.schedule.k8s;

import io.kubernetes.client.openapi.models.V1EnvVar;
import io.kubernetes.client.openapi.models.V1Probe;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -39,6 +40,8 @@ public class ContainerOverwriteSpec {

List<V1EnvVar> envs;

V1Probe readinessProbe;

public ContainerOverwriteSpec(String name) {
this.name = name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.kubernetes.client.openapi.models.V1PodList;
import io.kubernetes.client.openapi.models.V1Service;
import io.kubernetes.client.openapi.models.V1StatefulSet;
import io.kubernetes.client.openapi.models.V1StatefulSetList;
import io.kubernetes.client.util.CallGeneratorParams;
import io.kubernetes.client.util.labels.LabelSelector;
import java.io.IOException;
Expand Down Expand Up @@ -102,6 +103,10 @@ public void deleteJob(String id) throws ApiException {
batchV1Api.deleteNamespacedJob(id, ns, null, null, 1, false, null, null);
}

public void deleteStatefulSet(String name) throws ApiException {
appsV1Api.deleteNamespacedStatefulSet(name, ns, null, null, 1, false, null, null);
}

/**
* get all jobs with in this.ns
*
Expand All @@ -111,6 +116,11 @@ public V1JobList getJobs(String labelSelector) throws ApiException {
return batchV1Api.listNamespacedJob(ns, null, null, null, null, labelSelector, null, null, null, 30, null);
}

public V1StatefulSetList getStatefulSetList(String labelSelector) throws ApiException {
return appsV1Api.listNamespacedStatefulSet(ns, null, null, null, labelSelector,
null, null, null, null, 30, null);
}

public void watchJob(ResourceEventHandler<V1Job> eventH, String selector) {
SharedIndexInformer<V1Job> jobInformer = informerFactory.sharedIndexInformerFor(
(CallGeneratorParams params) -> batchV1Api.listNamespacedJobCall(ns, null, null, null, null,
Expand Down
Loading

0 comments on commit 5d8c8bc

Please sign in to comment.