Skip to content

Commit

Permalink
feat: Intra server to server communication (feast-dev#4433)
Browse files Browse the repository at this point in the history
Intra server communication

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
  • Loading branch information
tmihalac authored Aug 29, 2024
1 parent 5e753e4 commit 729c874
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 45 deletions.
2 changes: 2 additions & 0 deletions infra/charts/feast-feature-server/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ spec:
env:
- name: FEATURE_STORE_YAML_BASE64
value: {{ .Values.feature_store_yaml_base64 }}
- name: INTRA_COMMUNICATION_BASE64
value: {{ "intra-server-communication" | b64enc }}
command:
{{- if eq .Values.feast_mode "offline" }}
- "feast"
Expand Down
11 changes: 8 additions & 3 deletions sdk/python/feast/permissions/auth/kubernetes_token_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os

import jwt
from kubernetes import client, config
Expand Down Expand Up @@ -41,10 +42,14 @@ async def user_details_from_access_token(self, access_token: str) -> User:
current_user = f"{sa_namespace}:{sa_name}"
logging.info(f"Received request from {sa_name} in {sa_namespace}")

roles = self.get_roles(sa_namespace, sa_name)
logging.info(f"SA roles are: {roles}")
intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64")
if sa_name is not None and sa_name == intra_communication_base64:
return User(username=sa_name, roles=[])
else:
roles = self.get_roles(sa_namespace, sa_name)
logging.info(f"SA roles are: {roles}")

return User(username=current_user, roles=roles)
return User(username=current_user, roles=roles)

def get_roles(self, namespace: str, service_account_name: str) -> list[str]:
"""
Expand Down
26 changes: 25 additions & 1 deletion sdk/python/feast/permissions/auth/oidc_token_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
from typing import Optional
from unittest.mock import Mock

import jwt
Expand Down Expand Up @@ -34,7 +36,7 @@ def __init__(self, auth_config: OidcAuthConfig):

async def _validate_token(self, access_token: str):
"""
Validate the token extracted from the headrer of the user request against the OAuth2 server.
Validate the token extracted from the header of the user request against the OAuth2 server.
"""
# FastAPI's OAuth2AuthorizationCodeBearer requires a Request type but actually uses only the headers field
# https://github.com/tiangolo/fastapi/blob/eca465f4c96acc5f6a22e92fd2211675ca8a20c8/fastapi/security/oauth2.py#L380
Expand All @@ -60,6 +62,11 @@ async def user_details_from_access_token(self, access_token: str) -> User:
AuthenticationError if any error happens.
"""

# check if intra server communication
user = self._get_intra_comm_user(access_token)
if user:
return user

try:
await self._validate_token(access_token)
logger.info("Validated token")
Expand Down Expand Up @@ -108,3 +115,20 @@ async def user_details_from_access_token(self, access_token: str) -> User:
except jwt.exceptions.InvalidTokenError:
logger.exception("Exception while parsing the token:")
raise AuthenticationError("Invalid token.")

def _get_intra_comm_user(self, access_token: str) -> Optional[User]:
intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64")

if intra_communication_base64:
decoded_token = jwt.decode(
access_token, options={"verify_signature": False}
)
if "preferred_username" in decoded_token:
preferred_username: str = decoded_token["preferred_username"]
if (
preferred_username is not None
and preferred_username == intra_communication_base64
):
return User(username=preferred_username, roles=[])

return None
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

import jwt

from feast.permissions.auth_model import KubernetesAuthConfig
from feast.permissions.client.auth_client_manager import AuthenticationClientManager

Expand All @@ -13,6 +15,15 @@ def __init__(self, auth_config: KubernetesAuthConfig):
self.token_file_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"

def get_token(self):
intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64")
# If intra server communication call
if intra_communication_base64:
payload = {
"sub": f":::{intra_communication_base64}", # Subject claim
}

return jwt.encode(payload, "")

try:
token = self._read_token_from_file()
return token
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os

import jwt
import requests

from feast.permissions.auth_model import OidcAuthConfig
Expand All @@ -14,6 +16,15 @@ def __init__(self, auth_config: OidcAuthConfig):
self.auth_config = auth_config

def get_token(self):
intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64")
# If intra server communication call
if intra_communication_base64:
payload = {
"preferred_username": f"{intra_communication_base64}", # Subject claim
}

return jwt.encode(payload, "")

# Fetch the token endpoint from the discovery URL
token_endpoint = OIDCDiscoveryService(
self.auth_config.auth_discovery_url
Expand Down
25 changes: 21 additions & 4 deletions sdk/python/feast/permissions/security_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from contextvars import ContextVar
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -110,6 +111,10 @@ def assert_permissions_to_update(
Raises:
FeastPermissionError: If the current user is not authorized to execute all the requested actions on the given resource or on the existing one.
"""
sm = get_security_manager()
if not is_auth_necessary(sm):
return resource

actions = [AuthzedAction.DESCRIBE, AuthzedAction.UPDATE]
try:
existing_resource = getter(
Expand Down Expand Up @@ -142,10 +147,11 @@ def assert_permissions(
Raises:
FeastPermissionError: If the current user is not authorized to execute the requested actions on the given resources.
"""

sm = get_security_manager()
if sm is None:
if not is_auth_necessary(sm):
return resource
return sm.assert_permissions(
return sm.assert_permissions( # type: ignore[union-attr]
resources=[resource], actions=actions, filter_only=False
)[0]

Expand All @@ -165,10 +171,11 @@ def permitted_resources(
Returns:
list[FeastObject]]: A filtered list of the permitted resources, possibly empty.
"""

sm = get_security_manager()
if sm is None:
if not is_auth_necessary(sm):
return resources
return sm.assert_permissions(resources=resources, actions=actions, filter_only=True)
return sm.assert_permissions(resources=resources, actions=actions, filter_only=True) # type: ignore[union-attr]


"""
Expand Down Expand Up @@ -201,3 +208,13 @@ def no_security_manager():

global _sm
_sm = None


def is_auth_necessary(sm: Optional[SecurityManager]) -> bool:
intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64")

return (
sm is not None
and sm.current_user is not None
and sm.current_user.username != intra_communication_base64
)
147 changes: 145 additions & 2 deletions sdk/python/tests/unit/permissions/auth/test_token_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# test_token_validator.py

import asyncio
import os
from unittest import mock
from unittest.mock import MagicMock, patch

import assertpy
Expand Down Expand Up @@ -70,6 +70,75 @@ def test_oidc_token_validation_failure(mock_oauth2, oidc_config):
)


@mock.patch.dict(os.environ, {"INTRA_COMMUNICATION_BASE64": "test1234"})
@pytest.mark.parametrize(
"intra_communication_val, is_intra_server",
[
("test1234", True),
("my-name", False),
],
)
def test_oidc_inter_server_comm(
intra_communication_val, is_intra_server, oidc_config, monkeypatch
):
async def mock_oath2(self, request):
return "OK"

monkeypatch.setattr(
"feast.permissions.auth.oidc_token_parser.OAuth2AuthorizationCodeBearer.__call__",
mock_oath2,
)
signing_key = MagicMock()
signing_key.key = "a-key"
monkeypatch.setattr(
"feast.permissions.auth.oidc_token_parser.PyJWKClient.get_signing_key_from_jwt",
lambda self, access_token: signing_key,
)

user_data = {
"preferred_username": f"{intra_communication_val}",
}

if not is_intra_server:
user_data["resource_access"] = {_CLIENT_ID: {"roles": ["reader", "writer"]}}

monkeypatch.setattr(
"feast.permissions.oidc_service.OIDCDiscoveryService._fetch_discovery_data",
lambda self, *args, **kwargs: {
"authorization_endpoint": "https://localhost:8080/realms/master/protocol/openid-connect/auth",
"token_endpoint": "https://localhost:8080/realms/master/protocol/openid-connect/token",
"jwks_uri": "https://localhost:8080/realms/master/protocol/openid-connect/certs",
},
)

monkeypatch.setattr(
"feast.permissions.auth.oidc_token_parser.jwt.decode",
lambda self, *args, **kwargs: user_data,
)

access_token = "aaa-bbb-ccc"
token_parser = OidcTokenParser(auth_config=oidc_config)
user = asyncio.run(
token_parser.user_details_from_access_token(access_token=access_token)
)

if is_intra_server:
assertpy.assert_that(user).is_not_none()
assertpy.assert_that(user.username).is_equal_to(intra_communication_val)
assertpy.assert_that(user.roles).is_equal_to([])
else:
assertpy.assert_that(user).is_not_none()
assertpy.assert_that(user).is_type_of(User)
if isinstance(user, User):
assertpy.assert_that(user.username).is_equal_to("my-name")
assertpy.assert_that(user.roles.sort()).is_equal_to(
["reader", "writer"].sort()
)
assertpy.assert_that(user.has_matching_role(["reader"])).is_true()
assertpy.assert_that(user.has_matching_role(["writer"])).is_true()
assertpy.assert_that(user.has_matching_role(["updater"])).is_false()


# TODO RBAC: Move role bindings to a reusable fixture
@patch("feast.permissions.auth.kubernetes_token_parser.config.load_incluster_config")
@patch("feast.permissions.auth.kubernetes_token_parser.jwt.decode")
Expand Down Expand Up @@ -127,3 +196,77 @@ def test_k8s_token_validation_failure(mock_jwt, mock_config):
asyncio.run(
token_parser.user_details_from_access_token(access_token=access_token)
)


@mock.patch.dict(os.environ, {"INTRA_COMMUNICATION_BASE64": "test1234"})
@pytest.mark.parametrize(
"intra_communication_val, is_intra_server",
[
("test1234", True),
("my-name", False),
],
)
def test_k8s_inter_server_comm(
intra_communication_val,
is_intra_server,
oidc_config,
request,
rolebindings,
clusterrolebindings,
monkeypatch,
):
if is_intra_server:
subject = f":::{intra_communication_val}"
else:
sa_name = request.getfixturevalue("sa_name")
namespace = request.getfixturevalue("namespace")
subject = f"system:serviceaccount:{namespace}:{sa_name}"
rolebindings = request.getfixturevalue("rolebindings")
clusterrolebindings = request.getfixturevalue("clusterrolebindings")

monkeypatch.setattr(
"feast.permissions.auth.kubernetes_token_parser.client.RbacAuthorizationV1Api.list_namespaced_role_binding",
lambda *args, **kwargs: rolebindings["items"],
)
monkeypatch.setattr(
"feast.permissions.auth.kubernetes_token_parser.client.RbacAuthorizationV1Api.list_cluster_role_binding",
lambda *args, **kwargs: clusterrolebindings["items"],
)
monkeypatch.setattr(
"feast.permissions.client.kubernetes_auth_client_manager.KubernetesAuthClientManager.get_token",
lambda self: "my-token",
)

monkeypatch.setattr(
"feast.permissions.auth.kubernetes_token_parser.config.load_incluster_config",
lambda: None,
)

monkeypatch.setattr(
"feast.permissions.auth.kubernetes_token_parser.jwt.decode",
lambda *args, **kwargs: {"sub": subject},
)

roles = rolebindings["roles"]
croles = clusterrolebindings["roles"]

access_token = "aaa-bbb-ccc"
token_parser = KubernetesTokenParser()
user = asyncio.run(
token_parser.user_details_from_access_token(access_token=access_token)
)

if is_intra_server:
assertpy.assert_that(user).is_not_none()
assertpy.assert_that(user.username).is_equal_to(intra_communication_val)
assertpy.assert_that(user.roles).is_equal_to([])
else:
assertpy.assert_that(user).is_type_of(User)
if isinstance(user, User):
assertpy.assert_that(user.username).is_equal_to(f"{namespace}:{sa_name}")
assertpy.assert_that(user.roles.sort()).is_equal_to((roles + croles).sort())
for r in roles:
assertpy.assert_that(user.has_matching_role([r])).is_true()
for cr in croles:
assertpy.assert_that(user.has_matching_role([cr])).is_true()
assertpy.assert_that(user.has_matching_role(["foo"])).is_false()
2 changes: 1 addition & 1 deletion sdk/python/tests/unit/permissions/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@pytest.mark.parametrize(
"username, can_read, can_write",
[
(None, False, False),
(None, True, True),
("r", True, False),
("w", False, True),
("rw", True, True),
Expand Down
Loading

0 comments on commit 729c874

Please sign in to comment.