Skip to content

Commit cf39461

Browse files
committed
adding snowflake support and snowflake example
1 parent 9c6e652 commit cf39461

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import snowflake.connector # type: ignore
2+
3+
from contextual import ContextualAI
4+
5+
SF_BASE_URL = 'xxxxx-xxxxx-xxxxx.snowflakecomputing.app'
6+
BASE_URL = f'https://{SF_BASE_URL}/v1'
7+
8+
SAMPLE_MESSAGE = 'Can you tell me about XYZ'
9+
10+
ctx = snowflake.connector.connect( # type: ignore
11+
user="",# snowflake account user
12+
password='', # snowflake account password
13+
account="organization-account", # snowflake organization and account <Organization>-<Account>
14+
session_parameters={
15+
'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT': 'json'
16+
})
17+
18+
# Obtain a session token.
19+
token_data = ctx._rest._token_request('ISSUE') # type: ignore
20+
token_extract = token_data['data']['sessionToken'] # type: ignore
21+
22+
# Create a request to the ingress endpoint with authz.
23+
api_key = f'\"{token_extract}\"'
24+
25+
client = ContextualAI(api_key=api_key, base_url=BASE_URL)
26+
27+
agents = [a for a in client.agents.list() ]
28+
29+
agent = agents[0] if agents else None
30+
31+
if agent is None:
32+
print('No agents found')
33+
exit()
34+
print(f"Found agent {agent.name} with id {agent.id}")
35+
36+
messages = [
37+
{
38+
'content': SAMPLE_MESSAGE,
39+
'role': 'user',
40+
}
41+
]
42+
43+
res = client.agents.query.create(agent.id, messages=messages) # type: ignore
44+
45+
output = res.message.content # type: ignore
46+
47+
print(output)

src/contextual/_client.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ContextualAI(SyncAPIClient):
5656

5757
# client options
5858
api_key: str
59+
is_snowflake: bool
5960

6061
def __init__(
6162
self,
@@ -97,6 +98,11 @@ def __init__(
9798
if base_url is None:
9899
base_url = f"https://api.contextual.ai/v1"
99100

101+
if 'snowflakecomputing.app' in str(base_url):
102+
self.is_snowflake = True
103+
else:
104+
self.is_snowflake = False
105+
100106
super().__init__(
101107
version=__version__,
102108
base_url=base_url,
@@ -123,7 +129,10 @@ def qs(self) -> Querystring:
123129
@override
124130
def auth_headers(self) -> dict[str, str]:
125131
api_key = self.api_key
126-
return {"Authorization": f"Bearer {api_key}"}
132+
if self.is_snowflake:
133+
return {"Authorization": f"Snowflake Token={api_key}"}
134+
else:
135+
return {"Authorization": f"Bearer {api_key}"}
127136

128137
@property
129138
@override
@@ -228,6 +237,7 @@ class AsyncContextualAI(AsyncAPIClient):
228237

229238
# client options
230239
api_key: str
240+
is_snowflake: bool
231241

232242
def __init__(
233243
self,
@@ -269,6 +279,11 @@ def __init__(
269279
if base_url is None:
270280
base_url = f"https://api.contextual.ai/v1"
271281

282+
if 'snowflakecomputing.app' in str(base_url):
283+
self.is_snowflake = True
284+
else:
285+
self.is_snowflake = False
286+
272287
super().__init__(
273288
version=__version__,
274289
base_url=base_url,
@@ -295,7 +310,10 @@ def qs(self) -> Querystring:
295310
@override
296311
def auth_headers(self) -> dict[str, str]:
297312
api_key = self.api_key
298-
return {"Authorization": f"Bearer {api_key}"}
313+
if self.is_snowflake:
314+
return {"Authorization": f"Snowflake Token={api_key}"}
315+
else:
316+
return {"Authorization": f"Bearer {api_key}"}
299317

300318
@property
301319
@override

0 commit comments

Comments
 (0)