Skip to content
This repository was archived by the owner on Jan 5, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libraries/botbuilder-azure/botbuilder/azure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# --------------------------------------------------------------------------

from .about import __version__
from .cosmosdb_storage import CosmosDbStorage, CosmosDbConfig
from .cosmosdb_storage import CosmosDbStorage, CosmosDbConfig, CosmosDbKeyEscape

__all__ = ['CosmosDbStorage',
'CosmosDbConfig',
'CosmosDbKeyEscape',
'__version__']
161 changes: 102 additions & 59 deletions libraries/botbuilder-azure/botbuilder/azure/cosmosdb_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from hashlib import sha256
from typing import Dict, List
from threading import Semaphore
import json
from botbuilder.core.storage import Storage, StoreItem
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.errors as cosmos_errors


class CosmosDbConfig():
class CosmosDbConfig:
"""The class for CosmosDB configuration for the Azure Bot Framework."""

def __init__(self, **kwargs):
def __init__(self, endpoint: str = None, masterkey: str = None, database: str = None, container: str = None,
partition_key: str = None, database_creation_options: dict = None,
container_creation_options: dict = None, **kwargs):
"""Create the Config object.

:param endpoint:
Expand All @@ -27,35 +30,77 @@ def __init__(self, **kwargs):
:param filename:
:return CosmosDbConfig:
"""
self.__config_file = kwargs.pop('filename', None)
self.__config_file = kwargs.get('filename')
if self.__config_file:
kwargs = json.load(open(self.__config_file))
self.endpoint = kwargs.pop('endpoint')
self.masterkey = kwargs.pop('masterkey')
self.database = kwargs.pop('database', 'bot_db')
self.container = kwargs.pop('container', 'bot_container')
self.endpoint = endpoint or kwargs.get('endpoint')
self.masterkey = masterkey or kwargs.get('masterkey')
self.database = database or kwargs.get('database', 'bot_db')
self.container = container or kwargs.get('container', 'bot_container')
self.partition_key = partition_key or kwargs.get('partition_key')
self.database_creation_options = database_creation_options or kwargs.get('database_creation_options')
self.container_creation_options = container_creation_options or kwargs.get('container_creation_options')


class CosmosDbKeyEscape:

@staticmethod
def sanitize_key(key) -> str:
"""Return the sanitized key.

Replace characters that are not allowed in keys in Cosmos.

:param key:
:return str:
"""
# forbidden characters
bad_chars = ['\\', '?', '/', '#', '\t', '\n', '\r', '*']
# replace those with with '*' and the
# Unicode code point of the character and return the new string
key = ''.join(
map(
lambda x: '*' + str(ord(x)) if x in bad_chars else x, key
)
)

return CosmosDbKeyEscape.truncate_key(key)

@staticmethod
def truncate_key(key: str) -> str:
MAX_KEY_LEN = 255

if len(key) > MAX_KEY_LEN:
aux_hash = sha256(key.encode('utf-8'))
aux_hex = aux_hash.hexdigest()

key = key[0:MAX_KEY_LEN - len(aux_hex)] + aux_hex

return key


class CosmosDbStorage(Storage):
"""The class for CosmosDB middleware for the Azure Bot Framework."""

def __init__(self, config: CosmosDbConfig):
def __init__(self, config: CosmosDbConfig, client: cosmos_client.CosmosClient = None):
"""Create the storage object.

:param config:
"""
super(CosmosDbStorage, self).__init__()
self.config = config
self.client = cosmos_client.CosmosClient(
self.client = client or cosmos_client.CosmosClient(
self.config.endpoint,
{'masterKey': self.config.masterkey}
)
)
# these are set by the functions that check
# the presence of the db and container or creates them
self.db = None
self.container = None
self._database_creation_options = config.database_creation_options
self._container_creation_options = config.container_creation_options
self.__semaphore = Semaphore()

