Skip to content

Commit

Permalink
fix(controller): allow recovering project if project is deleted (#3115)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Jan 8, 2024
1 parent 554441f commit a5da552
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
import ai.starwhale.mlops.domain.user.bo.Role;
import ai.starwhale.mlops.domain.user.bo.User;
import ai.starwhale.mlops.exception.StarwhaleException;
import ai.starwhale.mlops.exception.SwNotFoundException;
import ai.starwhale.mlops.exception.SwNotFoundException.ResourceType;
import ai.starwhale.mlops.exception.SwValidationException;
import io.jsonwebtoken.Claims;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
Expand All @@ -58,40 +57,50 @@ public class JwtTokenFilter extends OncePerRequestFilter {
private final List<JwtClaimValidator> jwtClaimValidators;

private static final String AUTH_HEADER = "Authorization";

public JwtTokenFilter(JwtTokenUtil jwtTokenUtil, UserService userService, ProjectService projectService,
List<JwtClaimValidator> jwtClaimValidators) {
private static final List<Pattern> WHITE_LIST_FOR_DELETED_PROJECTS = List.of(
Pattern.compile("/api/v1/project/[^/]+/recover")
);

public JwtTokenFilter(
JwtTokenUtil jwtTokenUtil,
UserService userService,
ProjectService projectService,
List<JwtClaimValidator> jwtClaimValidators
) {
this.jwtTokenUtil = jwtTokenUtil;
this.userService = userService;
this.projectService = projectService;
this.jwtClaimValidators = jwtClaimValidators;
}

boolean allowAnonymous(HttpServletRequest request) {
try {
var projects = getProjects(request);
// only for public project
return projects.stream().allMatch(p -> p.getPrivacy() == Project.Privacy.PUBLIC);
} catch (SwNotFoundException e) {
return false;
}
boolean allowAnonymous(Set<Project> projects) {
// only for public project
return projects.stream().allMatch(p -> p.getPrivacy() == Project.Privacy.PUBLIC);
}

@Override
protected void doFilterInternal(HttpServletRequest httpServletRequest,
protected void doFilterInternal(
HttpServletRequest httpServletRequest,
@NonNull HttpServletResponse httpServletResponse,
@NonNull FilterChain filterChain) throws ServletException, IOException {
@NonNull FilterChain filterChain
) throws ServletException, IOException {
String header = httpServletRequest.getHeader(AUTH_HEADER);

if (!checkHeader(header)) {
if (isInvalidAuthHeader(header)) {
header = httpServletRequest.getParameter(AUTH_HEADER);
}
if (!checkHeader(header)) {

var projects = getProjects(httpServletRequest);
if (!verifyProjectsExist(httpServletRequest, httpServletResponse, projects)) {
return;
}

if (isInvalidAuthHeader(header)) {
// check whether the uri allow anonymous in public project
if (allowAnonymous(httpServletRequest)) {
if (allowAnonymous(projects)) {
// Build jwt token with anonymous user
JwtLoginToken jwtLoginToken = new JwtLoginToken(null, "", List.of(
Role.builder().roleCode(Role.CODE_ANONYMOUS).roleName(Role.NAME_ANONYMOUS).build()));
Role.builder().roleCode(Role.CODE_ANONYMOUS).roleName(Role.NAME_ANONYMOUS).build()));
jwtLoginToken.setDetails(new WebAuthenticationDetails(httpServletRequest));
SecurityContextHolder.getContext().setAuthentication(jwtLoginToken);
} else {
Expand Down Expand Up @@ -123,12 +132,8 @@ protected void doFilterInternal(HttpServletRequest httpServletRequest,
role -> role.getAuthority().equals(Role.CODE_OWNER)).collect(Collectors.toSet());
// Get project roles
try {
Set<Project> projects = getProjects(httpServletRequest);
Set<Role> rolesOfUser = userService.getProjectsRolesOfUser(user, projects);
roles.addAll(rolesOfUser);
} catch (SwNotFoundException e) {
error(httpServletResponse, HttpStatus.NOT_FOUND.value(), Code.validationException, e.getMessage());
return;
} catch (StarwhaleException e) {
logger.error(e.getMessage());
}
Expand All @@ -142,23 +147,40 @@ protected void doFilterInternal(HttpServletRequest httpServletRequest,
}

@NotNull
private Set<Project> getProjects(HttpServletRequest httpServletRequest) throws SwNotFoundException {
private Set<Project> getProjects(HttpServletRequest httpServletRequest) {
@SuppressWarnings("unchecked")
Set<Project> projects = ((Set<String>) httpServletRequest
.getAttribute(ProjectDetectionFilter.ATTRIBUTE_PROJECT))
var projectIds = (Set<String>) httpServletRequest.getAttribute(ProjectDetectionFilter.ATTRIBUTE_PROJECT);
if (projectIds == null) {
return Set.of();
}

return projectIds
.stream()
.map((String projectUrl) -> {
var p = projectService.findProject(projectUrl);
if (p.isDeleted()) {
throw new SwNotFoundException(ResourceType.PROJECT, "Project is deleted");
}
return p;
})
.map(projectService::findProject)
.collect(Collectors.toSet());
return projects;
}

private boolean checkHeader(String header) {
return StringUtils.hasText(header) && header.startsWith("Bearer ");
private boolean isInvalidAuthHeader(String header) {
return !StringUtils.hasText(header) || !header.startsWith("Bearer ");
}

private boolean verifyProjectsExist(HttpServletRequest request, HttpServletResponse response, Set<Project> projects)
throws IOException {
// never check for root path
var uri = request.getRequestURI();
if (!StringUtils.hasText(uri)) {
return true;
}
if (projects.isEmpty()) {
return true;
}
if (projects.stream().noneMatch(Project::isDeleted)) {
return true;
}
if (WHITE_LIST_FOR_DELETED_PROJECTS.stream().anyMatch(p -> p.matcher(request.getRequestURI()).matches())) {
return true;
}
error(response, HttpStatus.NOT_FOUND.value(), Code.validationException, "Project is deleted");
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package ai.starwhale.mlops.configuration.security;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -129,6 +132,7 @@ public void testDeletedProject() throws ServletException, IOException {
HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getHeader("Authorization")).thenReturn("Bearer a");
when(request.getAttribute("PROJECT")).thenReturn(Set.of("deleted"));
when(request.getRequestURI()).thenReturn("/api/v1/project/1/jobs");
HttpServletResponse response = mock(HttpServletResponse.class);
FilterChain filterchain = mock(FilterChain.class);
when(jwtTokenUtil.getUsername(any())).thenReturn("foo");
Expand All @@ -137,6 +141,16 @@ public void testDeletedProject() throws ServletException, IOException {
jwtTokenFilter.doFilterInternal(request, response, filterchain);
httpUtilMockedStatic.verify(
() -> HttpUtil.error(response, HttpStatus.NOT_FOUND.value(), Code.validationException,
"Resource is not found Project\nProject is deleted"), times(1));
"Project is deleted"), times(1));

// test project recover
for (var uri : List.of("/api/v1/project/1/recover", "/api/v1/project/abc/recover")) {
when(request.getRequestURI()).thenReturn(uri);
httpUtilMockedStatic.clearInvocations();
jwtTokenFilter.doFilterInternal(request, response, filterchain);
httpUtilMockedStatic.verify(
() -> HttpUtil.error(any(HttpServletResponse.class), anyInt(), any(Code.class), anyString()),
never());
}
}
}

0 comments on commit a5da552

Please sign in to comment.