Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@

import static datadog.trace.api.UserEventTrackingMode.DISABLED;
import static datadog.trace.api.UserEventTrackingMode.EXTENDED;
import static datadog.trace.api.telemetry.LogCollector.SEND_TELEMETRY;

import datadog.trace.api.Config;
import datadog.trace.api.UserEventTrackingMode;
import datadog.trace.bootstrap.instrumentation.decorator.AppSecUserEventDecorator;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
Expand All @@ -17,6 +24,13 @@ public class SpringSecurityUserEventDecorator extends AppSecUserEventDecorator {

public static final SpringSecurityUserEventDecorator DECORATE =
new SpringSecurityUserEventDecorator();
private static final String SPRING_SECURITY_PACKAGE = "org.springframework.security";

private static final Logger LOGGER =
LoggerFactory.getLogger(SpringSecurityUserEventDecorator.class);

private static final Set<Class<?>> SKIPPED_AUTHS =
Collections.newSetFromMap(new ConcurrentHashMap<>());

public void onSignup(UserDetails user, Throwable throwable) {
// skip failures while signing up a user, later on, we might want to generate a separate event
Expand Down Expand Up @@ -53,8 +67,7 @@ public void onLogin(Authentication authentication, Throwable throwable, Authenti
return;
}

// For now, exclude all OAuth events. See APPSEC-12547.
if (authentication.getClass().getName().contains("OAuth")) {
if (shouldSkipAuthentication(authentication)) {
return;
}

Expand Down Expand Up @@ -98,4 +111,26 @@ public void onLogin(Authentication authentication, Throwable throwable, Authenti
}
}
}

private static boolean shouldSkipAuthentication(final Authentication authentication) {
if (authentication instanceof UsernamePasswordAuthenticationToken) {
return false;
}
if (SKIPPED_AUTHS.add(authentication.getClass())) {
final Class<?> authClass = authentication.getClass();
LOGGER.debug(
SEND_TELEMETRY, "Skipped authentication, auth={}", findRootAuthentication(authClass));
}
return true;
}

private static String findRootAuthentication(Class<?> authentication) {
while (authentication != Object.class) {
if (authentication.getName().startsWith(SPRING_SECURITY_PACKAGE)) {
return authentication.getName();
}
authentication = authentication.getSuperclass();
}
return Authentication.class.getName(); // set this a default for really custom impls
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
package datadog.trace.instrumentation.springsecurity5

import custom.CustomAuthenticationFilter
import custom.CustomAuthenticationProvider
import org.springframework.boot.jdbc.DataSourceBuilder
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.security.authentication.AuthenticationManager
import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer
import org.springframework.security.provisioning.JdbcUserDetailsManager
import org.springframework.security.provisioning.UserDetailsManager
import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter

import javax.sql.DataSource

import static datadog.trace.instrumentation.springsecurity5.SecurityConfig.CustomDsl.customDsl

@Configuration
@EnableWebSecurity
class SecurityConfig {

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http.authorizeHttpRequests(
http.apply(customDsl())
http
.authorizeHttpRequests(
(requests) -> requests
.requestMatchers("/", "/success", "/register", "/login").permitAll()
.requestMatchers("/", "/success", "/register", "/login", "/custom").permitAll()
.anyRequest().authenticated())
.csrf().disable()
.formLogin((form) -> form.loginPage("/login").permitAll())
Expand All @@ -40,4 +49,17 @@ class SecurityConfig {
UserDetailsManager userDetailsService() {
return new JdbcUserDetailsManager(dataSource)
}

static class CustomDsl extends AbstractHttpConfigurer<CustomDsl, HttpSecurity> {
@Override
void configure(HttpSecurity http) throws Exception {
AuthenticationManager authenticationManager = http.getSharedObject(AuthenticationManager)
http.authenticationProvider(new CustomAuthenticationProvider())
http.addFilterBefore(new CustomAuthenticationFilter(authenticationManager), UsernamePasswordAuthenticationFilter)
}

static CustomDsl customDsl() {
return new CustomDsl()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package datadog.trace.instrumentation.springsecurity5

import ch.qos.logback.classic.Logger
import ch.qos.logback.core.Appender
import com.datadog.appsec.AppSecHttpServerTest
import datadog.trace.agent.test.base.HttpServer
import datadog.trace.core.DDSpan
Expand All @@ -14,6 +16,7 @@ import spock.lang.Shared

import static datadog.trace.instrumentation.springsecurity5.TestEndpoint.LOGIN
import static datadog.trace.agent.test.utils.TraceUtils.runUnderTrace
import static datadog.trace.instrumentation.springsecurity5.TestEndpoint.CUSTOM
import static datadog.trace.instrumentation.springsecurity5.TestEndpoint.REGISTER
import static datadog.trace.instrumentation.springsecurity5.TestEndpoint.UNKNOWN
import static datadog.trace.instrumentation.springsecurity5.TestEndpoint.NOT_FOUND
Expand Down Expand Up @@ -225,6 +228,33 @@ class SpringBootBasedTest extends AppSecHttpServerTest<ConfigurableApplicationCo
span.getTags().findAll { it.key.startsWith('appsec.events.users.signup') }.isEmpty()
}

void 'test skipped authentication'() {
setup:
final appender = Mock(Appender)
final logger = SpringSecurityUserEventDecorator.LOGGER as Logger
logger.addAppender(appender)

and:
final requestCount = 3
final request = request(CUSTOM, "GET", null).addHeader('X-Custom-User', 'batman').build()

when:
final response = (1..requestCount).collect { client.newCall(request).execute() }.first()
TEST_WRITER.waitForTraces(3)
final span = TEST_WRITER.flatten().first() as DDSpan
logger.detachAppender(appender) // cant add cleanup

then:
response.code() == CUSTOM.status
span.context().resourceName.contains(CUSTOM.path)
span.getTags().findAll { key, value -> key.startsWith('appsec.events.users.login')}.isEmpty()
// single call to the appender
1 * appender.doAppend(_) >> {
assert it[0].toString().contains('Skipped authentication, auth=org.springframework.security.authentication.AbstractAuthenticationToken')
}
0 * appender._
}

@SuppressWarnings('GroovyAssignabilityCheck')
private static String randomString(int length) {
return new Random().with { random ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ enum TestEndpoint {
REGISTER("register", 200, ""),
NOT_FOUND("not-found", 404, "not found"),
UNKNOWN("", 451, null), // This needs to have a valid status code
CUSTOM("custom", 302, ""),

private final String path
private final String rawPath
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package custom;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;

public class CustomAuthenticationFilter extends AbstractAuthenticationProcessingFilter {

private static final String HEADER_NAME = "X-Custom-User";

private final AuthenticationManager authenticationManager;

public CustomAuthenticationFilter(final AuthenticationManager authenticationManager) {
super("/custom");
this.authenticationManager = authenticationManager;
}

@Override
public Authentication attemptAuthentication(
HttpServletRequest request, HttpServletResponse response)
throws AuthenticationException, IOException, ServletException {
final String user = request.getHeader(HEADER_NAME);
if (user == null) {
return null;
}
return authenticationManager.authenticate(new CustomAuthenticationToken(user));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package custom;

import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;

public class CustomAuthenticationProvider implements AuthenticationProvider {

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
return authentication;
}

@Override
public boolean supports(Class<?> authentication) {
return CustomAuthenticationToken.class.isAssignableFrom(authentication);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package custom;

import java.util.Collections;
import org.springframework.security.authentication.AbstractAuthenticationToken;

public class CustomAuthenticationToken extends AbstractAuthenticationToken {

private final String user;

public CustomAuthenticationToken(String user) {
super(Collections.emptyList());
this.user = user;
}

@Override
public Object getCredentials() {
return user;
}

@Override
public Object getPrincipal() {
return user;
}
}