Skip to content

Added Accelerators and Examples for Integrating Snowflake Cortex Services as Custom LLM Providers #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions snowflake_cortex/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
######## SNOWFLAKE CONFIGURATION ########

# Snowflake Account
SNOWFLAKE_ACCOUNT="[yoursnowflakeaccount]"

# Snowflake Cortex Service User
SNOWFLAKE_SERVICE_USER="[yoursnowflakecortexuser]"

# Snowflake Cortex Custom LLM
SNOWFLAKE_CORTEXT_CUSTOM_LLM_BASE_URL="https://[yoursnowflakeaccount].snowflakecomputing.com/api/v2/databases/[yourdatabase]/schemas/PUBLIC/cortex-search-services/[yourcustomcortexservice]:query"

# Snowflake Cortex Inference LLM
SNOWFLAKE_CORTEXT_INFERENCE_LLM_BASE_URL="https://[yoursnowflakeaccount].snowflakecomputing.com/api/v2/cortex/inference:complete"

# Snowflake Cortex Custom LLM Private Key
SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY="yourencrypredprivatesnowflakekey"
164 changes: 164 additions & 0 deletions snowflake_cortex/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
🚀 Snowflake Cortex Service Helpers for CrewAI and LiteLLM
This document introduces two helper classes to interact with Snowflake Cortex APIs for language model tasks. These classes streamline integration with both custom Cortex services and built-in foundation models deployed in Snowflake.

⚙️ 1. Snowflake Cortex Custom LLM Helper
The Snowflake Cortex Custom LLM Helper class is designed to integrate with custom Cortex services you’ve deployed in Snowflake. This class provides a flexible interface to connect with your custom APIs, handle requests, and manage configurations for completing language model tasks using CrewAI and LiteLLM.

🧩 Features
Custom Cortex Service Integration: Connect to your custom endpoints within Snowflake Cortex.
Flexible Configuration: Supports configurable parameters such as base_url, api_key, timeout, etc.
Dynamic JWT Authentication: Automates JWT token generation based on your private key, ensuring secure communication.
Token Counting: Built-in token counting to help manage large inputs efficiently.
Error Handling: Provides robust error handling and retries for API requests.

# Example Usage for a Snowflake Cortex Custom Service

# Initialize the Snowflake Cortex Custom Service LLM with the Snowflake API details
```python
snowflake_custom_cortex_llm = SnowflakeCortexCustomServiceLLM(model="snowflake-cortex-custom-service-llm/snowflake-cortex-custom-service-llm",
base_url = os.environ.get("SNOWFLAKE_CORTEXT_CUSTOM_LLM_BASE_URL"),
api_key = os.environ.get("SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY"),
snowflakeAccount=os.environ.get("SNOWFLAKE_ACCOUNT"),
snowflakeServiceUser=os.environ.get("SNOWFLAKE_SERVICE_USER"),
snowflakePromptTemplate = snowflakeServicePayload
)
```
# Register the Snowflake Cortex Custom Service LLM as a provider in LiteLLM
```python
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
{"provider": "snowflake-cortex-custom-service-llm", "custom_handler": snowflake_custom_cortex_llm}
]
```
# Create an LLM instance using the Snowflake Cortex Custom Service LLM
```python
custom_llm = LLM(
model="snowflake-cortex-custom-service-llm/snowflake-cortex-custom-service-llm",
api_base = os.environ.get("SNOWFLAKE_CORTEXT_CUSTOM_LLM_BASE_URL"),
api_key = os.environ.get("SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY"),
)
```
# Use the Snowflake Cortex Custom Service LLM to generate a response
```python
response = snowflakecortex_custom_llm.completion(messages=[{"role": "user", "content": "Show me information about cases"}])
```
# Use the liteLLM to generate a response
```python
response = custom_llm.call(messages=[{"role": "user", "content": "Show me information about cases"}])
```

🧩2. Snowflake Cortex Inference Service Helper
The Snowflake Cortex Inference Service Helper class enables interaction with Snowflake's built-in foundation models (e.g., GPT, Mistral) through their Inference Service API. This class simplifies access to pre-deployed models without needing to manage your own custom endpoints.

