Skip to content

Commit

Permalink
feat(controller): diffrent sorting methods of project list (#2006)
Browse files Browse the repository at this point in the history
* feat: different sorting methods of projects

* unit test

* move project visited to get project info api
  • Loading branch information
dreamlandliu authored Mar 29, 2023
1 parent 7c66016 commit d6eb6e3
Show file tree
Hide file tree
Showing 20 changed files with 551 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import ai.starwhale.mlops.api.protocol.user.ProjectMemberVo;
import com.github.pagehelper.PageInfo;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import io.swagger.v3.oas.annotations.tags.Tag;
Expand Down Expand Up @@ -54,8 +57,11 @@ ResponseEntity<ResponseMessage<PageInfo<ProjectVo>>> listProject(
@RequestParam(value = "projectName", required = false) String projectName,
@Valid @RequestParam(value = "pageNum", required = false, defaultValue = "1") Integer pageNum,
@Valid @RequestParam(value = "pageSize", required = false, defaultValue = "10") Integer pageSize,
@Valid @RequestParam(value = "sort", required = false) String sort,
@Valid @RequestParam(value = "order", required = false, defaultValue = "1") Integer order);
@Parameter(
in = ParameterIn.PATH,
description = "The sort type of project list. (Default=visited)",
schema = @Schema(allowableValues = {"visited", "latest", "oldest"}))
@Valid @RequestParam(value = "sort", required = false) String sort);


