Skip to content

Commit 87a6606

Browse files
clundin25johanblumenbergigorbernstein2
authored
fix: Resolve race condition reported in #692 (#1031)
Co-authored-by: Johan Blumenberg <johan.blumenberg@gmail.com> Co-authored-by: Igor Berntein <igorbernstein@google.com>
1 parent 43874fc commit 87a6606

File tree

2 files changed

+147
-17
lines changed

2 files changed

+147
-17
lines changed

oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import com.google.common.collect.ImmutableList;
4242
import com.google.common.collect.ImmutableMap;
4343
import com.google.common.collect.Iterables;
44+
import com.google.common.util.concurrent.AbstractFuture;
4445
import com.google.common.util.concurrent.FutureCallback;
4546
import com.google.common.util.concurrent.Futures;
4647
import com.google.common.util.concurrent.ListenableFuture;
@@ -60,7 +61,6 @@
6061
import java.util.concurrent.Callable;
6162
import java.util.concurrent.ExecutionException;
6263
import java.util.concurrent.Executor;
63-
import java.util.concurrent.Future;
6464
import javax.annotation.Nullable;
6565

6666
/** Base type for Credentials using OAuth2. */
@@ -77,7 +77,7 @@ public class OAuth2Credentials extends Credentials {
7777
// byte[] is serializable, so the lock variable can be final
7878
@VisibleForTesting final Object lock = new byte[0];
7979
private volatile OAuthValue value = null;
80-
@VisibleForTesting transient ListenableFutureTask<OAuthValue> refreshTask;
80+
@VisibleForTesting transient RefreshTask refreshTask;
8181

8282
// Change listeners are not serialized
8383
private transient List<CredentialsChangedListener> changeListeners;
@@ -258,16 +258,7 @@ public OAuthValue call() throws Exception {
258258
}
259259
});
260260

261-
task.addListener(
262-
new Runnable() {
263-
@Override
264-
public void run() {
265-
finishRefreshAsync(task);
266-
}
267-
},
268-
MoreExecutors.directExecutor());
269-
270-
refreshTask = task;
261+
refreshTask = new RefreshTask(task, new RefreshTaskListener(task));
271262

