Skip to content

feat: change enterprise login provider from auth0 to workOS #2877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/concepts/cli.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,6 @@ def crew(self) -> Crew:
Deploy the crew or flow to [CrewAI Enterprise](https://app.crewai.com).

- **Authentication**: You need to be authenticated to deploy to CrewAI Enterprise.
```shell Terminal
crewai signup
```
If you already have an account, you can login with:
```shell Terminal
crewai login
```
Expand Down
5 changes: 1 addition & 4 deletions docs/enterprise/guides/deploy-crew.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,8 @@ The CLI provides the fastest way to deploy locally developed crews to the Enterp
First, you need to authenticate your CLI with the CrewAI Enterprise platform:

```bash
# If you already have a CrewAI Enterprise account
# If you already have a CrewAI Enterprise account, or want to create one to use the CLI:
crewai login

# If you're creating a new account
crewai signup
```

When you run either command, the CLI will:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ dependencies = [
"openpyxl>=3.1.5",
"pyvis>=0.3.2",
# Authentication and Security
"auth0-python>=4.7.1",
"python-dotenv>=1.0.0",
"auth0-python>=4.7.1",
"pyjwt>=2.9.0",
# Configuration and Utils
"click>=8.1.7",
"appdirs>=1.4.4",
Expand Down
7 changes: 7 additions & 0 deletions src/crewai/cli/authentication/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
ALGORITHMS = ["RS256"]
WORKOS_DOMAIN = "innovative-soap-49-staging.authkit.app"
WORKOS_CLIENT_ID = "client_01JTK5C207TZBSXQWCVWB1X6HK"
WORKOS_TOKEN_URL = f"https://{WORKOS_DOMAIN}/oauth2/token"
WORKOS_AUTHORIZE_URL = f"https://{WORKOS_DOMAIN}/oauth2/authorize"
WORKOS_ENVIRONMENT_ID = "client_01JNJQWB4HG8T5980R5VHP057C"

# Legacy Auth0 constants
AUTH0_DOMAIN = "crewai.us.auth0.com"
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
272 changes: 258 additions & 14 deletions src/crewai/cli/authentication/main.py

Large diffs are not rendered by default.

21 changes: 17 additions & 4 deletions src/crewai/cli/authentication/token.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from .utils import TokenManager
from .utils import TokenManager, get_auth_token_with_refresh_token


def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
"""Get the authentication token. Uses refresh token to fetch a new token if current one is expired."""
access_token = TokenManager().get_token("access_token")
refresh_token = TokenManager().get_token("refresh_token")

# Token could be expired, so we use the refresh token to fetch a new one.
# Skip if refresh token is not available.
if not access_token and refresh_token:
data = get_auth_token_with_refresh_token(refresh_token)
access_token = data.get("access_token")
refresh_token = data.get("refresh_token")

if access_token and refresh_token:
TokenManager().save_access_token(access_token, data["expires_in"])
TokenManager().save_refresh_token(refresh_token)

if not access_token:
raise Exception("No token found, make sure you are logged in")
raise Exception("Session expired. Please sign in again with 'crewai login'.")
return access_token
188 changes: 160 additions & 28 deletions src/crewai/cli/authentication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,113 @@
from pathlib import Path
from typing import Optional

from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier,
TokenVerifier,
)
import requests
from cryptography.fernet import Fernet
import jwt
from jwt import PyJWKClient

from .constants import (
WORKOS_CLIENT_ID,
WORKOS_DOMAIN,
WORKOS_ENVIRONMENT_ID,
WORKOS_TOKEN_URL,
AUTH0_CLIENT_ID, # Legacy Import for old auth
AUTH0_DOMAIN, # Legacy Import for old auth
)

from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier, # Legacy Import for old auth
TokenVerifier, # Legacy Import for old auth
)


def validate_token(id_token: str) -> None:
def get_auth_token_with_refresh_token(refresh_token: str) -> dict:
"""
Verify the token and its precedence
Get an access token using a refresh token.

:param id_token:
:param refresh_token: The refresh token to use.
:return: A dictionary containing the access token, its expiration time, and a new refresh token, or an empty dictionary if the attempt to get a new access token failed.
"""
jwks_url = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json"
issuer = f"https://{AUTH0_DOMAIN}/"
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
token_verifier = TokenVerifier(
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID

response = requests.post(
WORKOS_TOKEN_URL,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": WORKOS_CLIENT_ID,
},
timeout=10,
)
token_verifier.verify(id_token)

if response.status_code != 200:
return {}

data = response.json()
try:
validate_token(data.get("access_token"))
except Exception:
return {}

return {
"access_token": data.get("access_token"),
"refresh_token": data.get("refresh_token"),
"expires_in": data.get("expires_in"),
}


def validate_token(jwt_token: str, token_type: str = "access_token") -> dict:
"""
Verify the token's signature and claims using PyJWT.

:param jwt_token: The JWT (JWS) string to validate.
:return: The decoded token.
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
expired, incorrect issuer/audience, JWKS fetching error,
missing required claims).
"""

supported_audiences = {
"access_token": WORKOS_ENVIRONMENT_ID,
"id_token": WORKOS_CLIENT_ID,
}

jwks_url = f"https://{WORKOS_DOMAIN}/oauth2/jwks"
expected_issuer = f"https://{WORKOS_DOMAIN}"
expected_audience = supported_audiences[token_type]
decoded_token = None

try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)

decoded_token = jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
audience=expected_audience,
issuer=expected_issuer,
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
return decoded_token