@Operation(summary = "Create or Recover a new project")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import ai.starwhale.mlops.api.protocol.user.ProjectMemberVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.OrderParams;
import ai.starwhale.mlops.common.PageParams;
import ai.starwhale.mlops.domain.member.MemberService;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.project.bo.Project;
Expand Down Expand Up @@ -65,17 +64,12 @@ public ProjectController(ProjectService projectService, UserService userService,

@Override
public ResponseEntity<ResponseMessage<PageInfo<ProjectVo>>> listProject(String projectName,
Integer pageNum, Integer pageSize, String sort, Integer order) {
Integer pageNum, Integer pageSize, String sort) {
User user = userService.currentUserDetail();
PageInfo<ProjectVo> projects = projectService.listProject(
projectName,
PageParams.builder()
.pageNum(pageNum)
.pageSize(pageSize)
.build(),
OrderParams.builder()
.sort(sort)
.order(order)
.build(),
user);

Expand Down Expand Up @@ -118,6 +112,7 @@ public ResponseEntity<ResponseMessage<String>> recoverProject(String projectId)

@Override
public ResponseEntity<ResponseMessage<ProjectVo>> getProjectByUrl(String projectUrl) {
projectService.visit(projectUrl);
ProjectVo vo = projectService.getProjectVo(projectUrl);
return ResponseEntity.ok(Code.success.asResponse(vo));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@

package ai.starwhale.mlops.common;

import cn.hutool.db.sql.Direction;
import cn.hutool.db.sql.Order;
import java.util.Map;
import javax.validation.ValidationException;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.springframework.util.StringUtils;

@Builder
@Getter
Expand All @@ -36,14 +31,4 @@ public class OrderParams extends BaseParams {

private int order;

public String getOrderSql(Map<String, String> fieldMap) throws ValidationException {
if (StringUtils.hasText(sort)) {
if (fieldMap == null || !fieldMap.containsKey(sort)) {
throw new ValidationException();
}
return new Order(fieldMap.get(sort), order < 0 ? Direction.DESC : Direction.ASC).toString();
} else {
return "";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public class SecurityConfiguration extends WebSecurityConfigurerAdapter {
@Resource
private ContentCachingFilter contentCachingFilter;


public SecurityConfiguration() {
super();
// Inherit security context ,so async function calls can effect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public JobRepo(DataStore store,
}

private List<String> tableNames() {
var projects = projectService.listProjects(null, null, null);
var projects = projectService.listProjects();
return projects.stream()
.map(Project::getId)
.map(this::tableName)
Expand Down Expand Up @@ -238,7 +238,7 @@ public List<JobFlattenEntity> findJobByStatusIn(List<JobStatus> jobStatuses) {

List<JobFlattenEntity> results = new ArrayList<>();
// find all projects
var projects = projectService.listProjects(null, null, null);
var projects = projectService.listProjects();
for (Project project : projects) {
results.addAll(this.getJobEntities(project, filter));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
import ai.starwhale.mlops.api.protocol.user.UserVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.common.OrderParams;
import ai.starwhale.mlops.common.PageParams;
import ai.starwhale.mlops.common.util.PageUtil;
import ai.starwhale.mlops.domain.member.MemberService;
import ai.starwhale.mlops.domain.member.bo.ProjectMember;
import ai.starwhale.mlops.domain.project.bo.Project;
import ai.starwhale.mlops.domain.project.bo.Project.Privacy;
import ai.starwhale.mlops.domain.project.mapper.ProjectMapper;
import ai.starwhale.mlops.domain.project.mapper.ProjectVisitedMapper;
import ai.starwhale.mlops.domain.project.po.ObjectCountEntity;
import ai.starwhale.mlops.domain.project.po.ProjectEntity;
import ai.starwhale.mlops.domain.project.po.ProjectObjectCounts;
import ai.starwhale.mlops.domain.project.sort.Sort;
import ai.starwhale.mlops.domain.user.UserService;
import ai.starwhale.mlops.domain.user.bo.Role;
import ai.starwhale.mlops.domain.user.bo.User;
Expand All @@ -42,30 +42,36 @@
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import cn.hutool.core.util.StrUtil;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.google.common.base.Joiner;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;

@Slf4j
@Service
public class ProjectService implements ProjectAccessor {
public class ProjectService implements ProjectAccessor, ApplicationContextAware {

public static final String PROJECT_NAME_REGEX = "^[a-zA-Z][a-zA-Z\\d_-]{2,80}$";


private final ProjectMapper projectMapper;

private final ProjectVisitedMapper projectVisitedMapper;
private final ProjectDao projectDao;

private final MemberService memberService;
Expand All @@ -76,18 +82,29 @@ public class ProjectService implements ProjectAccessor {

private static final String DELETE_SUFFIX = ".deleted";

private final Map<Long, Visit> visitedProjectCacheMap = new ConcurrentHashMap<>();
private static final long storageInterval = 1000;
private ApplicationContext applicationContext;

public ProjectService(ProjectMapper projectMapper,
ProjectVisitedMapper projectVisitedMapper,
ProjectDao projectDao,
MemberService memberService,
IdConverter idConvertor,
UserService userService) {
this.projectMapper = projectMapper;
this.projectVisitedMapper = projectVisitedMapper;
this.projectDao = projectDao;
this.memberService = memberService;
this.idConvertor = idConvertor;
this.userService = userService;
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}

public Project findProject(Long id) {
ProjectEntity entity = projectDao.findById(id);
return toProject(entity);
Expand Down Expand Up @@ -125,31 +142,24 @@ public ProjectVo getProjectVo(String projectUrl) {
return ProjectVo.fromBo(findProject(projectUrl), idConvertor);
}

/**
* Get the list of projects.
*
* @param projectName Search by project name prefix if the project name is set.
* @param pageParams Paging parameters.
* @return A list of ProjectVo objects
*/
public PageInfo<ProjectVo> listProject(String projectName, PageParams pageParams, OrderParams orderParams,
User user) {
Long userId = user.getId();
public PageInfo<ProjectVo> listProject(String projectName, OrderParams orderParams, User user) {
boolean showAll = false;
List<Role> sysRoles = userService.getProjectRolesOfUser(user, Project.system());
for (Role sysRole : sysRoles) {
if (sysRole.getAuthority().equals("OWNER")) {
userId = null;
showAll = true;
break;
}
}

PageHelper.startPage(pageParams.getPageNum(), pageParams.getPageSize());
List<ProjectEntity> entities = listProjects(projectName, userId, orderParams);
Sort sort = getSort(orderParams.getSort());

List<ProjectEntity> entities = sort.list(projectName, user, showAll);
List<Long> ids = entities.stream().map(ProjectEntity::getId).collect(Collectors.toList());
Map<Long, ProjectObjectCounts> countMap = getObjectCountsOfProjects(
ids);

return PageUtil.toPageInfo(entities, entity -> {
return new PageInfo<>(entities.stream().map(entity -> {
ProjectVo vo = ProjectVo.fromEntity(entity, idConvertor,
userService.findUserById(entity.getOwnerId()));
ProjectObjectCounts count = countMap.get(entity.getId());
Expand All @@ -163,16 +173,20 @@ public PageInfo<ProjectVo> listProject(String projectName, PageParams pageParams
.build());
}
return vo;
});
}).collect(Collectors.toList()));
}

public List<Project> listProjects(String projectName, Long userId, String order) {
List<ProjectEntity> list = projectMapper.list(projectName, userId, order);
return list.stream().map(this::toProject).collect(Collectors.toList());
private Sort getSort(String type) {
try {
return this.applicationContext.getBean(StrUtil.isNotEmpty(type) ? type : "visited", Sort.class);
} catch (NoSuchBeanDefinitionException e) {
throw new SwValidationException(ValidSubject.PROJECT, "Unknown sort type. " + type);
}
}

private List<ProjectEntity> listProjects(String projectName, Long userId, OrderParams orderParams) {
return projectMapper.list(projectName, userId, null);
public List<Project> listProjects() {
List<ProjectEntity> list = projectMapper.listOfUser(null, null, null);
return list.stream().map(this::toProject).collect(Collectors.toList());
}

/**
Expand Down Expand Up @@ -338,8 +352,11 @@ private Boolean existProject(String projectName, Long userId) {
}

private Map<Long, ProjectObjectCounts> getObjectCountsOfProjects(List<Long> projectIds) {
String ids = Joiner.on(",").join(projectIds);
Map<Long, ProjectObjectCounts> map = Maps.newHashMap();
if (projectIds.isEmpty()) {
return map;
}
String ids = Joiner.on(",").join(projectIds);
for (Long projectId : projectIds) {
map.put(projectId, new ProjectObjectCounts());
}
Expand Down Expand Up @@ -408,4 +425,33 @@ public Boolean addProjectMember(String projectUrl, Long userId, Long roleId) {
Long projectId = getProjectId(projectUrl);
return memberService.addProjectMember(projectId, userId, roleId);
}


public void visit(String projectUrl) {
if (!Objects.equals("0", projectUrl)) {
Long projectId = getProjectId(projectUrl);
Long userId = userService.currentUserDetail().getId();

Visit visit = new Visit(projectId, System.currentTimeMillis());
Visit lastVisit = visitedProjectCacheMap.get(userId);

if (lastVisit == null || needStorage(visit, lastVisit)) {
visitedProjectCacheMap.put(userId, visit);
projectVisitedMapper.insert(userId, projectId);
}
}
}


private boolean needStorage(Visit current, Visit previous) {
return !Objects.equals(current.projectId, previous.projectId)
&& current.time > previous.time + storageInterval;
}

@AllArgsConstructor
static class Visit {

Long projectId;
long time;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ ProjectEntity findByNameForUpdateAndOwner(@NotNull @Param("projectName") String
@NotNull @Param("ownerId") Long ownerId);

@SelectProvider(value = ProjectProvider.class, method = "listSql")
List<ProjectEntity> list(@Param("projectName") String projectName,
List<ProjectEntity> listOfUser(@Param("projectName") String projectName,
@Param("userId") Long userId,
@Param("order") String order);

@SelectProvider(value = ProjectProvider.class, method = "listAllSql")
List<ProjectEntity> listAll(@Param("projectName") String projectName,
@Param("order") String order);

@SelectProvider(value = ProjectProvider.class, method = "listRemovedSql")
List<ProjectEntity> listRemovedProjects(@Param("projectName") String projectName,
@Param("ownerId") Long ownerId);
Expand Down Expand Up @@ -165,6 +169,26 @@ public String listSql(@NotNull @Param("userId") Long userId,
}.toString();
}

public String listAllSql(
@Param("projectName") String projectName,
@Param("order") String order) {
return new SQL() {
{
SELECT(COLUMNS);
FROM("project_info");
WHERE("is_deleted = 0");
if (StrUtil.isNotEmpty(projectName)) {
WHERE("project_name like concat(#{projectName}, '%')");
}
if (StrUtil.isNotEmpty(order)) {
ORDER_BY(order);
} else {
ORDER_BY("id desc");
}
}
}.toString();
}

public String listRemovedSql(@Param("projectName") String projectName, @Param("ownerId") Long ownerId) {
return new SQL() {
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.domain.project.mapper;

import java.util.List;
import org.apache.ibatis.annotations.Insert;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;

@Mapper
public interface ProjectVisitedMapper {

@Select("select project_id from project_visited where user_id = #{userId} order by id desc")
List<Long> listVisitedProjects(@Param("userId") Long userId);

@Insert("replace into project_visited(user_id, project_id) values (#{userId}, #{projectId})")
int insert(@Param("userId") Long userId, @Param("projectId") Long projectId);
}
Loading

0 comments on commit d6eb6e3

Please sign in to comment.