Skip to content

Commit

Permalink
Add test: accessTokenScopesIT (#18322)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichengliu12581 authored Dec 23, 2020
1 parent 251938c commit 9ec1f54
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 69 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.test.aad.converter;

import com.azure.test.oauth.SeleniumTestUtils;
import com.azure.test.utils.AppRunner;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.security.config.annotation.method.configuration.EnableGlobalMethodSecurity;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.*;

public class RefreshTokenScopesIT {

@Test
public void testRefreshTokenConverter() {
try (AppRunner app = new AppRunner(DumbApp.class)) {
SeleniumTestUtils.addProperty(app);
app.property("azure.activedirectory.authorization.office.scopes", "https://manage.office.com/ActivityFeed.Read");
app.property("azure.activedirectory.authorization.graph.scopes", "https://graph.microsoft.com/User.Read");
List<String> endPoints = new ArrayList<>();
endPoints.add("api/office");
endPoints.add("api/azure");
endPoints.add("api/graph");
endPoints.add("api/arm");
Map<String, String> result = SeleniumTestUtils.get(app, endPoints);

Assert.assertFalse(result.get("api/office").contains("profile"));
Assert.assertTrue(result.get("api/office").contains("https://manage.office.com/ActivityFeed.Read"));

Assert.assertTrue(result.get("api/azure").contains("profile"));
Assert.assertTrue(result.get("api/azure").contains("https://graph.microsoft.com/User.Read"));

Assert.assertTrue(result.get("api/graph").contains("profile"));
Assert.assertTrue(result.get("api/graph").contains("https://graph.microsoft.com/User.Read"));

Assert.assertNotEquals("error", result.get("api/arm"));
}
}

@EnableGlobalMethodSecurity(securedEnabled = true, prePostEnabled = true)
@SpringBootApplication
@RestController
public static class DumbApp {

@GetMapping(value = "api/office")
public Set<String> office(
@RegisteredOAuth2AuthorizedClient("office") OAuth2AuthorizedClient authorizedClient) {
return Optional.of(authorizedClient)
.map(OAuth2AuthorizedClient::getAccessToken)
.map(OAuth2AccessToken::getScopes)
.orElse(null);
}

@GetMapping(value = "api/azure")
public Set<String> azure(
@RegisteredOAuth2AuthorizedClient("azure") OAuth2AuthorizedClient authorizedClient) {
return Optional.of(authorizedClient)
.map(OAuth2AuthorizedClient::getAccessToken)
.map(OAuth2AccessToken::getScopes)
.orElse(null);
}

@GetMapping(value = "api/graph")
public Set<String> graph(
@RegisteredOAuth2AuthorizedClient("graph") OAuth2AuthorizedClient authorizedClient) {
return Optional.of(authorizedClient)
.map(OAuth2AuthorizedClient::getAccessToken)
.map(OAuth2AccessToken::getScopes)
.orElse(null);
}

@GetMapping(value = "api/arm")
public String arm(
@RegisteredOAuth2AuthorizedClient("arm") OAuth2AuthorizedClient authorizedClient) {
return "error";
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.azure.test.aad.login;

import com.azure.test.oauth.OAuthLoginUtils;
import com.azure.test.oauth.SeleniumTestUtils;
import com.azure.test.utils.AppRunner;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -20,6 +20,7 @@
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class AADLoginIT {

Expand All @@ -29,15 +30,15 @@ public class AADLoginIT {
public void loginTest() {

try (AppRunner app = new AppRunner(DumbApp.class)) {
OAuthLoginUtils.addProperty(app);
SeleniumTestUtils.addProperty(app);
List<String> endPoints = new ArrayList<>();
endPoints.add("api/home");
endPoints.add("api/group1");
endPoints.add("api/status403");
List<String> result = OAuthLoginUtils.get(app , endPoints);
Assert.assertEquals("home", result.get(0));
Assert.assertEquals("group1", result.get(1));
Assert.assertNotEquals("error", result.get(2));
Map<String, String> result = SeleniumTestUtils.get(app , endPoints);
Assert.assertEquals("home", result.get("api/home"));
Assert.assertEquals("group1", result.get("api/group1"));
Assert.assertNotEquals("error", result.get("api/status403"));
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

Expand All @@ -14,14 +15,15 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import static com.azure.test.oauth.OAuthUtils.*;
import static org.openqa.selenium.support.ui.ExpectedConditions.presenceOfElementLocated;

public class OAuthLoginUtils {
public class SeleniumTestUtils {

static {
final String directory = "src/test/resources/driver/";
Expand Down Expand Up @@ -56,9 +58,9 @@ public class OAuthLoginUtils {
}
}

public static List<String> get(AppRunner app, List<String> endPoints) {
public static Map<String, String> get(AppRunner app, List<String> endPoints) {

List<String> result = new ArrayList<>();
Map<String , String> result = new HashMap<>();
ChromeOptions options = new ChromeOptions();
options.addArguments("--incognito");
options.addArguments("--headless");
Expand All @@ -77,12 +79,12 @@ public static List<String> get(AppRunner app, List<String> endPoints) {
Thread.sleep(10000);
driver.findElement(By.cssSelector("input[type='submit']")).click();
Thread.sleep(10000);
result.add(driver.findElement(By.tagName("body")).getText());
result.put(endPoints.get(0) , driver.findElement(By.tagName("body")).getText());
endPoints.remove(0);
for(String endPoint : endPoints) {
driver.get(app.root() + endPoint);
Thread.sleep(1000);
result.add(driver.findElement(By.tagName("body")).getText());
result.put(endPoint ,driver.findElement(By.tagName("body")).getText());
}
return result;
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,8 @@ public static boolean isDefaultClient(ClientRegistration clientRegistration) {
return AZURE_CLIENT_REGISTRATION_ID.equals(
clientRegistration.getClientName());
}

public static boolean isDefaultClient(String clientId) {
return AZURE_CLIENT_REGISTRATION_ID.equals(clientId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.azure.spring.aad.webapp;

import com.azure.spring.aad.AADClientRegistrationRepository;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
Expand Down Expand Up @@ -65,7 +66,7 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String id,
OAuth2AuthorizationContext.Builder contextBuilder =
OAuth2AuthorizationContext.withAuthorizedClient(fakeAuthzClient);
String[] scopes = null;
if (!AADWebAppClientRegistrationRepository.AZURE_CLIENT_REGISTRATION_ID.equals(id)) {
if (!AADClientRegistrationRepository.isDefaultClient(id)) {
scopes = repo.findByRegistrationId(id).getScopes().toArray(new String[0]);
}
OAuth2AuthorizationContext context = contextBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private MultiValueMap<String, String> convertedBodyOf(OAuth2AuthorizationCodeGra
AuthzCodeGrantRequestEntityConverter converter =
new AuthzCodeGrantRequestEntityConverter(clientRepo.getAzureClient());
RequestEntity<?> entity = converter.convert(request);
return PropertiesUtils.requestEntityConverter(entity);
return PropertiesUtils.toMultiValueMap(entity);
}

private OAuth2AuthorizationCodeGrantRequest createCodeGrantRequest(ClientRegistration client) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static WebApplicationContextRunner getContextRunner() {


@SuppressWarnings("unchecked")
public static MultiValueMap<String, String> requestEntityConverter(RequestEntity<?> entity) {
public static MultiValueMap<String, String> toMultiValueMap(RequestEntity<?> entity) {
return (MultiValueMap<String, String>) Optional.ofNullable(entity)
.map(HttpEntity::getBody)
.orElse(null);
Expand Down

0 comments on commit 9ec1f54

Please sign in to comment.