except jwt.ExpiredSignatureError:
raise Exception("Token has expired.")
except jwt.InvalidAudienceError:
raise Exception(f"Invalid token audience. Expected: '{expected_audience}'")
except jwt.InvalidIssuerError:
raise Exception(f"Invalid token issuer. Expected: '{expected_issuer}'")
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {str(e)}")
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {str(e)}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")


class TokenManager:
Expand Down Expand Up @@ -56,37 +141,43 @@ def _get_or_create_key(self) -> bytes:
self.save_secure_file(key_filename, new_key)
return new_key

def save_tokens(self, access_token: str, expires_in: int) -> None:
def save_access_token(self, access_token: str, expires_in: int) -> None:
"""
Save the access token and its expiration time.

:param access_token: The access token to save.
:param expires_in: The expiration time of the access token in seconds.
"""
expiration_time = datetime.now() + timedelta(seconds=expires_in)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
self._save_token("access_token", access_token, expires_in)

def save_refresh_token(self, refresh_token: str) -> None:
"""
Save the refresh token and its expiration time.

def get_token(self) -> Optional[str]:
:param refresh_token: The refresh token to save.

Refresh tokens don't have an expiration time, so the expiration time is set to 100 years from now.
"""
self._save_token("refresh_token", refresh_token, 3153600000)

def get_token(self, token_type: str = "access_token") -> Optional[str]:
"""
Get the access token if it is valid and not expired.
Get the specified token if it exists and is valid (not expired).

:return: The access token if valid and not expired, otherwise None.
:return: The specified token if it exists and hasn't expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)

decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
all_tokens = json.loads(decrypted_data)
if not (token_data := all_tokens.get(token_type)):
return None

expiration = datetime.fromisoformat(data["expiration"])
expiration = datetime.fromisoformat(token_data["expiration"])
if expiration <= datetime.now():
return None

return data["access_token"]
return token_data["value"]

def get_secure_storage_path(self) -> Path:
"""
Expand Down Expand Up @@ -142,3 +233,44 @@ def read_secure_file(self, filename: str) -> Optional[bytes]:

with open(file_path, "rb") as f:
return f.read()

def _save_token(self, token_type: str, token: str, expires_in: int) -> None:
"""
Save the token and its expiration time, updating the existing token file.
"""
all_tokens = {}
raw_existing_data = self.read_secure_file(self.file_path)

if raw_existing_data:
try:
decrypted_data = self.fernet.decrypt(raw_existing_data)
all_tokens = json.loads(decrypted_data.decode())
except Exception:
print("Error decrypting existing token file. Creating new file.")
all_tokens = {}

expiration_time = datetime.now() + timedelta(seconds=expires_in)

all_tokens[token_type] = {
"value": token,
"expiration": expiration_time.isoformat(),
}

updated_encrypted_data = self.fernet.encrypt(json.dumps(all_tokens).encode())
self.save_secure_file(self.file_path, updated_encrypted_data)


# Legacy Authentication code below
def old_validate_token(id_token: str) -> None:
"""
Verify the token and its precedence

:param id_token:
"""
jwks_url = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json"
issuer = f"https://{AUTH0_DOMAIN}/"
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
token_verifier = TokenVerifier(
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
)
token_verifier.verify(id_token)
34 changes: 22 additions & 12 deletions src/crewai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click

from crewai.cli.config import Settings
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
Expand Down Expand Up @@ -138,8 +139,12 @@ def log_tasks_outputs() -> None:
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@click.option("-akn", "--agent-knowledge", is_flag=True, help="Reset AGENT KNOWLEDGE storage")
@click.option("-k","--kickoff-outputs",is_flag=True,help="Reset LATEST KICKOFF TASK OUTPUTS")
@click.option(
"-akn", "--agent-knowledge", is_flag=True, help="Reset AGENT KNOWLEDGE storage"
)
@click.option(
"-k", "--kickoff-outputs", is_flag=True, help="Reset LATEST KICKOFF TASK OUTPUTS"
)
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
def reset_memories(
long: bool,
Expand All @@ -154,13 +159,23 @@ def reset_memories(
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs, knowledge, agent_knowledge). This will delete all the data saved.
"""
try:
memory_types = [long, short, entities, knowledge, agent_knowledge, kickoff_outputs, all]
memory_types = [
long,
short,
entities,
knowledge,
agent_knowledge,
kickoff_outputs,
all,
]
if not any(memory_types):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(long, short, entities, knowledge, agent_knowledge, kickoff_outputs, all)
reset_memories_command(
long, short, entities, knowledge, agent_knowledge, kickoff_outputs, all
)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)

Expand Down Expand Up @@ -210,16 +225,11 @@ def update():
update_crew()


@crewai.command()
def signup():
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()


@crewai.command()
def login():
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()
"""Sign Up/Login to CrewAI Enterprise."""
Settings().clear()
AuthenticationCommand().login()


# DEPLOY CREWAI+ COMMANDS
Expand Down
4 changes: 2 additions & 2 deletions src/crewai/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(self, telemetry):
except Exception:
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
"Please sign up/login to CrewAI Enterprise before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
console.print("Run 'crewai login' to sign up/login.", style="bold green")
raise SystemExit

def _validate_response(self, response: requests.Response) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/crewai/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
merged_data = {**file_data, **data}
super().__init__(config_path=config_path, **merged_data)

def clear(self) -> None:
"""Clear all settings"""
self.config_path.unlink(missing_ok=True)

def dump(self) -> None:
"""Save current settings to settings.json"""
if self.config_path.is_file():
Expand Down
Loading
Loading