Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _create_uri(self) -> str:
:return: URI string.
"""
srv = self.extras.pop("srv", False)
scheme = "mongodb+srv" if srv else "mongodb"
scheme = "mongodb+srv" if srv or self.connection.conn_type == "mongodb+srv" else "mongodb"
login = self.connection.login
password = self.connection.password
netloc = self.connection.host
Expand Down
17 changes: 17 additions & 0 deletions tests/providers/mongo/hooks/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def mongo_connections():
),
# Mongo establishes connection during initialization, so we need to have this connection
Connection(conn_id="fake_connection", conn_type="mongo", host="mongo", port=27017),
Connection(
conn_id="mongo_srv_scheme",
uri="mongodb+srv://test_user:test_password@test_host:1234/test_db"
),
]

with pytest.MonkeyPatch.context() as mp:
Expand Down Expand Up @@ -109,6 +113,10 @@ def test_srv(self):
hook = MongoHook(mongo_conn_id="mongo_default_with_srv")
assert hook.uri.startswith("mongodb+srv://")

def test_srv_scheme(self):
hook = MongoHook(mongo_conn_id="mongo_srv_scheme")
assert hook.uri.startswith("mongodb+srv://")

def test_insert_one(self):
collection = mongomock.MongoClient().db.collection
obj = {"test_insert_one": "test_value"}
Expand Down Expand Up @@ -272,6 +280,15 @@ def test_create_uri_srv_true(self):
self.hook.connection.schema = "test_db"
assert self.hook._create_uri() == "mongodb+srv://test_user:test_password@test_host:1234/test_db"

def test_create_uri_srv_scheme(self):
self.hook.connection.conn_type = "mongodb+srv"
self.hook.connection.login = "test_user"
self.hook.connection.password = "test_password"
self.hook.connection.host = "test_host"
self.hook.connection.port = 1234
self.hook.connection.schema = "test_db"
assert self.hook._create_uri() == "mongodb+srv://test_user:test_password@test_host:1234/test_db"

def test_delete_one(self):
collection = mongomock.MongoClient().db.collection
obj = {"_id": "1"}
Expand Down