Skip to content

Commit

Permalink
feat: Add optional account association for Authorized User credential…
Browse files Browse the repository at this point in the history
…s. (#1458)

* feat: Add optional account association for Authorized User credentials.

* chore: Refresh system test creds.

* Fix two missed constructors.
  • Loading branch information
clundin25 authored Jan 24, 2024
1 parent 9cd6742 commit 988153d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
42 changes: 42 additions & 0 deletions google/oauth2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
granted_scopes=None,
trust_boundary=None,
universe_domain=_DEFAULT_UNIVERSE_DOMAIN,
account=None,
):
"""
Args:
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
trust_boundary (str): String representation of trust boundary meta.
universe_domain (Optional[str]): The universe domain. The default
universe domain is googleapis.com.
account (Optional[str]): The account associated with the credential.
"""
super(Credentials, self).__init__()
self.token = token
Expand All @@ -149,6 +151,7 @@ def __init__(
self._enable_reauth_refresh = enable_reauth_refresh
self._trust_boundary = trust_boundary
self._universe_domain = universe_domain or _DEFAULT_UNIVERSE_DOMAIN
self._account = account or ""

def __getstate__(self):
"""A __getstate__ method must exist for the __setstate__ to be called
Expand Down Expand Up @@ -189,6 +192,7 @@ def __setstate__(self, d):
self._refresh_handler = None
self._refresh_worker = None
self._use_non_blocking_refresh = d.get("_use_non_blocking_refresh", False)
self._account = d.get("_account", "")

@property
def refresh_token(self):
Expand Down Expand Up @@ -268,6 +272,11 @@ def refresh_handler(self, value):
raise TypeError("The provided refresh_handler is not a callable or None.")
self._refresh_handler = value

@property
def account(self):
"""str: The user account associated with the credential. If the account is unknown an empty string is returned."""
return self._account

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):

Expand All @@ -286,6 +295,7 @@ def with_quota_project(self, quota_project_id):
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=self._universe_domain,
account=self._account,
)

@_helpers.copy_docstring(credentials.CredentialsWithTokenUri)
Expand All @@ -306,6 +316,35 @@ def with_token_uri(self, token_uri):
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=self._universe_domain,
account=self._account,
)

def with_account(self, account):
"""Returns a copy of these credentials with a modified account.
Args:
account (str): The account to set
Returns:
google.oauth2.credentials.Credentials: A new credentials instance.
"""

return self.__class__(
self.token,
refresh_token=self.refresh_token,
id_token=self.id_token,
token_uri=self._token_uri,
client_id=self.client_id,
client_secret=self.client_secret,
scopes=self.scopes,
default_scopes=self.default_scopes,
granted_scopes=self.granted_scopes,
quota_project_id=self.quota_project_id,
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=self._universe_domain,
account=account,
)

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
Expand All @@ -326,6 +365,7 @@ def with_universe_domain(self, universe_domain):
enable_reauth_refresh=self._enable_reauth_refresh,
trust_boundary=self._trust_boundary,
universe_domain=universe_domain,
account=self._account,
)

def _metric_header_for_usage(self):
Expand Down Expand Up @@ -474,6 +514,7 @@ def from_authorized_user_info(cls, info, scopes=None):
rapt_token=info.get("rapt_token"), # may not exist
trust_boundary=info.get("trust_boundary"), # may not exist
universe_domain=info.get("universe_domain"), # may not exist
account=info.get("account", ""), # may not exist
)

@classmethod
Expand Down Expand Up @@ -518,6 +559,7 @@ def to_json(self, strip=None):
"scopes": self.scopes,
"rapt_token": self.rapt_token,
"universe_domain": self._universe_domain,
"account": self._account,
}
if self.expiry: # flatten expiry timestamp
prep["expiry"] = self.expiry.isoformat() + "Z"
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
7 changes: 7 additions & 0 deletions tests/oauth2/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,12 @@ def test_with_universe_domain(self):
new_creds = creds.with_universe_domain("dummy_universe.com")
assert new_creds.universe_domain == "dummy_universe.com"

def test_with_account(self):
creds = credentials.Credentials(token="token")
assert creds.account == ""
new_creds = creds.with_account("mock@example.com")
assert new_creds.account == "mock@example.com"

def test_with_token_uri(self):
info = AUTH_USER_INFO.copy()

Expand Down Expand Up @@ -888,6 +894,7 @@ def test_to_json(self):
assert json_asdict.get("client_secret") == creds.client_secret
assert json_asdict.get("expiry") == info["expiry"]
assert json_asdict.get("universe_domain") == creds.universe_domain
assert json_asdict.get("account") == creds.account

# Test with a `strip` arg
json_output = creds.to_json(strip=["client_secret"])
Expand Down

0 comments on commit 988153d

Please sign in to comment.