Skip to content

Commit

Permalink
Add zone ID to expiring codes
Browse files Browse the repository at this point in the history
  • Loading branch information
fhanik committed May 12, 2017
1 parent b8f1bf1 commit 2ca35f1
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Cloud Foundry
* Cloud Foundry
* Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved.
*
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
Expand All @@ -12,6 +12,7 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.codestore;

import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;

import java.sql.Timestamp;
Expand All @@ -20,7 +21,7 @@ public interface ExpiringCodeStore {

/**
* Generate and persist a one-time code with an expiry date.
*
*
* @param data JSON object to be associated with the code
* @param intent An optional key (not necessarily unique) for looking up codes
* @return code the generated one-time code
Expand All @@ -31,7 +32,7 @@ public interface ExpiringCodeStore {

/**
* Retrieve a code and delete it if it exists.
*
*
* @param code the one-time code to look for
* @return code or null if the code is not found
* @throws java.lang.NullPointerException if the code is null
Expand All @@ -40,7 +41,7 @@ public interface ExpiringCodeStore {

/**
* Set the code generator for this store.
*
*
* @param generator Code generator
*/
void setGenerator(RandomValueStringGenerator generator);
Expand All @@ -51,4 +52,16 @@ public interface ExpiringCodeStore {
* @param intent Intent of codes to remove
*/
void expireByIntent(String intent);

default String zonifyCode(String code) {
return code + "[zone[" + IdentityZoneHolder.get().getId()+"]]";
}

default String extractCode(String zoneCode) {
int endIndex = zoneCode.indexOf("[zone[" + IdentityZoneHolder.get().getId()+"]]");
if (endIndex<0) {
return zoneCode;
}
return zoneCode.substring(0, endIndex);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.TimeServiceImpl;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -44,7 +45,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent

ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent);

ExpiringCode duplicate = store.putIfAbsent(code, expiringCode);
ExpiringCode duplicate = store.putIfAbsent(zonifyCode(code), expiringCode);
if (duplicate != null) {
throw new DataIntegrityViolationException("Duplicate code: " + code);
}
Expand All @@ -58,7 +59,7 @@ public ExpiringCode retrieveCode(String code) {
throw new NullPointerException();
}

ExpiringCode expiringCode = store.remove(code);
ExpiringCode expiringCode = store.remove(zonifyCode(code));

if (expiringCode == null || isExpired(expiringCode)) {
expiringCode = null;
Expand All @@ -79,8 +80,8 @@ public void setGenerator(RandomValueStringGenerator generator) {
@Override
public void expireByIntent(String intent) {
Assert.hasText(intent);

store.values().stream().filter(c -> intent.equals(c.getIntent())).forEach(c -> store.remove(c.getCode()));
String id = IdentityZoneHolder.get().getId();
store.entrySet().stream().filter(c -> c.getKey().contains(id) && intent.equals(c.getValue().getIntent())).forEach(c -> store.remove(c.getKey()));
}

public InMemoryExpiringCodeStore setTimeService(TimeService timeService) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ public class JdbcExpiringCodeStore implements ExpiringCodeStore {

protected static final String insert = "insert into " + tableName + " (" + fields + ") values (?,?,?,?)";
protected static final String delete = "delete from " + tableName + " where code = ?";
protected static final String deleteIntent = "delete from " + tableName + " where intent = ?";
protected static final String deleteIntent = "delete from " + tableName + " where intent = ? and code LIKE ?";
protected static final String deleteExpired = "delete from " + tableName + " where expiresat < ?";

private static final JdbcExpiringCodeMapper rowMapper = new JdbcExpiringCodeMapper();
private final JdbcExpiringCodeMapper rowMapper = new JdbcExpiringCodeMapper();

protected static final String selectAllFields = "select " + fields + " from " + tableName + " where code = ?";

Expand Down Expand Up @@ -98,7 +98,7 @@ public ExpiringCode generateCode(String data, Timestamp expiresAt, String intent
count++;
String code = generator.generate();
try {
int update = jdbcTemplate.update(insert, code, expiresAt.getTime(), data, intent);
int update = jdbcTemplate.update(insert, zonifyCode(code), expiresAt.getTime(), data, intent);
if (update == 1) {
ExpiringCode expiringCode = new ExpiringCode(code, expiresAt, data, intent);
return expiringCode;
Expand All @@ -124,9 +124,9 @@ public ExpiringCode retrieveCode(String code) {
}

try {
ExpiringCode expiringCode = jdbcTemplate.queryForObject(selectAllFields, rowMapper, code);
ExpiringCode expiringCode = jdbcTemplate.queryForObject(selectAllFields, rowMapper, zonifyCode(code));
if (expiringCode != null) {
jdbcTemplate.update(delete, code);
jdbcTemplate.update(delete, zonifyCode(code));
}
if (expiringCode.getExpiresAt().getTime() < timeService.getCurrentTimeMillis()) {
expiringCode = null;
Expand All @@ -146,7 +146,7 @@ public void setGenerator(RandomValueStringGenerator generator) {
public void expireByIntent(String intent) {
Assert.hasText(intent);

jdbcTemplate.update(deleteIntent, intent);
jdbcTemplate.update(deleteIntent, intent, zonifyCode("%")+"%");
}

public int cleanExpiredEntries() {
Expand All @@ -162,11 +162,11 @@ public int cleanExpiredEntries() {
return 0;
}

protected static class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {
protected class JdbcExpiringCodeMapper implements RowMapper<ExpiringCode> {

@Override
public ExpiringCode mapRow(ResultSet rs, int rowNum) throws SQLException {
String code = rs.getString("code");
String code = extractCode(rs.getString("code"));
Timestamp expiresAt = new Timestamp(rs.getLong("expiresat"));
String intent = rs.getString("intent");
String data = rs.getString("data");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.cloudfoundry.identity.uaa.test.TestUtils;
import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.TimeServiceImpl;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -26,12 +28,15 @@
import org.springframework.dao.DataAccessException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.test.util.ReflectionTestUtils;

import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -50,7 +55,7 @@ public ExpiringCodeStoreTests(Class expiringCodeStoreClass) {
@Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
{ InMemoryExpiringCodeStore.class }, { JdbcExpiringCodeStore.class },
{ InMemoryExpiringCodeStore.class }, { JdbcExpiringCodeStore.class },
});
}

Expand All @@ -68,6 +73,16 @@ public void initExpiringCodeStoreTests() throws Exception {
}
}

public int countCodes() {
if (expiringCodeStore instanceof InMemoryExpiringCodeStore) {
Map map = (Map) ReflectionTestUtils.getField(expiringCodeStore, "store");
return map.size();
} else {
// confirm that everything is clean prior to test.
return jdbcTemplate.queryForObject("select count(*) from expiring_code_store", Integer.class);
}
}

@Test
public void testGenerateCode() throws Exception {
String data = "{}";
Expand Down Expand Up @@ -132,6 +147,22 @@ public void testRetrieveCode() throws Exception {
Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));
}

@Test
public void testRetrieveCode_In_Another_Zone() throws Exception {
String data = "{}";
Timestamp expiresAt = new Timestamp(System.currentTimeMillis() + 60000);
ExpiringCode generatedCode = expiringCodeStore.generateCode(data, expiresAt, null);

IdentityZoneHolder.set(MultitenancyFixture.identityZone("other","other"));
Assert.assertNull(expiringCodeStore.retrieveCode(generatedCode.getCode()));

IdentityZoneHolder.clear();
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(generatedCode.getCode());
Assert.assertEquals(generatedCode, retrievedCode);


}

@Test
public void testRetrieveCodeWithCodeNotFound() throws Exception {
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode("unknown");
Expand All @@ -150,7 +181,7 @@ public void testStoreLargeData() throws Exception {
Arrays.fill(oneMb, 'a');
String aaaString = new String(oneMb);
ExpiringCode expiringCode = expiringCodeStore.generateCode(aaaString, new Timestamp(
System.currentTimeMillis() + 60000), null);
System.currentTimeMillis() + 60000), null);
String code = expiringCode.getCode();
ExpiringCode actualCode = expiringCodeStore.retrieveCode(code);
Assert.assertEquals(expiringCode, actualCode);
Expand All @@ -174,10 +205,16 @@ public void testExpiredCodeReturnsNull() throws Exception {
public void testExpireCodeByIntent() throws Exception {
ExpiringCode code = expiringCodeStore.generateCode("{}", new Timestamp(System.currentTimeMillis() + 60000), "Test Intent");

Assert.assertEquals(1, countCodes());

IdentityZoneHolder.set(MultitenancyFixture.identityZone("id","id"));
expiringCodeStore.expireByIntent("Test Intent");
Assert.assertEquals(1, countCodes());

IdentityZoneHolder.clear();
expiringCodeStore.expireByIntent("Test Intent");
ExpiringCode retrievedCode = expiringCodeStore.retrieveCode(code.getCode());

Assert.assertEquals(0, countCodes());
Assert.assertNull(retrievedCode);
}

Expand Down Expand Up @@ -206,10 +243,10 @@ public void testExpirationCleaner() throws Exception {
jdbcTemplate.update(JdbcExpiringCodeStore.insert, "test", System.currentTimeMillis() - 1000, "{}", null);
((JdbcExpiringCodeStore) expiringCodeStore).cleanExpiredEntries();
jdbcTemplate.queryForObject(JdbcExpiringCodeStore.selectAllFields,
new JdbcExpiringCodeStore.JdbcExpiringCodeMapper(), "test");
(RowMapper<ExpiringCode>) ReflectionTestUtils.getField(expiringCodeStore, "rowMapper"), "test");
} else {
throw new EmptyResultDataAccessException(1);
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.cloudfoundry.identity.uaa.codestore.ExpiringCode;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeStore;
import org.cloudfoundry.identity.uaa.codestore.ExpiringCodeType;
import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.mock.InjectedMockContextTest;
import org.cloudfoundry.identity.uaa.mock.util.MockMvcUtils;
Expand Down Expand Up @@ -92,6 +93,7 @@ public void setUp() throws Exception {

@After
public void cleanUpDomainList() throws Exception {
IdentityZoneHolder.clear();
IdentityProvider<UaaIdentityProviderDefinition> uaaProvider = getWebApplicationContext().getBean(JdbcIdentityProviderProvisioning.class).retrieveByOrigin(UAA, IdentityZone.getUaa().getId());
uaaProvider.getConfig().setEmailDomain(null);
getWebApplicationContext().getBean(JdbcIdentityProviderProvisioning.class).update(uaaProvider);
Expand Down Expand Up @@ -148,7 +150,7 @@ public void invite_User_In_Zone_With_DefaultZone_UaaAdmin() throws Exception {
InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class);
BaseClientDetails defaultClientDetails = new BaseClientDetails();
defaultClientDetails.setClientId("admin");
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, defaultClientDetails);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, defaultClientDetails);

}

Expand Down Expand Up @@ -182,7 +184,7 @@ public void invite_User_In_Zone_With_DefaultZone_ZoneAdmin() throws Exception {
.andReturn();

InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, zonifiedScimInviteClientDetails);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, zonifiedScimInviteClientDetails);

}

Expand Down Expand Up @@ -216,7 +218,7 @@ public void invite_User_In_Zone_With_DefaultZone_ScimInvite() throws Exception {
.andReturn();

InvitationsResponse invitationsResponse = readValue(mvcResult.getResponse().getContentAsString(), InvitationsResponse.class);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), invitationsResponse, zonifiedScimInviteClientDetails);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), invitationsResponse, zonifiedScimInviteClientDetails);

}

Expand All @@ -235,7 +237,7 @@ public void invite_User_Within_Zone() throws Exception {
String redirectUrl = "example.com";
InvitationsResponse response = sendRequestWithTokenAndReturnResponse(zonedScimInviteToken, result.getIdentityZone().getSubdomain(), zonedClientDetails.getClientId(), redirectUrl, email);

assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone().getSubdomain(), response, zonedClientDetails);
assertResponseAndCodeCorrect(new String[] {email}, redirectUrl, result.getIdentityZone(), response, zonedClientDetails);
}

@Test
Expand Down Expand Up @@ -346,6 +348,7 @@ public void invitations_Accept_Get_Security() throws Exception {
sendRequestWithToken(userToken, null, clientId, "example.com", "user1@"+domain);

String code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("SELECT code FROM expiring_code_store", String.class);
code = new InMemoryExpiringCodeStore().extractCode(code);
assertNotNull("Invite Code Must be Present", code);

MockHttpServletRequestBuilder accept = get("/invitations/accept")
Expand All @@ -371,7 +374,7 @@ public void sendRequestWithToken(String token, String subdomain, String clientId
assertThat(response.getFailedInvites().size(), is(0));
}

private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, String subdomain, InvitationsResponse response, ClientDetails clientDetails) {
private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, IdentityZone zone, InvitationsResponse response, ClientDetails clientDetails) {
for (int i = 0; i < emails.length; i++) {
assertThat(response.getNewInvites().size(), is(emails.length));
assertThat(response.getNewInvites().get(i).getEmail(), is(emails[i]));
Expand All @@ -382,8 +385,9 @@ private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, S
String link = response.getNewInvites().get(i).getInviteLink().toString();
assertFalse(contains(link, "@"));
assertFalse(contains(link, "%40"));
if (StringUtils.hasText(subdomain)) {
assertThat(link, startsWith("http://" + subdomain + ".localhost/invitations/accept"));
if (zone != null && StringUtils.hasText(zone.getSubdomain())) {
assertThat(link, startsWith("http://" + zone.getSubdomain() + ".localhost/invitations/accept"));
IdentityZoneHolder.set(zone);
} else {
assertThat(link, startsWith("http://localhost/invitations/accept"));
}
Expand All @@ -392,6 +396,7 @@ private void assertResponseAndCodeCorrect(String[] emails, String redirectUrl, S
assertThat(query, startsWith("code="));
String code = query.split("=")[1];
ExpiringCode expiringCode = codeStore.retrieveCode(code);
IdentityZoneHolder.clear();
assertThat(expiringCode.getExpiresAt().getTime(), is(greaterThan(System.currentTimeMillis())));
assertThat(expiringCode.getIntent(), is(ExpiringCodeType.INVITATION.name()));
Map<String, String> data = readValue(expiringCode.getData(), new TypeReference<Map<String, String>>() {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.cloudfoundry.identity.uaa.login;

import org.cloudfoundry.identity.uaa.codestore.InMemoryExpiringCodeStore;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.message.EmailService;
import org.cloudfoundry.identity.uaa.message.util.FakeJavaMailSender;
Expand Down Expand Up @@ -244,6 +245,7 @@ public void accept_invitation_sets_your_password() throws Exception {
.andReturn();

code = getWebApplicationContext().getBean(JdbcTemplate.class).queryForObject("select code from expiring_code_store", String.class);
code = new InMemoryExpiringCodeStore().extractCode(code);
MockHttpSession session = (MockHttpSession) result.getRequest().getSession(false);
result = getMockMvc().perform(
post("/invitations/accept.do")
Expand Down
Loading

0 comments on commit 2ca35f1

Please sign in to comment.