async def read(self, keys: List[str]) -> dict:
async def read(self, keys: List[str]) -> Dict[str, object]:
"""Read storeitems from storage.

:param keys:
Expand All @@ -65,35 +110,39 @@ async def read(self, keys: List[str]) -> dict:
# check if the database and container exists and if not create
if not self.__container_exists:
self.__create_db_and_container()
if len(keys) > 0:
if keys:
# create the parameters object
parameters = [
{'name': f'@id{i}', 'value': f'{self.__sanitize_key(key)}'}
{'name': f'@id{i}', 'value': f'{CosmosDbKeyEscape.sanitize_key(key)}'}
for i, key in enumerate(keys)
]
]
# get the names of the params
parameter_sequence = ','.join(param.get('name')
for param in parameters)
# create the query
query = {
"query":
f"SELECT c.id, c.realId, c.document, c._etag \
FROM c WHERE c.id in ({parameter_sequence})",
f"SELECT c.id, c.realId, c.document, c._etag FROM c WHERE c.id in ({parameter_sequence})",
"parameters": parameters
}
options = {'enableCrossPartitionQuery': True}
}

if self.config.partition_key:
options = {'partitionKey': self.config.partition_key}
else:
options = {'enableCrossPartitionQuery': True}

# run the query and store the results as a list
results = list(
self.client.QueryItems(
self.__container_link, query, options)
)
)
# return a dict with a key and a StoreItem
return {
r.get('realId'): self.__create_si(r) for r in results
}
}
else:
raise Exception('cosmosdb_storage.read(): \
provide at least one key')
# No keys passed in, no result to return.
return {}
except TypeError as e:
raise e

Expand All @@ -112,7 +161,7 @@ async def write(self, changes: Dict[str, StoreItem]):
# store the e_tag
e_tag = change.e_tag
# create the new document
doc = {'id': self.__sanitize_key(key),
doc = {'id': CosmosDbKeyEscape.sanitize_key(key),
'realId': key,
'document': self.__create_dict(change)
}
Expand All @@ -122,16 +171,16 @@ async def write(self, changes: Dict[str, StoreItem]):
database_or_Container_link=self.__container_link,
document=doc,
options={'disableAutomaticIdGeneration': True}
)
)
# if we have an etag, do opt. concurrency replace
elif(len(e_tag) > 0):
elif (len(e_tag) > 0):
access_condition = {'type': 'IfMatch', 'condition': e_tag}
self.client.ReplaceItem(
document_link=self.__item_link(
self.__sanitize_key(key)),
CosmosDbKeyEscape.sanitize_key(key)),
new_document=doc,
options={'accessCondition': access_condition}
)
)
# error when there is no e_tag
else:
raise Exception('cosmosdb_storage.write(): etag missing')
Expand All @@ -148,10 +197,17 @@ async def delete(self, keys: List[str]):
# check if the database and container exists and if not create
if not self.__container_exists:
self.__create_db_and_container()

options = {}
if self.config.partition_key:
options['partitionKey'] = self.config.partition_key

# call the function for each key
for k in keys:
self.client.DeleteItem(
document_link=self.__item_link(self.__sanitize_key(k)))
document_link=self.__item_link(CosmosDbKeyEscape.sanitize_key(k)),
options=options
)
# print(res)
except cosmos_errors.HTTPFailure as h:
# print(h.status_code)
Expand All @@ -169,7 +225,8 @@ def __create_si(self, result) -> StoreItem:
# get the document item from the result and turn into a dict
doc = result.get('document')
# readd the e_tag from Cosmos
doc['e_tag'] = result.get('_etag')
if result.get('_etag'):
doc['e_tag'] = result['_etag']
# create and return the StoreItem
return StoreItem(**doc)

Expand All @@ -183,28 +240,10 @@ def __create_dict(self, si: StoreItem) -> Dict:
"""
# read the content
non_magic_attr = ([attr for attr in dir(si)
if not attr.startswith('_') or attr.__eq__('e_tag')])
if not attr.startswith('_') or attr.__eq__('e_tag')])
# loop through attributes and write and return a dict
return ({attr: getattr(si, attr)
for attr in non_magic_attr})

def __sanitize_key(self, key) -> str:
"""Return the sanitized key.

Replace characters that are not allowed in keys in Cosmos.

:param key:
:return str:
"""
# forbidden characters
bad_chars = ['\\', '?', '/', '#', '\t', '\n', '\r']
# replace those with with '*' and the
# Unicode code point of the character and return the new string
return ''.join(
map(
lambda x: '*'+str(ord(x)) if x in bad_chars else x, key
)
)
for attr in non_magic_attr})

def __item_link(self, id) -> str:
"""Return the item link of a item in the container.
Expand Down Expand Up @@ -241,14 +280,15 @@ def __container_exists(self) -> bool:

def __create_db_and_container(self):
"""Call the get or create methods."""
db_id = self.config.database
container_name = self.config.container
self.db = self.__get_or_create_database(self.client, db_id)
self.container = self.__get_or_create_container(
self.client, container_name
with self.__semaphore:
db_id = self.config.database
container_name = self.config.container
self.db = self._get_or_create_database(self.client, db_id)
self.container = self._get_or_create_container(
self.client, container_name
)

def __get_or_create_database(self, doc_client, id) -> str:
def _get_or_create_database(self, doc_client, id) -> str:
"""Return the database link.

Check if the database exists or create the db.
Expand All @@ -269,10 +309,10 @@ def __get_or_create_database(self, doc_client, id) -> str:
return dbs[0]['id']
else:
# create the database if it didn't exist
res = doc_client.CreateDatabase({'id': id})
res = doc_client.CreateDatabase({'id': id}, self._database_creation_options)
return res['id']

def __get_or_create_container(self, doc_client, container) -> str:
def _get_or_create_container(self, doc_client, container) -> str:
"""Return the container link.

Check if the container exists or create the container.
Expand All @@ -297,5 +337,8 @@ def __get_or_create_container(self, doc_client, container) -> str:
else:
# Create a container if it didn't exist
res = doc_client.CreateContainer(
self.__database_link, {'id': container})
self.__database_link,
{'id': container},
self._container_creation_options
)
return res['id']
Loading