|
16 | 16 |
|
17 | 17 | package org.springframework.security.web.context;
|
18 | 18 |
|
| 19 | +import java.io.IOException; |
19 | 20 | import java.lang.annotation.ElementType;
|
20 | 21 | import java.lang.annotation.Retention;
|
21 | 22 | import java.lang.annotation.RetentionPolicy;
|
22 | 23 | import java.lang.annotation.Target;
|
| 24 | +import javax.servlet.Filter; |
| 25 | +import javax.servlet.ServletException; |
23 | 26 | import javax.servlet.ServletOutputStream;
|
| 27 | +import javax.servlet.http.HttpServlet; |
24 | 28 | import javax.servlet.http.HttpServletRequest;
|
25 | 29 | import javax.servlet.http.HttpServletRequestWrapper;
|
26 | 30 | import javax.servlet.http.HttpServletResponse;
|
|
30 | 34 | import org.junit.After;
|
31 | 35 | import org.junit.Test;
|
32 | 36 |
|
| 37 | +import org.springframework.mock.web.MockFilterChain; |
33 | 38 | import org.springframework.mock.web.MockHttpServletRequest;
|
34 | 39 | import org.springframework.mock.web.MockHttpServletResponse;
|
35 | 40 | import org.springframework.mock.web.MockHttpSession;
|
36 | 41 | import org.springframework.security.authentication.AbstractAuthenticationToken;
|
37 | 42 | import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
38 | 43 | import org.springframework.security.authentication.AuthenticationTrustResolver;
|
39 | 44 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
| 45 | +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; |
40 | 46 | import org.springframework.security.core.Transient;
|
41 | 47 | import org.springframework.security.core.authority.AuthorityUtils;
|
42 | 48 | import org.springframework.security.core.context.SecurityContext;
|
43 | 49 | import org.springframework.security.core.context.SecurityContextHolder;
|
| 50 | +import org.springframework.security.core.context.SecurityContextImpl; |
| 51 | +import org.springframework.security.core.userdetails.User; |
| 52 | +import org.springframework.security.core.userdetails.UserDetails; |
44 | 53 |
|
45 | 54 | import static org.assertj.core.api.Assertions.assertThat;
|
46 | 55 | import static org.mockito.ArgumentMatchers.anyBoolean;
|
@@ -174,6 +183,48 @@ public void saveContextCallsSetAttributeIfContextIsModifiedDirectlyDuringRequest
|
174 | 183 | verify(session).setAttribute(SPRING_SECURITY_CONTEXT_KEY, ctx);
|
175 | 184 | }
|
176 | 185 |
|
| 186 | + |
| 187 | + @Test |
| 188 | + public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved() throws Exception { |
| 189 | + HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository(); |
| 190 | + SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter( |
| 191 | + repository); |
| 192 | + |
| 193 | + UserDetails original = User.withUsername("user").password("password").roles("USER").build(); |
| 194 | + SecurityContext originalContext = createSecurityContext(original); |
| 195 | + UserDetails impersonate = User.withUserDetails(original).username("impersonate").build(); |
| 196 | + SecurityContext impersonateContext = createSecurityContext(impersonate); |
| 197 | + |
| 198 | + MockHttpServletRequest mockRequest = new MockHttpServletRequest(); |
| 199 | + MockHttpServletResponse mockResponse = new MockHttpServletResponse(); |
| 200 | + |
| 201 | + Filter saveImpersonateContext = (request, response, chain) -> { |
| 202 | + SecurityContextHolder.setContext(impersonateContext); |
| 203 | + // ensure the response is committed to trigger save |
| 204 | + response.flushBuffer(); |
| 205 | + chain.doFilter(request, response); |
| 206 | + }; |
| 207 | + Filter saveOriginalContext = (request, response, chain) -> { |
| 208 | + SecurityContextHolder.setContext(originalContext); |
| 209 | + chain.doFilter(request, response); |
| 210 | + }; |
| 211 | + HttpServlet servlet = new HttpServlet() { |
| 212 | + @Override |
| 213 | + protected void service(HttpServletRequest req, HttpServletResponse resp) |
| 214 | + throws ServletException, IOException { |
| 215 | + resp.getWriter().write("Hi"); |
| 216 | + } |
| 217 | + }; |
| 218 | + |
| 219 | + SecurityContextHolder.setContext(originalContext); |
| 220 | + MockFilterChain chain = new MockFilterChain(servlet, saveImpersonateContext, saveOriginalContext); |
| 221 | + |
| 222 | + securityContextPersistenceFilter.doFilter(mockRequest, mockResponse, chain); |
| 223 | + |
| 224 | + assertThat(mockRequest.getSession().getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY)) |
| 225 | + .isEqualTo(originalContext); |
| 226 | + } |
| 227 | + |
177 | 228 | @Test
|
178 | 229 | public void nonSecurityContextInSessionIsIgnored() {
|
179 | 230 | HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
|
@@ -668,6 +719,13 @@ public void saveContextWhenTransientAuthenticationWithCustomAnnotationThenSkippe
|
668 | 719 | assertThat(session).isNull();
|
669 | 720 | }
|
670 | 721 |
|
| 722 | + private SecurityContext createSecurityContext(UserDetails userDetails) { |
| 723 | + UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userDetails, |
| 724 | + userDetails.getPassword(), userDetails.getAuthorities()); |
| 725 | + SecurityContext securityContext = new SecurityContextImpl(token); |
| 726 | + return securityContext; |
| 727 | + } |
| 728 | + |
671 | 729 | @Transient
|
672 | 730 | private static class SomeTransientAuthentication extends AbstractAuthenticationToken {
|
673 | 731 | SomeTransientAuthentication() {
|
|
0 commit comments