🧰 Features
Foundation Model Access: Use pre-deployed LLMs provided by Snowflake, such as GPT, Mistral, and more.
Token Management: Automatically tracks prompt tokens and completion tokens for API responses.
Easy Configuration: Minimal setup required with base_url, api_key, and model selection.
Secure JWT Authentication: Supports JWT-based authorization for secure requests.

# Example Usage for Snowflake Cortex Inference Services
# Initialize the Snowflow Cortext Inference Service with the Snowflake API details
```python
snowflake_cortex_inference_llm = SnowflakeCortexInferenceService(model="snowflake-cortex-inference-service/snowflake-cortex-inference-service",
base_url = os.environ.get("SNOWFLAKE_CORTEXT_INFERENCE_LLM_BASE_URL"),
api_key = os.environ.get("SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY"),
snowflakeAccount=os.environ.get("SNOWFLAKE_ACCOUNT"),
snowflakeServiceUser=os.environ.get("SNOWFLAKE_SERVICE_USER"),
snowflakePromptTemplate = snowflakeServicePayload
)
```
# Register the Snowflake Cortext Inference Service as a custom provider in LiteLLM
```python
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
{"provider": "snowflake-cortex-inference-service", "custom_handler": snowflake_cortex_inference_llm}
]
```
# Create a Custom LLM instance using the Snowflake Cortext Inference Service
```python
custom_llm = LLM(
model="snowflake-cortex-inference-service/snowflake-cortex-inference-service",
api_base = os.environ.get("SNOWFLAKE_CORTEXT_INFERENCE_LLM_BASE_URL"),
api_key = os.environ.get("SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY")
)
```
# Use the Snowflake Cortex Custom Service to generate a response
```python
response = snowflake_cortex_inference_llm.completion(messages=[{"role": "user", "content": "Tell me about the ocean"}])
```
# Use the liteLLM to generate a response
```python
response = custom_llm.call(messages=[{"role": "user", "content": "Tell me about the ocean"}])
```

✅ When to Use Each Helper Class:
Use Case 1
You have a custom Cortex service

Helper Class
SnowflakeCortexCustomLLM

Description
For integrating with your own deployed services.

Use Case 2
You want to use built-in models

Helper Class
SnowflakeCortexInferenceService

Description
For leveraging Snowflake’s foundation models.

# Register either Snowflake Cortex Custom or Inference Services with an Agent
```python
researcher = Agent(
role='Senior Researcher',
goal='You are looking for cases from a ticketing system to help you with your research',
verbose=True,
llm=custom_llm, # 👈 KEY STEP - REGISTER HANDLER
backstory='You are just looking for a list of cases or case related information',
max_iter=1,
)
```
# Final Code with Task and Crew
```python
# Task for the researcher
research_task = Task(
description='Make a single call hoping to get a similar response to case information',
expected_output='case information',
agent=researcher # Assigning the task to the researcher
)

#Instantiate your crew
tech_crew = Crew(
agents=[researcher],
tasks=[research_task],
process=Process.sequential # Tasks will be executed one after the other
)

# # # Begin the task execution
result = tech_crew.kickoff()
```

# To check the results of the Crew, Tasks and Agents
```python
#if the agent is failing to complete the task, you can print the result to see what is going wrong
print("\n--- Execution Result ---")
print(result)
```

Installation
Make sure to install the necessary dependencies before using the Snowflake Cortex Crew AI Helper Modules:

Notes
Make sure to set the environment variables:
SNOWFLAKE_ACCOUNT
SNOWFLAKE_SERVICE_USER
SNOWFLAKE_CORTEXT_CUSTOM_LLM_BASE_URL
SNOWFLAKE_CORTEXT_INFERENCE_LLM_BASE_URL
SNOWFLAKE_CUSTOMLLM_PRIVATE_KEY
with the appropriate values for your Snowflake Cortex instance.

You can modify the class to suit your needs, including adjusting request parameters, service payload and error handling.

## License
This project is released under the MIT License.
111 changes: 111 additions & 0 deletions snowflake_cortex/src/snowflake_auth_jwt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Author: Sean Iannuzzi
# Updated: January 14, 2024

from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from cryptography.hazmat.backends import default_backend
from datetime import timedelta, timezone, datetime
import base64
import hashlib
import jwt

# Constants for JWT claims
ISSUER = "iss" # Issuer of the JWT
EXPIRE_TIME = "exp" # Expiration time for the JWT
ISSUE_TIME = "iat" # Issue time for the JWT
SUBJECT = "sub" # Subject of the JWT (usually the user or service)