272263
return new AsyncRefreshResult(refreshTask, true);
273264
}
@@ -290,7 +281,7 @@ private void finishRefreshAsync(ListenableFuture<OAuthValue> finishedTask) {
290281
} catch (Exception e) {
291282
// noop
292283
} finally {
293-
if (this.refreshTask == finishedTask) {
284+
if (this.refreshTask != null && this.refreshTask.getTask() == finishedTask) {
294285
this.refreshTask = null;
295286
}
296287
}
@@ -307,7 +298,7 @@ private void finishRefreshAsync(ListenableFuture<OAuthValue> finishedTask) {
307298
* thread of whatever executor the async call used. This doesn't affect correctness and is
308299
* extremely unlikely.
309300
*/
310-
private static <T> T unwrapDirectFuture(Future<T> future) throws IOException {
301+
private static <T> T unwrapDirectFuture(ListenableFuture<T> future) throws IOException {
311302
try {
312303
return future.get();
313304
} catch (InterruptedException e) {
@@ -567,10 +558,10 @@ public void onFailure(Throwable throwable) {
567558
* task is newly created, it is the caller's responsibility to execute it.
568559
*/
569560
static class AsyncRefreshResult {
570-
private final ListenableFutureTask<OAuthValue> task;
561+
private final RefreshTask task;
571562
private final boolean isNew;
572563

573-
AsyncRefreshResult(ListenableFutureTask<OAuthValue> task, boolean isNew) {
564+
AsyncRefreshResult(RefreshTask task, boolean isNew) {
574565
this.task = task;
575566
this.isNew = isNew;
576567
}
@@ -582,6 +573,57 @@ void executeIfNew(Executor executor) {
582573
}
583574
}
584575

576+
@VisibleForTesting
577+
class RefreshTaskListener implements Runnable {
578+
private ListenableFutureTask<OAuthValue> task;
579+
580+
RefreshTaskListener(ListenableFutureTask<OAuthValue> task) {
581+
this.task = task;
582+
}
583+
584+
@Override
585+
public void run() {
586+
finishRefreshAsync(task);
587+
}
588+
}
589+
590+
class RefreshTask extends AbstractFuture<OAuthValue> implements Runnable {
591+
private final ListenableFutureTask<OAuthValue> task;
592+
private final RefreshTaskListener listener;
593+
594+
RefreshTask(ListenableFutureTask<OAuthValue> task, RefreshTaskListener listener) {
595+
this.task = task;
596+
this.listener = listener;
597+
598+
// Update Credential state first
599+
task.addListener(listener, MoreExecutors.directExecutor());
600+
601+
// Then notify the world
602+
Futures.addCallback(
603+
task,
604+
new FutureCallback<OAuthValue>() {
605+
@Override
606+
public void onSuccess(OAuthValue result) {
607+
RefreshTask.this.set(result);
608+
}
609+
610+
@Override
611+
public void onFailure(Throwable t) {
612+
RefreshTask.this.setException(t);
613+
}
614+
},
615+
MoreExecutors.directExecutor());
616+
}
617+
618+
public ListenableFutureTask<OAuthValue> getTask() {
619+
return this.task;
620+
}
621+
622+
public void run() {
623+
task.run();
624+
}
625+
}
626+
585627
public static class Builder {
586628

587629
private AccessToken accessToken;

oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import com.google.auth.http.AuthHttpConstants;
4848
import com.google.auth.oauth2.GoogleCredentialsTest.MockTokenServerTransportFactory;
4949
import com.google.auth.oauth2.OAuth2Credentials.OAuthValue;
50+
import com.google.auth.oauth2.OAuth2Credentials.RefreshTask;
51+
import com.google.auth.oauth2.OAuth2Credentials.RefreshTaskListener;
5052
import com.google.common.collect.ImmutableList;
5153
import com.google.common.collect.ImmutableMap;
5254
import com.google.common.util.concurrent.ListenableFutureTask;
@@ -58,6 +60,7 @@
5860
import java.util.ArrayList;
5961
import java.util.Arrays;
6062
import java.util.Date;
63+
import java.util.HashMap;
6164
import java.util.List;
6265
import java.util.Map;
6366
import java.util.concurrent.Callable;
@@ -590,7 +593,7 @@ public AccessToken refreshAccessToken() {
590593
creds.getRequestMetadata(CALL_URI, realExecutor, callback);
591594
TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN);
592595
assertNotNull(creds.refreshTask);
593-
ListenableFutureTask<OAuthValue> refreshTask = creds.refreshTask;
596+
RefreshTask refreshTask = creds.refreshTask;
594597

595598
// Fast forward to expiration, which will hang cause the callback to hang
596599
testClock.setCurrentTime(clientExpired.toEpochMilli());
@@ -873,6 +876,91 @@ public void serialize() throws IOException, ClassNotFoundException {
873876
assertSame(deserializedCredentials.clock, Clock.SYSTEM);
874877
}
875878

879+
@Test
880+
public void updateTokenValueBeforeWake() throws IOException, InterruptedException {
881+
final SettableFuture<AccessToken> refreshedTokenFuture = SettableFuture.create();
882+
AccessToken refreshedToken = new AccessToken("2/MkSJoj1xsli0AccessToken_NKPY2", null);
883+
refreshedTokenFuture.set(refreshedToken);
884+
885+
final ListenableFutureTask<OAuthValue> task =
886+
ListenableFutureTask.create(
887+
new Callable<OAuthValue>() {
888+
@Override
889+
public OAuthValue call() throws Exception {
890+
return OAuthValue.create(refreshedToken, new HashMap<>());
891+
}
892+
});
893+
894+
OAuth2Credentials creds =
895+
new OAuth2Credentials() {
896+
@Override
897+
public AccessToken refreshAccessToken() {
898+
synchronized (this) {
899+
// Wake up the main thread. This is done now because the child thread (t) is known to
900+
// have the refresh task. Now we want the main thread to wake up and create a future
901+
// in order to wait for the refresh to complete.
902+
this.notify();
903+
}
904+
RefreshTaskListener listener =
905+
new RefreshTaskListener(task) {
906+
@Override
907+
public void run() {
908+
try {
909+
// Sleep before setting accessToken to new accessToken. Refresh should not
910+
// complete before this, and the accessToken is `null` until it is.
911+
Thread.sleep(300);
912+
super.run();
913+
} catch (Exception e) {
914+
fail("Unexpected error. Exception: " + e);
915+
}
916+
}
917+
};
918+
919+
this.refreshTask = new RefreshTask(task, listener);
920+
921+
try {
922+
// Sleep for 100 milliseconds to give parent thread time to create a refresh future.
923+
Thread.sleep(100);
924+
return refreshedTokenFuture.get();
925+
} catch (Exception e) {
926+
throw new RuntimeException(e);
927+
}
928+
}
929+
};
930+
931+
Thread t =
932+
new Thread(
933+
new Runnable() {
934+
@Override
935+
public void run() {
936+
try {
937+
creds.refresh();
938+
assertNotNull(creds.getAccessToken());
939+
} catch (Exception e) {
940+
fail("Unexpected error. Exception: " + e);
941+
}
942+
}
943+
});
944+
t.start();
945+
946+
synchronized (creds) {
947+
// Grab a lock on creds object. This thread (the main thread) will wait here until the child
948+
// thread (t) calls `notify` on the creds object.
949+
creds.wait();
950+
}
951+
952+
AccessToken token = creds.getAccessToken();
953+
assertNull(token);
954+
955+
creds.refresh();
956+
token = creds.getAccessToken();
957+
// Token should never be NULL after a refresh that succeeded.
958+
// Previously the token could be NULL due to an internal race condition between the future
959+
// completing and the task listener updating the value of the access token.
960+
assertNotNull(token);
961+
t.join();
962+
}
963+
876964
private void waitForRefreshTaskCompletion(OAuth2Credentials credentials)
877965
throws TimeoutException, InterruptedException {
878966
for (int i = 0; i < 100; i++) {

0 commit comments

Comments
 (0)