Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _auth_gcp(self, _client: hvac.Client) -> None:
import time
import json
import googleapiclient

with open(self.gcp_key_path, 'r') as f:
creds = json.load(f)
service_account = creds['client_email']
Expand Down Expand Up @@ -342,19 +342,11 @@ def _auth_gcp(self, _client: hvac.Client) -> None:
role=self.role_id,
jwt=jwt,
mount_point=self.auth_mount_point)
_client.auth.gcp.login(
role=self.role_id,
jwt=jwt,
mount_point=self.auth_mount_point)
else:
_client.auth.gcp.login(
role=self.role_id,
jwt=jwt)

_client.auth.gcp.login(
role=self.role_id,
jwt=jwt)


def _auth_azure(self, _client: hvac.Client) -> None:
if self.auth_mount_point:
Expand Down
183 changes: 137 additions & 46 deletions providers/tests/hashicorp/_internal_client/test_vault_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest.mock import mock_open, patch

import pytest
import json
import time
from hvac.exceptions import InvalidPath, VaultError
from requests import Session
from requests.adapters import HTTPAdapter
Expand Down Expand Up @@ -230,15 +232,29 @@ def test_azure_missing_tenant_id(self, mock_hvac):
secret_id="pass",
)

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes, mock_google_build):
def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file


mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

# Mock the current time to use for iat and exp
current_time = int(time.time())
iat = current_time
exp = iat + 3600 # 1 hour after iat

# Mock the signJwt API to return the expected payload
mock_sign_jwt = mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}

Expand All @@ -251,22 +267,37 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes, mock_google
role_id="TODO",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:
def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Validate that the HVAC client and other mocks are called correctly
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
mock_sign_jwt.assert_called_with(
name="projects/project_id/serviceAccounts/service_account",
body={"payload": json.dumps({
"iat": mock.ANY,
"exp": mock.ANY,
"aud": "vault/role",
"sub": "credentials"
})}
)

# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(
role="role",
Expand All @@ -275,42 +306,70 @@ def test_gcp(self, mock_hvac, mock_get_credentials, mock_get_scopes, mock_google
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, mock_get_scopesm, mock_google_build):
def test_gcp_different_auth_mount_point(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file

mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

mock_sign_jwt = mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}


# Generate realistic iat and exp values
current_time = int(time.time())
iat = current_time
exp = current_time + 3600 # 1 hour later

vault_client = _VaultClient(
auth_type="gcp",
gcp_key_path="path.json",
gcp_scopes="scope1,scope2",
role_id="role",
url="http://localhost:8180",
auth_mount_point="other",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:
def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Assertions
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path="path.json", keyfile_dict=None, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
mock_sign_jwt.assert_called_with(
name="projects/project_id/serviceAccounts/service_account",
body={"payload": json.dumps({
"iat": mock.ANY,
"exp": mock.ANY,
"aud": "vault/test_role",
"sub": "credentials"
})}
)
# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(
role="role",
jwt="mocked_jwt",
Expand All @@ -319,38 +378,70 @@ def test_gcp_different_auth_mount_point(self, mock_hvac, mock_get_credentials, m
client.is_authenticated.assert_called_with()
assert vault_client.kv_engine_version == 2

@mock.patch("builtins.open", create=True)
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac")
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@mock.patch("googleapiclient.discovery.build")
def test_gcp_dict(self, mock_hvac, mock_get_credentials, mock_get_scopes):
def test_gcp_dict(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open):
# Mock the content of the file 'path.json'
mock_file = mock.MagicMock()
mock_file.read.return_value = '{"client_email": "service_account_email"}'
mock_open.return_value.__enter__.return_value = mock_file

# Mock the content of the keyfile dict
mock_client = mock.MagicMock()
mock_hvac.Client.return_value = mock_client
mock_hvac_client.return_value = mock_client
mock_get_scopes.return_value = ["scope1", "scope2"]
mock_get_credentials.return_value = ("credentials", "project_id")

mock_sign_jwt = mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}

# Generate realistic iat and exp values
current_time = int(time.time())
iat = current_time
exp = current_time + 3600 # 1 hour later

vault_client = _VaultClient(
auth_type="gcp",
gcp_keyfile_dict={"key": "value"},
gcp_scopes="scope1,scope2",
role_id="role",
url="http://localhost:8180",
session=None,
)
client = vault_client.client
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)

# Preserve the original json.dumps
original_json_dumps = json.dumps

# Inject the mocked payload into the JWT signing process
with mock.patch("json.dumps") as mock_json_dumps:
def mocked_json_dumps(payload):
# Override the payload to inject controlled iat and exp values
payload["iat"] = iat
payload["exp"] = exp
return original_json_dumps(payload) # Use the original json.dumps

mock_json_dumps.side_effect = mocked_json_dumps

client = vault_client.client # Trigger the Vault client creation

# Assertions
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
mock_get_scopes.assert_called_with("scope1,scope2")
mock_get_credentials.assert_called_with(
key_path=None, keyfile_dict={"key": "value"}, scopes=["scope1", "scope2"]
)
mock_hvac.Client.assert_called_with(url="http://localhost:8180", session=None)
mock_sign_jwt.assert_called_with(
name="projects/project_id/serviceAccounts/service_account",
body={"payload": json.dumps({
"iat": mock.ANY,
"exp": mock.ANY,
"aud": "vault/test_role",
"sub": "credentials"
})}
)
# Extract the arguments passed to the mocked signJwt API
args, kwargs = mock_sign_jwt.call_args
payload = json.loads(kwargs["body"]["payload"])

# Assert iat and exp values are as expected
assert payload["iat"] == iat
assert payload["exp"] == exp
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat

client.auth.gcp.login.assert_called_with(
role="role",
jwt="mocked_jwt"
Expand Down
Loading