class JWTGenerator(object):
"""
JWTGenerator class to generate JSON Web Tokens using RSA private keys.
The class allows generating tokens with a specified lifetime and renewal delay.
"""
LIFETIME = timedelta(minutes=59) # Default token lifetime (59 minutes)
RENEWAL_DELTA = timedelta(minutes=54) # Default renewal delay (54 minutes)
ALGORITHM = "RS256" # JWT generation uses RSA with SHA256 algorithm

def __init__(self, account, user, private_key_string, passphrase, lifetime=LIFETIME, renewal_delay=RENEWAL_DELTA):
"""
Initializes the JWTGenerator class with necessary parameters.

:param account: The account name (e.g., organization or application name).
:param user: The user identifier (e.g., service account or user ID).
:param private_key_string: The private key as a string for JWT signing.
:param passphrase: Passphrase used to decrypt the private key, if it's encrypted.
:param lifetime: Token's lifetime duration (default is 59 minutes).
:param renewal_delay: Time before the token needs to be renewed (default is 54 minutes).
"""
self.account = self.prepare_account_name_for_jwt(account)
self.user = user.upper() # Ensure the username is in uppercase
self.qualified_username = self.account + "." + self.user # Full qualified username
self.lifetime = lifetime # Token's lifetime
self.renewal_delay = renewal_delay # Renewal time for token
self.private_key_string = private_key_string # The private key string
self.passphrase = passphrase # Passphrase for private key decryption (if needed)
self.renew_time = datetime.now(timezone.utc) # Time when the token needs to be renewed
self.token = None # Placeholder for the generated JWT

# Load the private key from the provided string and passphrase
password_to_use = self.passphrase.encode() if self.passphrase else b"" # Ensure passphrase is encoded if provided
self.private_key = load_pem_private_key(self.private_key_string.encode(), password_to_use, default_backend()) # Load and decrypt the private key

def prepare_account_name_for_jwt(self, raw_account):
"""
Prepares the account name by stripping off additional parts to standardize it for JWT.

:param raw_account: The raw account name.
:return: The cleaned account name.
"""
account = raw_account
if '.global' not in account:
idx = account.find('.')
if idx > 0:
account = account[:idx] # Extract the part before the first dot if it is not global
else:
idx = account.find('-')
if idx > 0:
account = account[:idx] # Extract the part before the first hyphen if it's a global account
return account.upper() # Return the standardized account name in uppercase

def get_token(self):
"""
Generates a JWT token using the private key if it's not expired or needs renewal.

:return: The generated JWT as a string.
"""
now = datetime.now(timezone.utc) # Current UTC time
# Renew the token if it's not set or if it's time to renew
if self.token is None or self.renew_time <= now:
self.renew_time = now + self.renewal_delay # Update the renewal time
public_key_fp = self.calculate_public_key_fingerprint(self.private_key) # Generate the public key fingerprint

# Prepare the JWT payload with necessary claims
payload = {
ISSUER: self.qualified_username + '.' + public_key_fp, # Issuer with public key fingerprint
SUBJECT: self.qualified_username, # Subject (user or service account)
ISSUE_TIME: now, # Issue time (current time)
EXPIRE_TIME: now + self.lifetime # Expiration time (current time + lifetime)
}

# Encode the payload using the private key and the RS256 algorithm
token = jwt.encode(payload, key=self.private_key, algorithm=self.ALGORITHM)
if isinstance(token, bytes):
token = token.decode('utf-8') # Decode the token to string if it's bytes
self.token = token # Store the generated token

return self.token # Return the generated JWT

def calculate_public_key_fingerprint(self, private_key):
"""
Calculates the SHA-256 fingerprint of the public key extracted from the private key.

:param private_key: The private key object.
:return: The fingerprint in SHA256 base64 encoded format.
"""
public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) # Extract the public key in DER format
sha256hash = hashlib.sha256() # Create a SHA-256 hash object
sha256hash.update(public_key_raw) # Update the hash with the public key data
public_key_fp = 'SHA256:' + base64.b64encode(sha256hash.digest()).decode('utf-8') # Generate the fingerprint and encode in base64
return public_key_fp # Return the fingerprint

Loading