66
77# Copyright (c) Microsoft Corporation. All rights reserved.
88# Licensed under the MIT License.
9-
9+ from hashlib import sha256
1010from typing import Dict , List
11+ from threading import Semaphore
1112import json
1213from botbuilder .core .storage import Storage , StoreItem
1314import azure .cosmos .cosmos_client as cosmos_client
1415import azure .cosmos .errors as cosmos_errors
1516
1617
17- class CosmosDbConfig () :
18+ class CosmosDbConfig :
1819 """The class for CosmosDB configuration for the Azure Bot Framework."""
1920
20- def __init__ (self , ** kwargs ):
21+ def __init__ (self , endpoint : str = None , masterkey : str = None , database : str = None , container : str = None ,
22+ partition_key : str = None , database_creation_options : dict = None ,
23+ container_creation_options : dict = None , ** kwargs ):
2124 """Create the Config object.
2225
2326 :param endpoint:
@@ -27,35 +30,77 @@ def __init__(self, **kwargs):
2730 :param filename:
2831 :return CosmosDbConfig:
2932 """
30- self .__config_file = kwargs .pop ('filename' , None )
33+ self .__config_file = kwargs .get ('filename' )
3134 if self .__config_file :
3235 kwargs = json .load (open (self .__config_file ))
33- self .endpoint = kwargs .pop ('endpoint' )
34- self .masterkey = kwargs .pop ('masterkey' )
35- self .database = kwargs .pop ('database' , 'bot_db' )
36- self .container = kwargs .pop ('container' , 'bot_container' )
36+ self .endpoint = endpoint or kwargs .get ('endpoint' )
37+ self .masterkey = masterkey or kwargs .get ('masterkey' )
38+ self .database = database or kwargs .get ('database' , 'bot_db' )
39+ self .container = container or kwargs .get ('container' , 'bot_container' )
40+ self .partition_key = partition_key or kwargs .get ('partition_key' )
41+ self .database_creation_options = database_creation_options or kwargs .get ('database_creation_options' )
42+ self .container_creation_options = container_creation_options or kwargs .get ('container_creation_options' )
43+
44+
45+ class CosmosDbKeyEscape :
46+
47+ @staticmethod
48+ def sanitize_key (key ) -> str :
49+ """Return the sanitized key.
50+
51+ Replace characters that are not allowed in keys in Cosmos.
52+
53+ :param key:
54+ :return str:
55+ """
56+ # forbidden characters
57+ bad_chars = ['\\ ' , '?' , '/' , '#' , '\t ' , '\n ' , '\r ' , '*' ]
58+ # replace those with with '*' and the
59+ # Unicode code point of the character and return the new string
60+ key = '' .join (
61+ map (
62+ lambda x : '*' + str (ord (x )) if x in bad_chars else x , key
63+ )
64+ )
65+
66+ return CosmosDbKeyEscape .truncate_key (key )
67+
68+ @staticmethod
69+ def truncate_key (key : str ) -> str :
70+ MAX_KEY_LEN = 255
71+
72+ if len (key ) > MAX_KEY_LEN :
73+ aux_hash = sha256 (key .encode ('utf-8' ))
74+ aux_hex = aux_hash .hexdigest ()
75+
76+ key = key [0 :MAX_KEY_LEN - len (aux_hex )] + aux_hex
77+
78+ return key
3779
3880
3981class CosmosDbStorage (Storage ):
4082 """The class for CosmosDB middleware for the Azure Bot Framework."""
4183
42- def __init__ (self , config : CosmosDbConfig ):
84+ def __init__ (self , config : CosmosDbConfig , client : cosmos_client . CosmosClient = None ):
4385 """Create the storage object.
4486
4587 :param config:
4688 """
4789 super (CosmosDbStorage , self ).__init__ ()
4890 self .config = config
49- self .client = cosmos_client .CosmosClient (
91+ self .client = client or cosmos_client .CosmosClient (
5092 self .config .endpoint ,
5193 {'masterKey' : self .config .masterkey }
52- )
94+ )
5395 # these are set by the functions that check
5496 # the presence of the db and container or creates them
5597 self .db = None
5698 self .container = None
99+ self ._database_creation_options = config .database_creation_options
100+ self ._container_creation_options = config .container_creation_options
101+ self .__semaphore = Semaphore ()
57102
58- async def read (self , keys : List [str ]) -> dict :
103+ async def read (self , keys : List [str ]) -> Dict [ str , object ] :
59104 """Read storeitems from storage.
60105
61106 :param keys:
@@ -65,35 +110,39 @@ async def read(self, keys: List[str]) -> dict:
65110 # check if the database and container exists and if not create
66111 if not self .__container_exists :
67112 self .__create_db_and_container ()
68- if len ( keys ) > 0 :
113+ if keys :
69114 # create the parameters object
70115 parameters = [
71- {'name' : f'@id{ i } ' , 'value' : f'{ self . __sanitize_key (key )} ' }
116+ {'name' : f'@id{ i } ' , 'value' : f'{ CosmosDbKeyEscape . sanitize_key (key )} ' }
72117 for i , key in enumerate (keys )
73- ]
118+ ]
74119 # get the names of the params
75120 parameter_sequence = ',' .join (param .get ('name' )
76121 for param in parameters )
77122 # create the query
78123 query = {
79124 "query" :
80- f"SELECT c.id, c.realId, c.document, c._etag \
81- FROM c WHERE c.id in ({ parameter_sequence } )" ,
125+ f"SELECT c.id, c.realId, c.document, c._etag FROM c WHERE c.id in ({ parameter_sequence } )" ,
82126 "parameters" : parameters
83- }
84- options = {'enableCrossPartitionQuery' : True }
127+ }
128+
129+ if self .config .partition_key :
130+ options = {'partitionKey' : self .config .partition_key }
131+ else :
132+ options = {'enableCrossPartitionQuery' : True }
133+
85134 # run the query and store the results as a list
86135 results = list (
87136 self .client .QueryItems (
88137 self .__container_link , query , options )
89- )
138+ )
90139 # return a dict with a key and a StoreItem
91140 return {
92141 r .get ('realId' ): self .__create_si (r ) for r in results
93- }
142+ }
94143 else :
95- raise Exception ( 'cosmosdb_storage.read(): \
96- provide at least one key' )
144+ # No keys passed in, no result to return.
145+ return {}
97146 except TypeError as e :
98147 raise e
99148
@@ -112,7 +161,7 @@ async def write(self, changes: Dict[str, StoreItem]):
112161 # store the e_tag
113162 e_tag = change .e_tag
114163 # create the new document
115- doc = {'id' : self . __sanitize_key (key ),
164+ doc = {'id' : CosmosDbKeyEscape . sanitize_key (key ),
116165 'realId' : key ,
117166 'document' : self .__create_dict (change )
118167 }
@@ -122,16 +171,16 @@ async def write(self, changes: Dict[str, StoreItem]):
122171 database_or_Container_link = self .__container_link ,
123172 document = doc ,
124173 options = {'disableAutomaticIdGeneration' : True }
125- )
174+ )
126175 # if we have an etag, do opt. concurrency replace
127- elif (len (e_tag ) > 0 ):
176+ elif (len (e_tag ) > 0 ):
128177 access_condition = {'type' : 'IfMatch' , 'condition' : e_tag }
129178 self .client .ReplaceItem (
130179 document_link = self .__item_link (
131- self . __sanitize_key (key )),
180+ CosmosDbKeyEscape . sanitize_key (key )),
132181 new_document = doc ,
133182 options = {'accessCondition' : access_condition }
134- )
183+ )
135184 # error when there is no e_tag
136185 else :
137186 raise Exception ('cosmosdb_storage.write(): etag missing' )
@@ -148,10 +197,17 @@ async def delete(self, keys: List[str]):
148197 # check if the database and container exists and if not create
149198 if not self .__container_exists :
150199 self .__create_db_and_container ()
200+
201+ options = {}
202+ if self .config .partition_key :
203+ options ['partitionKey' ] = self .config .partition_key
204+
151205 # call the function for each key
152206 for k in keys :
153207 self .client .DeleteItem (
154- document_link = self .__item_link (self .__sanitize_key (k )))
208+ document_link = self .__item_link (CosmosDbKeyEscape .sanitize_key (k )),
209+ options = options
210+ )
155211 # print(res)
156212 except cosmos_errors .HTTPFailure as h :
157213 # print(h.status_code)
@@ -169,7 +225,8 @@ def __create_si(self, result) -> StoreItem:
169225 # get the document item from the result and turn into a dict
170226 doc = result .get ('document' )
171227 # readd the e_tag from Cosmos
172- doc ['e_tag' ] = result .get ('_etag' )
228+ if result .get ('_etag' ):
229+ doc ['e_tag' ] = result ['_etag' ]
173230 # create and return the StoreItem
174231 return StoreItem (** doc )
175232
@@ -183,28 +240,10 @@ def __create_dict(self, si: StoreItem) -> Dict:
183240 """
184241 # read the content
185242 non_magic_attr = ([attr for attr in dir (si )
186- if not attr .startswith ('_' ) or attr .__eq__ ('e_tag' )])
243+ if not attr .startswith ('_' ) or attr .__eq__ ('e_tag' )])
187244 # loop through attributes and write and return a dict
188245 return ({attr : getattr (si , attr )
189- for attr in non_magic_attr })
190-
191- def __sanitize_key (self , key ) -> str :
192- """Return the sanitized key.
193-
194- Replace characters that are not allowed in keys in Cosmos.
195-
196- :param key:
197- :return str:
198- """
199- # forbidden characters
200- bad_chars = ['\\ ' , '?' , '/' , '#' , '\t ' , '\n ' , '\r ' ]
201- # replace those with with '*' and the
202- # Unicode code point of the character and return the new string
203- return '' .join (
204- map (
205- lambda x : '*' + str (ord (x )) if x in bad_chars else x , key
206- )
207- )
246+ for attr in non_magic_attr })
208247
209248 def __item_link (self , id ) -> str :
210249 """Return the item link of a item in the container.
@@ -241,14 +280,15 @@ def __container_exists(self) -> bool:
241280
242281 def __create_db_and_container (self ):
243282 """Call the get or create methods."""
244- db_id = self .config .database
245- container_name = self .config .container
246- self .db = self .__get_or_create_database (self .client , db_id )
247- self .container = self .__get_or_create_container (
248- self .client , container_name
283+ with self .__semaphore :
284+ db_id = self .config .database
285+ container_name = self .config .container
286+ self .db = self ._get_or_create_database (self .client , db_id )
287+ self .container = self ._get_or_create_container (
288+ self .client , container_name
249289 )
250290
251- def __get_or_create_database (self , doc_client , id ) -> str :
291+ def _get_or_create_database (self , doc_client , id ) -> str :
252292 """Return the database link.
253293
254294 Check if the database exists or create the db.
@@ -269,10 +309,10 @@ def __get_or_create_database(self, doc_client, id) -> str:
269309 return dbs [0 ]['id' ]
270310 else :
271311 # create the database if it didn't exist
272- res = doc_client .CreateDatabase ({'id' : id })
312+ res = doc_client .CreateDatabase ({'id' : id }, self . _database_creation_options )
273313 return res ['id' ]
274314
275- def __get_or_create_container (self , doc_client , container ) -> str :
315+ def _get_or_create_container (self , doc_client , container ) -> str :
276316 """Return the container link.
277317
278318 Check if the container exists or create the container.
@@ -297,5 +337,8 @@ def __get_or_create_container(self, doc_client, container) -> str:
297337 else :
298338 # Create a container if it didn't exist
299339 res = doc_client .CreateContainer (
300- self .__database_link , {'id' : container })
340+ self .__database_link ,
341+ {'id' : container },
342+ self ._container_creation_options
343+ )
301344 return res ['id' ]
0 commit comments