Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
"""Hook for Mongo DB."""
from __future__ import annotations

import warnings
from ssl import CERT_NONE
from typing import TYPE_CHECKING, Any, overload
from urllib.parse import quote_plus, urlunsplit

import pymongo
from pymongo import MongoClient, ReplaceOne

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,10 +59,19 @@ class MongoHook(BaseHook):
conn_type = "mongo"
hook_name = "MongoDB"

def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
def __init__(self, mongo_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(logger_name=kwargs.pop("logger_name", None))
self.mongo_conn_id = conn_id
self.connection = self.get_connection(conn_id)
if conn_id := kwargs.pop("conn_id", None):
warnings.warn(
"Parameter `conn_id` is deprecated and will be removed in a future releases. "
"Please use `mongo_conn_id` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
mongo_conn_id = conn_id

self.mongo_conn_id = mongo_conn_id
self.connection = self.get_connection(self.mongo_conn_id)
self.extras = self.connection.extra_dejson.copy()
self.client: MongoClient | None = None
self.uri = self._create_uri()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/mongo/sensors/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def poke(self, context: Context) -> bool:
self.log.info(
"Sensor check existence of the document that matches the following query: %s", self.query
)
hook = MongoHook(self.mongo_conn_id)
hook = MongoHook(mongo_conn_id=self.mongo_conn_id)
return hook.find(self.collection, self.query, mongo_db=self.mongo_db, find_one=True) is not None
24 changes: 17 additions & 7 deletions tests/integration/providers/mongo/sensors/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,32 @@
from airflow.models.dag import DAG
from airflow.providers.mongo.hooks.mongo import MongoHook
from airflow.providers.mongo.sensors.mongo import MongoSensor
from airflow.utils import db, timezone
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2017, 1, 1)


@pytest.fixture(scope="module", autouse=True)
def mongo_connections():
"""Create MongoDB connections which use for testing purpose."""
connections = [
Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017),
Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port=27017, schema="test"),
]

with pytest.MonkeyPatch.context() as mp:
for conn in connections:
mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
yield


@pytest.mark.integration("mongo")
class TestMongoSensor:
def setup_method(self):
db.merge_conn(
Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port=27017, schema="test")
)

args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG("test_dag_id", default_args=args)

hook = MongoHook("mongo_test")
hook = MongoHook(mongo_conn_id="mongo_test")
hook.insert_one("foo", {"bar": "baz"})

self.sensor = MongoSensor(
Expand All @@ -53,7 +63,7 @@ def test_poke(self):
assert self.sensor.poke(None)

def test_sensor_with_db(self):
hook = MongoHook("mongo_test")
hook = MongoHook(mongo_conn_id="mongo_test")
hook.insert_one("nontest", {"1": "2"}, mongo_db="nontest")

sensor = MongoSensor(
Expand Down
60 changes: 46 additions & 14 deletions tests/providers/mongo/hooks/test_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from __future__ import annotations

import importlib
import warnings
from typing import TYPE_CHECKING

import pymongo
import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.mongo.hooks.mongo import MongoHook
from airflow.utils import db

pytestmark = pytest.mark.db_test

Expand All @@ -40,14 +41,36 @@
mongomock = None


@pytest.fixture(scope="module", autouse=True)
def mongo_connections():
"""Create MongoDB connections which use for testing purpose."""
connections = [
Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017),
Connection(
conn_id="mongo_default_with_srv",
conn_type="mongo",
host="mongo",
port=27017,
extra='{"srv": true}',
),
# Mongo establishes connection during initialization, so we need to have this connection
Connection(conn_id="fake_connection", conn_type="mongo", host="mongo", port=27017),
]

with pytest.MonkeyPatch.context() as mp:
for conn in connections:
mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
yield


class MongoHookTest(MongoHook):
"""
Extending hook so that a mockmongo collection object can be passed in
to get_collection()
"""

def __init__(self, conn_id="mongo_default", *args, **kwargs):
super().__init__(conn_id=conn_id, *args, **kwargs)
def __init__(self, mongo_conn_id="mongo_default", *args, **kwargs):
super().__init__(mongo_conn_id=mongo_conn_id, *args, **kwargs)

def get_collection(self, mock_collection, mongo_db=None):
return mock_collection
Expand All @@ -56,24 +79,33 @@ def get_collection(self, mock_collection, mongo_db=None):
@pytest.mark.skipif(mongomock is None, reason="mongomock package not present")
class TestMongoHook:
def setup_method(self):
self.hook = MongoHookTest(conn_id="mongo_default", mongo_db="default")
self.hook = MongoHookTest(mongo_conn_id="mongo_default")
self.conn = self.hook.get_conn()
db.merge_conn(
Connection(
conn_id="mongo_default_with_srv",
conn_type="mongo",
host="mongo",
port=27017,
extra='{"srv": true}',

def test_mongo_conn_id(self):
with warnings.catch_warnings():
warnings.simplefilter("error", category=AirflowProviderDeprecationWarning)
# Use default "mongo_default"
assert MongoHook().mongo_conn_id == "mongo_default"
# Positional argument
assert MongoHook("fake_connection").mongo_conn_id == "fake_connection"

warning_message = "Parameter `conn_id` is deprecated"
with pytest.warns(AirflowProviderDeprecationWarning, match=warning_message):
assert MongoHook(conn_id="fake_connection").mongo_conn_id == "fake_connection"

with pytest.warns(AirflowProviderDeprecationWarning, match=warning_message):
assert (
MongoHook(conn_id="fake_connection", mongo_conn_id="foo-bar").mongo_conn_id
== "fake_connection"
)
)

def test_get_conn(self):
assert self.hook.connection.port == 27017
assert isinstance(self.conn, pymongo.MongoClient)

def test_srv(self):
hook = MongoHook(conn_id="mongo_default_with_srv")
hook = MongoHook(mongo_conn_id="mongo_default_with_srv")
assert hook.uri.startswith("mongodb+srv://")

def test_insert_one(self):
Expand Down Expand Up @@ -333,7 +365,7 @@ def test_distinct_with_filter(self):


def test_context_manager():
with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook:
with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:
ctx_hook.get_conn()

assert isinstance(ctx_hook, MongoHook)
Expand Down