Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions temporallib/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
from macaroonbakery.bakery import Macaroon, b64decode, macaroon_to_dict
from macaroonbakery.httpbakery.agent import Agent, AgentInteractor, AuthInfo
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict


class MacaroonAuthOptions(BaseSettings):
macaroon_url: str = Field(None, alias="TEMPORAL_CANDID_URL")
username: str
keys: Optional[KeyPair]

class Config:
env_prefix = "TEMPORAL_CANDID_"
populate_by_name = True
model_config = SettingsConfigDict(
env_prefix="TEMPORAL_CANDID_", populate_by_name=True
)


class GoogleAuthOptions(BaseSettings):
Expand All @@ -39,25 +39,23 @@ class GoogleAuthOptions(BaseSettings):
None, alias="TEMPORAL_OIDC_CLIENT_CERT_URL"
)

class Config:
env_prefix = "TEMPORAL_OIDC_"
populate_by_name = True
model_config = SettingsConfigDict(
env_prefix="TEMPORAL_OIDC_", populate_by_name=True
)


class KeyPair(BaseSettings):
private: str = Field(None, alias="TEMPORAL_CANDID_PRIVATE_KEY")
public: str = Field(None, alias="TEMPORAL_CANDID_PUBLIC_KEY")

class Config:
populate_by_name = True
model_config = SettingsConfigDict(populate_by_name=True)


class AuthOptions(BaseSettings):
config: Optional[Union[MacaroonAuthOptions, GoogleAuthOptions]] = None
provider: str

class Config:
env_prefix = "TEMPORAL_AUTH_"
model_config = SettingsConfigDict(env_prefix="TEMPORAL_AUTH_")


class AuthHeaderProvider:
Expand Down
16 changes: 10 additions & 6 deletions temporallib/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from typing import Callable, Iterable, Mapping, Optional, Union

from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from temporalio.client import Client as TemporalClient
from temporalio.client import Interceptor, OutboundInterceptor
from temporalio.common import QueryRejectCondition
Expand All @@ -29,8 +29,7 @@ class Options(BaseSettings):
auth: Optional[AuthOptions] = None
prometheus_port: Optional[str] = None

class Config:
env_prefix = "TEMPORAL_"
model_config = SettingsConfigDict(env_prefix="TEMPORAL_")


Options.model_rebuild()
Expand Down Expand Up @@ -171,7 +170,9 @@ async def connect(

self._client = await TemporalClient.connect(
self._client_opts.host,
namespace=self._client_opts.namespace or os.getenv("TEMPORAL_NAMESPACE") or "default",
namespace=self._client_opts.namespace
or os.getenv("TEMPORAL_NAMESPACE")
or "default",
data_converter=self._data_converter,
interceptors=self._interceptors,
default_workflow_query_reject_condition=self._default_workflow_query_reject_condition,
Expand All @@ -185,15 +186,18 @@ async def connect(
)

asyncio.create_task(self.reconnect_loop())

return self._client

@classmethod
async def _reconnect(self):
# Refresh the auth headers before reconnecting
if self._client_opts.auth:
auth_header_provider = AuthHeaderProvider(self._client_opts.auth)
self._client.rpc_metadata = {**self._client.rpc_metadata, **auth_header_provider.get_headers()}
self._client.rpc_metadata = {
**self._client.rpc_metadata,
**auth_header_provider.get_headers(),
}

logging.debug("Testing Temporal server connection")
await self._client.count_workflows()
5 changes: 2 additions & 3 deletions temporallib/encryption/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import binascii
from typing import Iterable, List, Optional

from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from temporalio.api.common.v1 import Payload
from temporalio.converter import PayloadCodec

Expand All @@ -13,8 +13,7 @@ class EncryptionOptions(BaseSettings):
key: Optional[str] = None
compress: Optional[bool] = False

class Config:
env_prefix = "TEMPORAL_ENCRYPTION_"
model_config = SettingsConfigDict(env_prefix="TEMPORAL_ENCRYPTION_")


class EncryptionPayloadCodec(PayloadCodec):
Expand Down
6 changes: 3 additions & 3 deletions temporallib/worker/sentry_interceptor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Temporal client worker Sentry interceptor."""

import os
from dataclasses import asdict, is_dataclass
from typing import Any, Optional, Type, Union

from pydantic import validator
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from temporalio import activity, workflow
from temporalio.worker import (
ActivityInboundInterceptor,
Expand All @@ -30,8 +31,7 @@ class SentryOptions(BaseSettings):
sample_rate: Optional[float] = 1.0
redact_params: Optional[bool] = False

class Config:
env_prefix = "TEMPORAL_SENTRY_"
model_config = SettingsConfigDict(env_prefix="TEMPORAL_SENTRY_")

@validator("sample_rate", pre=True, always=True)
def validate_sample_rate(cls, v):
Expand Down