Skip to content

add some features, change depricated methods #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions pytest_async_mongodb/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def pytest_addoption(parser):
help='Try loading fixtures from this directory')



def wrapper(func):
@functools.wraps(func)
async def wrapped(*args, **kwargs):
Expand All @@ -54,9 +53,27 @@ def __getattribute__(self, name):
class AsyncCollection(AsyncClassMethod, mongomock.Collection):

ASYNC_METHODS = [
'find_one',
'find',
'find_one',
'find_one_and_delete',
'find_one_and_replace',
'find_one_and_update',
'find_and_modify',
'save',
'delete_one',
'delete_many',
'count',
'insert_one',
'insert_many',
'update_one',
'update_many',
'replace_one',
'count_documents',
'estimated_document_count',
'drop',
'create_index',
'ensure_index',
'map_reduce',
]

async def find_one(self, filter=None, *args, **kwargs):
Expand Down Expand Up @@ -89,6 +106,14 @@ def get_collection(self, name, codec_options=None, read_preference=None,
return collection


class Session:
async def __aenter__(self):
await asyncio.sleep(0)

async def __aexit__(self, exc_type, exc, tb):
await asyncio.sleep(0)


class AsyncMockMongoClient(mongomock.MongoClient):

def get_database(self, name, codec_options=None, read_preference=None,
Expand All @@ -98,23 +123,36 @@ def get_database(self, name, codec_options=None, read_preference=None,
db = self._databases[name] = AsyncDatabase(self, name)
return db

async def start_session(self, **kwargs):
await asyncio.sleep(0)
return Session()


@pytest.fixture(scope='function')
async def async_mongodb(pytestconfig):
client = AsyncMockMongoClient()
db = client['pytest']
await clean_database(db)
load_fixtures(db, pytestconfig)
await load_fixtures(db, pytestconfig)
return db


@pytest.fixture(scope='function')
async def async_mongodb_client(pytestconfig):
client = AsyncMockMongoClient()
db = client['pytest']
await clean_database(db)
await load_fixtures(db, pytestconfig)
return client


async def clean_database(db):
collections = await db.collection_names(include_system_collections=False)
for name in collections:
db.drop_collection(name)


def load_fixtures(db, config):
async def load_fixtures(db, config):
option_dir = config.getoption('async_mongodb_fixture_dir')
ini_dir = config.getini('async_mongodb_fixture_dir')
fixtures = config.getini('async_mongodb_fixtures')
Expand All @@ -127,10 +165,10 @@ def load_fixtures(db, config):
selected = fixtures and collection in fixtures
if selected and supported:
path = os.path.join(basedir, file_name)
load_fixture(db, collection, path, file_format)
await load_fixture(db, collection, path, file_format)


def load_fixture(db, collection, path, file_format):
async def load_fixture(db, collection, path, file_format):
if file_format == 'json':
loader = functools.partial(json.load, object_hook=json_util.object_hook)
elif file_format == 'yaml':
Expand All @@ -144,4 +182,4 @@ def load_fixture(db, collection, path, file_format):
_cache[path] = docs = loader(fp)

for document in docs:
db[collection].insert(document)
await db[collection].insert_one(document)
8 changes: 4 additions & 4 deletions tests/unit/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_load(async_mongodb):

@pytest.mark.asyncio
async def check_players(players):
count = await players.count()
count = await players.count_documents({})
assert count == 2
await check_keys_in_docs(players, ['name', 'surname', 'position'])
manuel = await players.find_one({'name': 'Manuel'})
Expand All @@ -24,7 +24,7 @@ async def check_players(players):

@pytest.mark.asyncio
async def check_championships(championships):
count = await championships.count()
count = await championships.count_documents({})
assert count == 3
await check_keys_in_docs(championships, ['year', 'host', 'winner'])

Expand All @@ -39,12 +39,12 @@ async def check_keys_in_docs(collection, keys):

@pytest.mark.asyncio
async def test_insert(async_mongodb):
async_mongodb.players.insert({
await async_mongodb.players.insert_one({
'name': 'Bastian',
'surname': 'Schweinsteiger',
'position': 'midfield'
})
count = await async_mongodb.players.count()
count = await async_mongodb.players.count_documents({})
bastian = await async_mongodb.players.find_one({'name': 'Bastian'})
assert count == 3
assert bastian.get('name') == 'Bastian'