Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions tests/test_clinical_trials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def mock_bucket_client():
yield mock_client


@pytest.fixture
def mock_session():
@pytest.fixture(name="mock_session")
def fixture_mock_session() -> ClientSession:
return AsyncMock(spec=ClientSession)


Expand Down Expand Up @@ -93,35 +93,26 @@ def test_format_to_doc_details():


@pytest.mark.asyncio
async def test_add_clinical_trials_to_docs():
mock_session = AsyncMock(spec=ClientSession)
mock_docs = Mock(spec=Docs)
mock_docs.aadd_texts = AsyncMock()
mock_docs.texts = []

mock_response = AsyncMock()
mock_response.raise_for_status.return_value = None
async def test_add_clinical_trials_to_docs(mock_session) -> None:
mock_docs = Mock(spec=Docs, aadd_texts=AsyncMock(), texts=[])
mock_response = AsyncMock(raise_for_status=Mock(return_value=None))
mock_response.json.return_value = {
"studies": [
{"protocolSection": {"identificationModule": {"nctId": "NCT12345678"}}}
]
}
mock_session.get.return_value.__aenter__.return_value = mock_response

with patch(
"paperqa.sources.clinical_trials.search_retrieve_clinical_trials",
return_value=([SAMPLE_TRIAL_DATA], 1),
):
await add_clinical_trials_to_docs(
"test query", mock_docs, Settings(), session=mock_session
)

assert (
mock_docs.aadd_texts.call_count == 2
), "One for the metadata and one for the trial"
call_args = mock_docs.aadd_texts.call_args[1]
assert "doc" in call_args
assert isinstance(call_args["doc"].citation, str)
await add_clinical_trials_to_docs(
"test query", mock_docs, Settings(), session=mock_session
)

assert (
mock_docs.aadd_texts.call_count == 2
), "One for the metadata and one for the trial"
call_args = mock_docs.aadd_texts.call_args[1]
assert "doc" in call_args
assert isinstance(call_args["doc"].citation, str)


def test_parse_clinical_trial():
Expand Down