Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite OIDC implementation #8641

Merged
merged 4 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
*/
package io.trino.server.security;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtException;
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;

import javax.ws.rs.container.ContainerRequestContext;

import java.security.Principal;
import java.util.List;
import java.util.Optional;

import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static java.lang.String.format;
Expand All @@ -30,12 +29,10 @@
public abstract class AbstractBearerAuthenticator
implements Authenticator
{
private final String principalField;
private final UserMapping userMapping;

protected AbstractBearerAuthenticator(String principalField, UserMapping userMapping)
protected AbstractBearerAuthenticator(UserMapping userMapping)
{
this.principalField = requireNonNull(principalField, "principalField is null");
this.userMapping = requireNonNull(userMapping, "userMapping is null");
}

Expand All @@ -50,14 +47,14 @@ public Identity authenticate(ContainerRequestContext request, String token)
throws AuthenticationException
{
try {
Jws<Claims> claimsJws = parseClaimsJws(token);
String principal = claimsJws.getBody().get(principalField, String.class);
if (principal == null) {
Optional<Principal> principal = extractPrincipalFromToken(token);
if (principal.isEmpty()) {
throw needAuthentication(request, "Invalid credentials");
}
String authenticatedUser = userMapping.mapUser(principal);

String authenticatedUser = userMapping.mapUser(principal.get().getName());
return Identity.forUser(authenticatedUser)
.withPrincipal(new BasicPrincipal(principal))
.withPrincipal(principal.get())
.build();
}
catch (JwtException | UserMappingException e) {
Expand Down Expand Up @@ -91,7 +88,7 @@ public String extractToken(ContainerRequestContext request)
return token;
}

protected abstract Jws<Claims> parseClaimsJws(String jws);
protected abstract Optional<Principal> extractPrincipalFromToken(String token);

protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, String message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@
*/
package io.trino.server.security.jwt;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.spi.security.BasicPrincipal;

import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;

import java.security.Principal;
import java.util.Optional;

import static io.trino.server.security.UserMapping.createUserMapping;

public class JwtAuthenticator
extends AbstractBearerAuthenticator
{
private final JwtParser jwtParser;
private final String principalField;

@Inject
public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signingKeyResolver)
{
super(config.getPrincipalField(), createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
principalField = config.getPrincipalField();

JwtParser jwtParser = Jwts.parser()
.setSigningKeyResolver(signingKeyResolver);
Expand All @@ -49,9 +53,12 @@ public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signin
}

@Override
protected Jws<Claims> parseClaimsJws(String jws)
protected Optional<Principal> extractPrincipalFromToken(String token)
{
return jwtParser.parseClaimsJws(jws);
return Optional.ofNullable(jwtParser.parseClaimsJws(token)
.getBody()
.get(principalField, String.class))
.map(BasicPrincipal::new);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class OAuth2AuthenticationSupportModule
protected void setup(Binder binder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix OIDC implementation to resolve correctness issues

Can you please explain what was the issue about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the commit message to be clearer. Specifically the commit fixes several structural issues where there were inappropriate dependencies and duplicate/confusing code. Additionally there are several security fixes including correctly validating ID and access tokens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a new commit message yet.

Specifically the commit fixes several structural issues where there were inappropriate dependencies and duplicate/confusing code.

Let's have each of the issue handled in separate commit, currently it is hard to follow what is fixed where.

Copy link
Member

@dain dain Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked with Nik on this PR and agreed to the current commit structure.

{
binder.bind(OAuth2TokenExchange.class).in(Scopes.SINGLETON);
binder.bind(OAuth2TokenHandler.class).to(OAuth2TokenExchange.class).in(Scopes.SINGLETON);
jaxrsBinder(binder).bind(OAuth2TokenExchangeResource.class);
install(new OAuth2ServiceModule());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
*/
package io.trino.server.security.oauth2;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jws;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.spi.security.BasicPrincipal;

import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;

import java.net.URI;
import java.security.Principal;
import java.util.Optional;
import java.util.UUID;

import static io.trino.server.security.UserMapping.createUserMapping;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getInitiateUri;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getTokenUri;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand All @@ -34,26 +35,36 @@ public class OAuth2Authenticator
extends AbstractBearerAuthenticator
{
private final OAuth2Service service;
private final String principalField;

@Inject
public OAuth2Authenticator(OAuth2Service service, OAuth2Config config)
{
super(config.getPrincipalField(), createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
super(createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile()));
this.service = requireNonNull(service, "service is null");
this.principalField = config.getPrincipalField();
}

@Override
protected Jws<Claims> parseClaimsJws(String jws)
protected Optional<Principal> extractPrincipalFromToken(String token)
{
return service.parseClaimsJws(jws);
try {
return service.convertTokenToClaims(token)
.map(claims -> claims.get(principalField))
.map(String.class::cast)
.map(BasicPrincipal::new);
}
catch (ChallengeFailedException e) {
return Optional.empty();
}
}

@Override
protected AuthenticationException needAuthentication(ContainerRequestContext request, String message)
{
UUID authId = UUID.randomUUID();
URI redirectUri = service.startRestChallenge(request.getUriInfo().getBaseUri().resolve(CALLBACK_ENDPOINT), authId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is changing the protocol needed to solve the "correctness issues"?

Copy link
Member

@dain dain Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not change the protocol. It simply changes the x_redirect_server URL returned in the authenticate header, but the protocol is completely unchanged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Protocol is not changed as there is no needed change in the client.

However there is a new round trip. And there is new endpoint method created. Commit message does not explain the motivation behind it. Previously we were starting the challenge at at the first call, now it is postponed and it is started once the new endpoint method is called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, and we should have always done that. It lets us use a nonce to protect the authentication code. Also in the future, we can use this to timeout clients faster... basically if a client doesn't hit the initiate in a few seconds, we kill the auth because likely the browser did not launch.

@11xor6 in the commit message mention the new use of nonce cookies for query clients.

URI initiateUri = request.getUriInfo().getBaseUri().resolve(getInitiateUri(authId));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is unrelated. It's adding an additional step to the authorization flow and it's neither a security nor a structural fix. It would be better not to piggyback a change of behaviour in a commit which purpose is already weakly defined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest a separate commit or even better a new PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked with Nik on this PR and agreed to the current commit structure.

URI tokenUri = request.getUriInfo().getBaseUri().resolve(getTokenUri(authId));
return new AuthenticationException(message, format("Bearer x_redirect_server=\"%s\", x_token_server=\"%s\"", redirectUri, tokenUri));
return new AuthenticationException(message, format("Bearer x_redirect_server=\"%s\", x_token_server=\"%s\"", initiateUri, tokenUri));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

import io.airlift.log.Logger;
import io.trino.server.security.ResourceSecurity;
import io.trino.server.security.oauth2.OAuth2Service.OAuthResult;
import io.trino.server.ui.OAuth2WebUiInstalled;
import io.trino.server.ui.OAuthWebUiCookie;

import javax.inject.Inject;
import javax.ws.rs.CookieParam;
Expand All @@ -28,20 +25,14 @@
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Cookie;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.ResponseBuilder;
import javax.ws.rs.core.UriInfo;

import java.net.URI;
import java.util.Optional;
import java.util.UUID;

import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC;
import static io.trino.server.security.oauth2.NonceCookie.NONCE_COOKIE;
import static io.trino.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOCATION;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static javax.ws.rs.core.MediaType.TEXT_HTML;
import static javax.ws.rs.core.Response.Status.BAD_REQUEST;

@Path(CALLBACK_ENDPOINT)
public class OAuth2CallbackResource
Expand All @@ -51,15 +42,11 @@ public class OAuth2CallbackResource
public static final String CALLBACK_ENDPOINT = "/oauth2/callback";

private final OAuth2Service service;
private final Optional<OAuth2TokenExchange> tokenExchange;
private final boolean webUiOAuthEnabled;

@Inject
public OAuth2CallbackResource(OAuth2Service service, Optional<OAuth2TokenExchange> tokenExchange, Optional<OAuth2WebUiInstalled> webUiOAuthEnabled)
public OAuth2CallbackResource(OAuth2Service service)
{
this.service = requireNonNull(service, "service is null");
this.tokenExchange = requireNonNull(tokenExchange, "tokenExchange is null");
this.webUiOAuthEnabled = requireNonNull(webUiOAuthEnabled, "webUiOAuthEnabled is null").isPresent();
}

@ResourceSecurity(PUBLIC)
Expand All @@ -74,70 +61,21 @@ public Response callback(
@CookieParam(NONCE_COOKIE) Cookie nonce,
@Context UriInfo uriInfo)
{
Optional<UUID> authId;
11xor6 marked this conversation as resolved.
Show resolved Hide resolved
try {
authId = service.getAuthId(state);
}
catch (ChallengeFailedException e) {
LOG.debug(e, "Authentication response could not be verified: state=%s", state);
return Response.ok()
.entity(service.getInternalFailureHtml("Authentication response could not be verified"))
.build();
}

// Note: the Web UI may be disabled, so REST requests can not redirect to a success or error page inside of the Web UI

if (error != null) {
LOG.debug("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state);

if (tokenExchange.isPresent() && authId.isPresent()) {
tokenExchange.get().setTokenExchangeError(
authId.get(),
format("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state));
}
return Response.ok()
.entity(service.getCallbackErrorHtml(error))
.build();
return service.handleOAuth2Error(state, error, errorDescription, errorUri);
}

OAuthResult result;
try {
result = service.finishChallenge(
authId,
code,
uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT),
NonceCookie.read(nonce));
requireNonNull(state, "state is null");
requireNonNull(code, "code is null");
11xor6 marked this conversation as resolved.
Show resolved Hide resolved
return service.finishOAuth2Challenge(state, code, uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), NonceCookie.read(nonce));
}
catch (ChallengeFailedException | RuntimeException e) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why RuntimeException was removed, it looks like section below is a copy paste. What am I missing?

Copy link
Member Author

@11xor6 11xor6 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We return a different status code in the two cases. There's some future work here in setting proper error codes and building better responses for failure cases. I believe the OIDC spec has some details about what these should look like, but that would be more of an overhaul that's suited to its own PR. For now the 400 helps us differentiate from "expected" failures like failed authorizations and those that are unexpected. Additionally I'm reluctant to return 500 (instead of 400) because the most likely cause of such a failure is actually a bad request of some kind.

catch (RuntimeException e) {
LOG.debug(e, "Authentication response could not be verified: state=%s", state);
if (tokenExchange.isPresent() && authId.isPresent()) {
tokenExchange.get().setTokenExchangeError(authId.get(), format("Authentication response could not be verified: state=%s", state));
}
return Response.ok()
return Response.status(BAD_REQUEST)
.cookie(NonceCookie.delete())
.entity(service.getInternalFailureHtml("Authentication response could not be verified"))
.build();
}

if (authId.isEmpty()) {
return Response
.seeOther(URI.create(UI_LOCATION))
.cookie(OAuthWebUiCookie.create(result.getAccessToken(), result.getTokenExpiration()), NonceCookie.delete())
.build();
}

if (tokenExchange.isEmpty()) {
LOG.debug("Token exchange is not active: state=%s", state);
return Response.ok()
.entity(service.getInternalFailureHtml("Client token exchange is not enabled"))
.build();
}

tokenExchange.get().setAccessToken(authId.get(), result.getAccessToken());

ResponseBuilder builder = Response.ok(service.getSuccessHtml());
if (webUiOAuthEnabled) {
builder.cookie(OAuthWebUiCookie.create(result.getAccessToken(), result.getTokenExpiration()));
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ public interface OAuth2Client
{
URI getAuthorizationUri(String state, URI callbackUri, Optional<String> nonceHash);

AccessToken getAccessToken(String code, URI callbackUri)
OAuth2Response getOAuth2Response(String code, URI callbackUri)
11xor6 marked this conversation as resolved.
Show resolved Hide resolved
throws ChallengeFailedException;

class AccessToken
class OAuth2Response
{
private final String accessToken;
private final Optional<Instant> validUntil;
private final Optional<String> idToken;

public AccessToken(String accessToken, Optional<Instant> validUntil, Optional<String> idToken)
public OAuth2Response(String accessToken, Optional<Instant> validUntil, Optional<String> idToken)
{
this.accessToken = requireNonNull(accessToken, "accessToken is null");
this.validUntil = requireNonNull(validUntil, "validUntil is null");
Expand Down
Loading